实现Llama-2
在本节中,我们实现一个Llama-2模型。
模型实现¶
In [1]:
Copied!
import functools
import torch
import transformers
import functools
import torch
import transformers
位置编码¶
Llama使用RoPE位置编码。注意此处的位置编码和RoFormer中实现的不同,RoFormer将相邻的两个元素分为一组进行旋转操作,而LlamaRope是将向量分为前后两段,前后两段相同位置的元素分为一组进行旋转操作。
In [2]:
Copied!
class RoPE(torch.nn.Module):
def __init__(self, d_model: int, theta: int | float = 10000):
super(RoPE, self).__init__()
self.d_model = d_model
self.theta = theta ** -(torch.arange(0, d_model, 2) / d_model)
@functools.lru_cache(maxsize=None)
def _forward_l(self, L: int) -> torch.Tensor:
# Use lru_cache to avoid redundant computation for the same L
D = self.d_model
pos = torch.einsum(
'l,d->ld',
torch.arange(L), self.theta
) # [L, D / 2]
# 0 paired with d // 2, 1 paired with d // 2 + 1, ...
cos = torch.cos(pos).repeat([1, 2])
sin = torch.sin(pos)
# Here we do not use efficient method, but construct a rotary matrix
result = torch.zeros(L, D, D)
result = torch.diagonal_scatter(result, cos, dim1=1, dim2=2)
result = torch.diagonal_scatter(result, sin, dim1=1, dim2=2, offset=-D // 2)
result = torch.diagonal_scatter(result, -sin, dim1=1, dim2=2, offset=D // 2)
return result
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [N, L, H, D]
_, L, _, D = x.size()
rot_matrix = self._forward_l(L).to(x.device) # [L, D, D]
return torch.einsum('lde,nlhe->nlhd', rot_matrix, x)
class RoPE(torch.nn.Module):
def __init__(self, d_model: int, theta: int | float = 10000):
super(RoPE, self).__init__()
self.d_model = d_model
self.theta = theta ** -(torch.arange(0, d_model, 2) / d_model)
@functools.lru_cache(maxsize=None)
def _forward_l(self, L: int) -> torch.Tensor:
# Use lru_cache to avoid redundant computation for the same L
D = self.d_model
pos = torch.einsum(
'l,d->ld',
torch.arange(L), self.theta
) # [L, D / 2]
# 0 paired with d // 2, 1 paired with d // 2 + 1, ...
cos = torch.cos(pos).repeat([1, 2])
sin = torch.sin(pos)
# Here we do not use efficient method, but construct a rotary matrix
result = torch.zeros(L, D, D)
result = torch.diagonal_scatter(result, cos, dim1=1, dim2=2)
result = torch.diagonal_scatter(result, sin, dim1=1, dim2=2, offset=-D // 2)
result = torch.diagonal_scatter(result, -sin, dim1=1, dim2=2, offset=D // 2)
return result
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [N, L, H, D]
_, L, _, D = x.size()
rot_matrix = self._forward_l(L).to(x.device) # [L, D, D]
return torch.einsum('lde,nlhe->nlhd', rot_matrix, x)
注意力机制¶
RoPE是在多头注意力机制分头后才进行计算。并且,Llama的注意力机制中不包含偏置。
In [3]:
Copied!
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model: int, num_heads: int, rope: RoPE):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
if d_model % num_heads != 0:
raise ValueError("d_model must be divisible by num_heads")
self.d_k = d_model // num_heads
self.sqrt_d_k = self.d_k ** 0.5
self.rope = rope
self.W_Q = torch.nn.Linear(d_model, d_model, bias=False)
self.W_K = torch.nn.Linear(d_model, d_model, bias=False)
self.W_V = torch.nn.Linear(d_model, d_model, bias=False)
self.W_O = torch.nn.Linear(d_model, d_model, bias=False)
def forward(
self, x_q: torch.Tensor, x_k: torch.Tensor, x_v: torch.Tensor,
padding_mask: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
# x_q: (N, L_Q, D), x_k: (N, L_KV, D), x_v: (N, L_KV, D)
# padding_mask: (N, L_KV), attention_mask: (L_Q, L_KV)
N, L_Q, D = x_q.size()
_, L_KV, _ = x_k.size()
Q = self.rope(self.W_Q(x_q).reshape(N, L_Q, self.num_heads, self.d_k))
K = self.rope(self.W_K(x_k).reshape(N, L_KV, self.num_heads, self.d_k))
V = self.W_V(x_v).reshape(N, L_KV, self.num_heads, self.d_k)
score = torch.einsum('nihd,njhd->nijh', Q, K) / self.sqrt_d_k
# score: (N, L_Q, L_KV, num_heads)
# Apply attention mask
if attention_mask is not None:
score = score.masked_fill(
attention_mask.reshape(1, L_Q, L_KV, 1) == 0, float('-inf')
)
# Apply padding mask
if padding_mask is not None:
score = score.masked_fill(
padding_mask.reshape(N, 1, L_KV, 1) == 0, float('-inf')
)
score = torch.nn.functional.softmax(score, dim=2)
value = torch.einsum(
'nijh,njhd->nihd', score, V
).reshape(N, L_Q, self.d_model)
return self.W_O(value)
class MultiHeadSelfAttention(MultiHeadAttention):
def __init__(self, d_model: int, num_heads: int, rope=RoPE):
super(MultiHeadSelfAttention, self).__init__(d_model, num_heads, rope)
def forward(
self, x: torch.Tensor,
padding_mask: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
return super().forward(x, x, x, padding_mask, attention_mask)
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model: int, num_heads: int, rope: RoPE):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
if d_model % num_heads != 0:
raise ValueError("d_model must be divisible by num_heads")
self.d_k = d_model // num_heads
self.sqrt_d_k = self.d_k ** 0.5
self.rope = rope
self.W_Q = torch.nn.Linear(d_model, d_model, bias=False)
self.W_K = torch.nn.Linear(d_model, d_model, bias=False)
self.W_V = torch.nn.Linear(d_model, d_model, bias=False)
self.W_O = torch.nn.Linear(d_model, d_model, bias=False)
def forward(
self, x_q: torch.Tensor, x_k: torch.Tensor, x_v: torch.Tensor,
padding_mask: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
# x_q: (N, L_Q, D), x_k: (N, L_KV, D), x_v: (N, L_KV, D)
# padding_mask: (N, L_KV), attention_mask: (L_Q, L_KV)
N, L_Q, D = x_q.size()
_, L_KV, _ = x_k.size()
Q = self.rope(self.W_Q(x_q).reshape(N, L_Q, self.num_heads, self.d_k))
K = self.rope(self.W_K(x_k).reshape(N, L_KV, self.num_heads, self.d_k))
V = self.W_V(x_v).reshape(N, L_KV, self.num_heads, self.d_k)
score = torch.einsum('nihd,njhd->nijh', Q, K) / self.sqrt_d_k
# score: (N, L_Q, L_KV, num_heads)
# Apply attention mask
if attention_mask is not None:
score = score.masked_fill(
attention_mask.reshape(1, L_Q, L_KV, 1) == 0, float('-inf')
)
# Apply padding mask
if padding_mask is not None:
score = score.masked_fill(
padding_mask.reshape(N, 1, L_KV, 1) == 0, float('-inf')
)
score = torch.nn.functional.softmax(score, dim=2)
value = torch.einsum(
'nijh,njhd->nihd', score, V
).reshape(N, L_Q, self.d_model)
return self.W_O(value)
class MultiHeadSelfAttention(MultiHeadAttention):
def __init__(self, d_model: int, num_heads: int, rope=RoPE):
super(MultiHeadSelfAttention, self).__init__(d_model, num_heads, rope)
def forward(
self, x: torch.Tensor,
padding_mask: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
return super().forward(x, x, x, padding_mask, attention_mask)
FFN¶
在前馈网络部分,Llama使用Swiglu作为激活函数。此部分同样不使用偏置
In [4]:
Copied!
class FFN(torch.nn.Module):
def __init__(self, input_dim: int, hidden_dim: int):
super(FFN, self).__init__()
self.fc1 = torch.nn.Linear(input_dim, hidden_dim, bias=False)
self.fc2 = torch.nn.Linear(hidden_dim, input_dim, bias=False)
self.gate = torch.nn.Linear(input_dim, hidden_dim, bias=False)
self.act = torch.nn.SiLU()
def forward(self, x):
return self.fc2(self.act(self.gate(x)) * self.fc1(x))
class FFN(torch.nn.Module):
def __init__(self, input_dim: int, hidden_dim: int):
super(FFN, self).__init__()
self.fc1 = torch.nn.Linear(input_dim, hidden_dim, bias=False)
self.fc2 = torch.nn.Linear(hidden_dim, input_dim, bias=False)
self.gate = torch.nn.Linear(input_dim, hidden_dim, bias=False)
self.act = torch.nn.SiLU()
def forward(self, x):
return self.fc2(self.act(self.gate(x)) * self.fc1(x))
LayerNorm¶
Llama使用RMSNorm作为LayerNorm。
In [5]:
Copied!
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-8):
super(RMSNorm, self).__init__()
self.weight = torch.nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor):
return x / torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) * self.weight
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-8):
super(RMSNorm, self).__init__()
self.weight = torch.nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor):
return x / torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) * self.weight
Decoder Layer¶
Llama使用Pre-norm,即先进行norm,再进行attention/FFN操作,最后相加。据此可以构建出LlamaDecoder的基本结构。
In [6]:
Copied!
class LlamaDecoder(torch.nn.Module):
def __init__(
self, d_model: int, num_heads: int, rope: RoPE, hidden_dim: int, layernorm_eps: float
):
super().__init__()
self.norm1 = RMSNorm(d_model, layernorm_eps)
self.attention = MultiHeadSelfAttention(d_model, num_heads, rope)
self.norm2 = RMSNorm(d_model, layernorm_eps)
self.ffn = FFN(d_model, hidden_dim)
def forward(self, x: torch.Tensor, padding_mask: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None):
hidden = x
hidden += self.attention(self.norm1(hidden), padding_mask, attention_mask)
hidden += self.ffn(self.norm2(hidden))
return hidden
class LlamaDecoder(torch.nn.Module):
def __init__(
self, d_model: int, num_heads: int, rope: RoPE, hidden_dim: int, layernorm_eps: float
):
super().__init__()
self.norm1 = RMSNorm(d_model, layernorm_eps)
self.attention = MultiHeadSelfAttention(d_model, num_heads, rope)
self.norm2 = RMSNorm(d_model, layernorm_eps)
self.ffn = FFN(d_model, hidden_dim)
def forward(self, x: torch.Tensor, padding_mask: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None):
hidden = x
hidden += self.attention(self.norm1(hidden), padding_mask, attention_mask)
hidden += self.ffn(self.norm2(hidden))
return hidden
Llama模型构建¶
将如上结构拼起来,即可构建出Llama模型。
In [7]:
Copied!
class LlamaModel(torch.nn.Module):
def __init__(
self, config: transformers.LlamaConfig
):
super().__init__()
self.embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size)
self.rope = RoPE(
config.hidden_size // config.num_attention_heads,
config.rope_theta
)
self.layers = torch.nn.ModuleList([
LlamaDecoder(
config.hidden_size,
config.num_attention_heads,
self.rope,
config.intermediate_size,
config.rms_norm_eps
)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size)
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(
self, input_ids: torch.Tensor,
padding_mask: torch.Tensor | None = None
) -> torch.Tensor:
hidden = self.embedding(input_ids)
attention_mask = 1 - torch.triu(
torch.ones(hidden.size(1), hidden.size(1)), diagonal=1
).to(hidden.device)
for layer in self.layers:
hidden = layer(hidden, padding_mask, attention_mask)
hidden = self.norm(hidden)
hidden = self.lm_head(hidden)
return hidden
class LlamaModel(torch.nn.Module):
def __init__(
self, config: transformers.LlamaConfig
):
super().__init__()
self.embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size)
self.rope = RoPE(
config.hidden_size // config.num_attention_heads,
config.rope_theta
)
self.layers = torch.nn.ModuleList([
LlamaDecoder(
config.hidden_size,
config.num_attention_heads,
self.rope,
config.intermediate_size,
config.rms_norm_eps
)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size)
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(
self, input_ids: torch.Tensor,
padding_mask: torch.Tensor | None = None
) -> torch.Tensor:
hidden = self.embedding(input_ids)
attention_mask = 1 - torch.triu(
torch.ones(hidden.size(1), hidden.size(1)), diagonal=1
).to(hidden.device)
for layer in self.layers:
hidden = layer(hidden, padding_mask, attention_mask)
hidden = self.norm(hidden)
hidden = self.lm_head(hidden)
return hidden
模型验证¶
首先,从Llama-2-7b-hf
中加载模型配置,并根据配置初始化模型
In [8]:
Copied!
llama_config = transformers.AutoConfig.from_pretrained('meta-llama/Llama-2-7b-hf')
llama = LlamaModel(llama_config)
llama_config = transformers.AutoConfig.from_pretrained('meta-llama/Llama-2-7b-hf')
llama = LlamaModel(llama_config)
/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. warnings.warn(
加载Huggingface提供的Llama模型,并且将预训练的Llama参数复制到实现的模型中。
In [9]:
Copied!
hf_llama = transformers.LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
def load_params(custom_model: LlamaModel, hf_model: transformers.LlamaForCausalLM):
layer_pairs = [
(custom_model.embedding, hf_model.model.embed_tokens),
(custom_model.lm_head, hf_model.lm_head),
(custom_model.norm, hf_model.model.norm)
]
for custom_layer, hf_layer in zip(custom_model.layers, hf_model.model.layers):
layer_pairs.extend([
(custom_layer.norm1, hf_layer.input_layernorm),
(custom_layer.norm2, hf_layer.post_attention_layernorm),
(custom_layer.attention.W_Q, hf_layer.self_attn.q_proj),
(custom_layer.attention.W_K, hf_layer.self_attn.k_proj),
(custom_layer.attention.W_V, hf_layer.self_attn.v_proj),
(custom_layer.attention.W_O, hf_layer.self_attn.o_proj),
(custom_layer.ffn.gate, hf_layer.mlp.gate_proj),
(custom_layer.ffn.fc1, hf_layer.mlp.up_proj),
(custom_layer.ffn.fc2, hf_layer.mlp.down_proj)
])
for custom_layer, hf_layer in layer_pairs:
custom_layer.weight.data.copy_(hf_layer.weight.data)
return custom_model
llama = load_params(llama, hf_llama)
hf_llama = transformers.LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
def load_params(custom_model: LlamaModel, hf_model: transformers.LlamaForCausalLM):
layer_pairs = [
(custom_model.embedding, hf_model.model.embed_tokens),
(custom_model.lm_head, hf_model.lm_head),
(custom_model.norm, hf_model.model.norm)
]
for custom_layer, hf_layer in zip(custom_model.layers, hf_model.model.layers):
layer_pairs.extend([
(custom_layer.norm1, hf_layer.input_layernorm),
(custom_layer.norm2, hf_layer.post_attention_layernorm),
(custom_layer.attention.W_Q, hf_layer.self_attn.q_proj),
(custom_layer.attention.W_K, hf_layer.self_attn.k_proj),
(custom_layer.attention.W_V, hf_layer.self_attn.v_proj),
(custom_layer.attention.W_O, hf_layer.self_attn.o_proj),
(custom_layer.ffn.gate, hf_layer.mlp.gate_proj),
(custom_layer.ffn.fc1, hf_layer.mlp.up_proj),
(custom_layer.ffn.fc2, hf_layer.mlp.down_proj)
])
for custom_layer, hf_layer in layer_pairs:
custom_layer.weight.data.copy_(hf_layer.weight.data)
return custom_model
llama = load_params(llama, hf_llama)
加载tokenizer并生成输入数据,注意Llama-2没有padding token,需要手动设置。
In [10]:
Copied!
llama_tokenizer = transformers.LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
llama_tokenizer.pad_token = llama_tokenizer.eos_token
tokenized_sentence = llama_tokenizer([
'User: Hello Llama! How are you?\nAssistant:',
'User: Hello World!\nAssistant:'
], return_tensors='pt', padding=True)
input_ids = tokenized_sentence['input_ids']
padding_mask = tokenized_sentence['attention_mask']
llama_tokenizer = transformers.LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
llama_tokenizer.pad_token = llama_tokenizer.eos_token
tokenized_sentence = llama_tokenizer([
'User: Hello Llama! How are you?\nAssistant:',
'User: Hello World!\nAssistant:'
], return_tensors='pt', padding=True)
input_ids = tokenized_sentence['input_ids']
padding_mask = tokenized_sentence['attention_mask']
根据输入数据,计算每层模型的输出。
In [11]:
Copied!
llama_layers_output = hf_llama(input_ids, attention_mask=padding_mask, output_hidden_states=True)
llama_layers_output = hf_llama(input_ids, attention_mask=padding_mask, output_hidden_states=True)
以下对输出结果进行验证。首先验证embedding层的输出。
In [12]:
Copied!
llama_layers_output.hidden_states[0]
hidden = llama.embedding(input_ids)
attention_mask = 1 - torch.triu(
torch.ones(hidden.size(1), hidden.size(1)), diagonal=1
).to(device=hidden.device)
# Hidden state 1 equivalent to embedding output
torch.allclose(
llama_layers_output.hidden_states[0],
hidden
)
llama_layers_output.hidden_states[0]
hidden = llama.embedding(input_ids)
attention_mask = 1 - torch.triu(
torch.ones(hidden.size(1), hidden.size(1)), diagonal=1
).to(device=hidden.device)
# Hidden state 1 equivalent to embedding output
torch.allclose(
llama_layers_output.hidden_states[0],
hidden
)
Out[12]:
True
验证每一层decoder的输出,最后一层decoder可能会出现不同。
In [13]:
Copied!
# Verify layer outputs
layer_results = []
for i in range(llama_config.num_hidden_layers - 1):
hidden = llama.layers[i](hidden, padding_mask, attention_mask)
layer_results.append(torch.allclose(
llama_layers_output.hidden_states[i + 1], hidden, atol=1e-4
))
all(layer_results), all(layer_results[:-1])
# Verify layer outputs
layer_results = []
for i in range(llama_config.num_hidden_layers - 1):
hidden = llama.layers[i](hidden, padding_mask, attention_mask)
layer_results.append(torch.allclose(
llama_layers_output.hidden_states[i + 1], hidden, atol=1e-4
))
all(layer_results), all(layer_results[:-1])
Out[13]:
(False, True)
验证LM-head的输出。
In [14]:
Copied!
# Verify output logits
logits = llama(input_ids, padding_mask)
hf_logits = llama_layers_output.logits
torch.allclose(logits, hf_logits, atol=1e-4)
# Verify output logits
logits = llama(input_ids, padding_mask)
hf_logits = llama_layers_output.logits
torch.allclose(logits, hf_logits, atol=1e-4)
Out[14]:
True
尝试对输出的单词进行解码。
In [15]:
Copied!
# Decode the output tokens
max_prob_tokens = logits[0, -1].topk(5).indices.tolist()
for _ in max_prob_tokens:
print(llama_tokenizer.decode(_))
# Decode the output tokens
max_prob_tokens = logits[0, -1].topk(5).indices.tolist()
for _ in max_prob_tokens:
print(llama_tokenizer.decode(_))
I Hello Fine Hi Oh