实现transformer
In [1]:
Copied!
import torch
import torch
首先实现多头注意力机制。
In [2]:
Copied!
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model: int, num_heads: int):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
if d_model % num_heads != 0:
raise ValueError("d_model must be divisible by num_heads")
self.d_k = d_model // num_heads
self.sqrt_d_k = self.d_k ** 0.5
self.W_Q = torch.nn.Linear(d_model, d_model)
self.W_K = torch.nn.Linear(d_model, d_model)
self.W_V = torch.nn.Linear(d_model, d_model)
self.W_O = torch.nn.Linear(d_model, d_model)
def forward(
self, x_q: torch.Tensor, x_k: torch.Tensor, x_v: torch.Tensor,
mask: torch.Tensor | None = None
) -> torch.Tensor:
N, L_Q, D = x_q.size()
_, L_KV, _ = x_k.size()
Q = self.W_Q(x_q).reshape(N, L_Q, self.num_heads, self.d_k)
K = self.W_K(x_k).reshape(N, L_KV, self.num_heads, self.d_k)
V = self.W_V(x_v).reshape(N, L_KV, self.num_heads, self.d_k)
score = torch.einsum("nihd,njhd->nijh", Q, K) / self.sqrt_d_k
if mask is not None:
# mask: (L, L)
score = score.masked_fill(
mask.reshape(1, L_Q, L_KV, 1) == 0, float('-inf')
)
score = torch.nn.functional.softmax(score, dim=2)
value = torch.einsum("nijh,njhd->nihd", score, V).reshape(N, L_Q, self.d_model)
return self.W_O(value)
class MultiHeadSelfAttention(MultiHeadAttention):
def __init__(self, d_model: int, num_heads: int):
super(MultiHeadSelfAttention, self).__init__(d_model, num_heads)
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
return super().forward(x, x, x, mask)
class MultiHeadCrossAttention(MultiHeadAttention):
def __init__(self, d_model: int, num_heads: int):
super(MultiHeadCrossAttention, self).__init__(d_model, num_heads)
def forward(
self, x_q: torch.Tensor, x_kv: torch.Tensor,
mask: torch.Tensor | None = None
) -> torch.Tensor:
return super().forward(x_q, x_kv, x_kv, mask)
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model: int, num_heads: int):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
if d_model % num_heads != 0:
raise ValueError("d_model must be divisible by num_heads")
self.d_k = d_model // num_heads
self.sqrt_d_k = self.d_k ** 0.5
self.W_Q = torch.nn.Linear(d_model, d_model)
self.W_K = torch.nn.Linear(d_model, d_model)
self.W_V = torch.nn.Linear(d_model, d_model)
self.W_O = torch.nn.Linear(d_model, d_model)
def forward(
self, x_q: torch.Tensor, x_k: torch.Tensor, x_v: torch.Tensor,
mask: torch.Tensor | None = None
) -> torch.Tensor:
N, L_Q, D = x_q.size()
_, L_KV, _ = x_k.size()
Q = self.W_Q(x_q).reshape(N, L_Q, self.num_heads, self.d_k)
K = self.W_K(x_k).reshape(N, L_KV, self.num_heads, self.d_k)
V = self.W_V(x_v).reshape(N, L_KV, self.num_heads, self.d_k)
score = torch.einsum("nihd,njhd->nijh", Q, K) / self.sqrt_d_k
if mask is not None:
# mask: (L, L)
score = score.masked_fill(
mask.reshape(1, L_Q, L_KV, 1) == 0, float('-inf')
)
score = torch.nn.functional.softmax(score, dim=2)
value = torch.einsum("nijh,njhd->nihd", score, V).reshape(N, L_Q, self.d_model)
return self.W_O(value)
class MultiHeadSelfAttention(MultiHeadAttention):
def __init__(self, d_model: int, num_heads: int):
super(MultiHeadSelfAttention, self).__init__(d_model, num_heads)
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
return super().forward(x, x, x, mask)
class MultiHeadCrossAttention(MultiHeadAttention):
def __init__(self, d_model: int, num_heads: int):
super(MultiHeadCrossAttention, self).__init__(d_model, num_heads)
def forward(
self, x_q: torch.Tensor, x_kv: torch.Tensor,
mask: torch.Tensor | None = None
) -> torch.Tensor:
return super().forward(x_q, x_kv, x_kv, mask)
前馈神经网络由两个全连接层和ReLU激活函数组成
In [3]:
Copied!
class FFN(torch.nn.Module):
def __init__(self, input_dim: int, hidden_dim: int):
super(FFN, self).__init__()
self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, input_dim)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
class FFN(torch.nn.Module):
def __init__(self, input_dim: int, hidden_dim: int):
super(FFN, self).__init__()
self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, input_dim)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
编码器层由自注意力c层和前馈神经网络层组成,一个编码器由多个这样的层串联组成
In [4]:
Copied!
class EncoderLayer(torch.nn.Module):
def __init__(
self, input_dim: int, num_heads: int,
ffn_dim: int | None = None, dropout: float = 0.1,
layer_norm_eps: float = 1e-6
):
super(EncoderLayer, self).__init__()
if ffn_dim is None:
ffn_dim = input_dim * 4
self.attention = MultiHeadSelfAttention(input_dim, num_heads)
self.norm1 = torch.nn.LayerNorm(input_dim, layer_norm_eps)
self.dropout1 = torch.nn.Dropout(dropout)
self.ffn = FFN(input_dim, ffn_dim)
self.norm2 = torch.nn.LayerNorm(input_dim, layer_norm_eps)
self.dropout2 = torch.nn.Dropout(dropout)
def forward(self, x, mask):
x = x + self.dropout1(self.attention(x, mask))
x = self.norm1(x)
x = x + self.dropout2(self.ffn(x))
x = self.norm2(x)
return x
class Encoder(torch.nn.Module):
def __init__(
self, num_layers: int, input_dim: int, num_heads: int,
ffn_dim: int | None = None, dropout: float = 0.1,
layer_norm_eps: float = 1e-6
):
super(Encoder, self).__init__()
self.layers = torch.nn.ModuleList([
EncoderLayer(input_dim, num_heads, ffn_dim, dropout, layer_norm_eps)
for _ in range(num_layers)
])
def forward(self, x, mask):
# x: (N, L, D)
for layer in self.layers:
x = layer(x, mask)
return x
class EncoderLayer(torch.nn.Module):
def __init__(
self, input_dim: int, num_heads: int,
ffn_dim: int | None = None, dropout: float = 0.1,
layer_norm_eps: float = 1e-6
):
super(EncoderLayer, self).__init__()
if ffn_dim is None:
ffn_dim = input_dim * 4
self.attention = MultiHeadSelfAttention(input_dim, num_heads)
self.norm1 = torch.nn.LayerNorm(input_dim, layer_norm_eps)
self.dropout1 = torch.nn.Dropout(dropout)
self.ffn = FFN(input_dim, ffn_dim)
self.norm2 = torch.nn.LayerNorm(input_dim, layer_norm_eps)
self.dropout2 = torch.nn.Dropout(dropout)
def forward(self, x, mask):
x = x + self.dropout1(self.attention(x, mask))
x = self.norm1(x)
x = x + self.dropout2(self.ffn(x))
x = self.norm2(x)
return x
class Encoder(torch.nn.Module):
def __init__(
self, num_layers: int, input_dim: int, num_heads: int,
ffn_dim: int | None = None, dropout: float = 0.1,
layer_norm_eps: float = 1e-6
):
super(Encoder, self).__init__()
self.layers = torch.nn.ModuleList([
EncoderLayer(input_dim, num_heads, ffn_dim, dropout, layer_norm_eps)
for _ in range(num_layers)
])
def forward(self, x, mask):
# x: (N, L, D)
for layer in self.layers:
x = layer(x, mask)
return x
解码器层由自注意力层,交叉注意力层和前馈神经网络层组成。
In [5]:
Copied!
class DecoderLayer(torch.nn.Module):
def __init__(
self, input_dim: int, num_heads: int,
ffn_dim: int | None = None, dropout: float = 0.1
):
super(DecoderLayer, self).__init__()
if ffn_dim is None:
ffn_dim = input_dim * 4
self.self_attention = MultiHeadSelfAttention(input_dim, num_heads)
self.norm1 = torch.nn.LayerNorm(input_dim)
self.dropout1 = torch.nn.Dropout(dropout)
self.cross_attention = MultiHeadCrossAttention(input_dim, num_heads)
self.norm2 = torch.nn.LayerNorm(input_dim)
self.dropout2 = torch.nn.Dropout(dropout)
self.ffn = FFN(input_dim, ffn_dim)
self.norm3 = torch.nn.LayerNorm(input_dim)
self.dropout3 = torch.nn.Dropout(dropout)
def forward(
self, x: torch.Tensor, memory: torch.Tensor, tgt_mask: torch.Tensor | None = None
):
x = x + self.dropout1(self.self_attention(x, tgt_mask))
x = self.norm1(x)
x = x + self.dropout2(self.cross_attention(x, memory))
x = self.norm2(x)
x = x + self.dropout3(self.ffn(x))
x = self.norm3(x)
return x
class Decoder(torch.nn.Module):
def __init__(
self, num_layers: int, input_dim: int, num_heads: int,
ffn_dim: int | None = None, dropout: float = 0.1
):
super(Decoder, self).__init__()
self.layers = torch.nn.ModuleList([
DecoderLayer(input_dim, num_heads, ffn_dim, dropout)
for _ in range(num_layers)
])
def forward(self, x, memory, tgt_mask):
# x: (N, L, D)
for layer in self.layers:
x = layer(x, memory, tgt_mask)
return x
class DecoderLayer(torch.nn.Module):
def __init__(
self, input_dim: int, num_heads: int,
ffn_dim: int | None = None, dropout: float = 0.1
):
super(DecoderLayer, self).__init__()
if ffn_dim is None:
ffn_dim = input_dim * 4
self.self_attention = MultiHeadSelfAttention(input_dim, num_heads)
self.norm1 = torch.nn.LayerNorm(input_dim)
self.dropout1 = torch.nn.Dropout(dropout)
self.cross_attention = MultiHeadCrossAttention(input_dim, num_heads)
self.norm2 = torch.nn.LayerNorm(input_dim)
self.dropout2 = torch.nn.Dropout(dropout)
self.ffn = FFN(input_dim, ffn_dim)
self.norm3 = torch.nn.LayerNorm(input_dim)
self.dropout3 = torch.nn.Dropout(dropout)
def forward(
self, x: torch.Tensor, memory: torch.Tensor, tgt_mask: torch.Tensor | None = None
):
x = x + self.dropout1(self.self_attention(x, tgt_mask))
x = self.norm1(x)
x = x + self.dropout2(self.cross_attention(x, memory))
x = self.norm2(x)
x = x + self.dropout3(self.ffn(x))
x = self.norm3(x)
return x
class Decoder(torch.nn.Module):
def __init__(
self, num_layers: int, input_dim: int, num_heads: int,
ffn_dim: int | None = None, dropout: float = 0.1
):
super(Decoder, self).__init__()
self.layers = torch.nn.ModuleList([
DecoderLayer(input_dim, num_heads, ffn_dim, dropout)
for _ in range(num_layers)
])
def forward(self, x, memory, tgt_mask):
# x: (N, L, D)
for layer in self.layers:
x = layer(x, memory, tgt_mask)
return x
Transformer模型由编码器和解码器组成,编码器处理源序列,将编码后的序列输入到解码器中,解码器生成目标序列。
In [6]:
Copied!
class CustomTransformer(torch.nn.Module):
def __init__(
self, num_layers: int, num_decoder_layers: int, input_dim: int, num_heads: int,
ffn_dim: int | None = None, dropout: float = 0.1, layer_norm_eps: float = 1e-6
):
super(CustomTransformer, self).__init__()
self.encoder = Encoder(num_layers, input_dim, num_heads, ffn_dim, dropout, layer_norm_eps)
self.decoder = Decoder(num_decoder_layers, input_dim, num_heads, ffn_dim, dropout)
def forward(self, src, tgt, src_mask: torch.Tensor | None = None, tgt_mask: torch.Tensor | None = None):
memory = self.encoder(src, src_mask)
return self.decoder(tgt, memory, tgt_mask)
class CustomTransformer(torch.nn.Module):
def __init__(
self, num_layers: int, num_decoder_layers: int, input_dim: int, num_heads: int,
ffn_dim: int | None = None, dropout: float = 0.1, layer_norm_eps: float = 1e-6
):
super(CustomTransformer, self).__init__()
self.encoder = Encoder(num_layers, input_dim, num_heads, ffn_dim, dropout, layer_norm_eps)
self.decoder = Decoder(num_decoder_layers, input_dim, num_heads, ffn_dim, dropout)
def forward(self, src, tgt, src_mask: torch.Tensor | None = None, tgt_mask: torch.Tensor | None = None):
memory = self.encoder(src, src_mask)
return self.decoder(tgt, memory, tgt_mask)
将PyTorch内部实现的Transformer权重复制到实现的Transformer中。
In [7]:
Copied!
def _attn_load_from_torch(
custom_attn: MultiHeadAttention, torch_attn: torch.nn.MultiheadAttention
):
embed_dim = custom_attn.d_model
def split_qkv(weight, embed_dim):
return weight[:embed_dim], weight[embed_dim: 2 * embed_dim], weight[2 * embed_dim:3 * embed_dim]
custom_attn.W_Q.weight.data, \
custom_attn.W_K.weight.data, \
custom_attn.W_V.weight.data = split_qkv(
torch_attn.in_proj_weight.data, embed_dim
)
custom_attn.W_Q.bias.data, \
custom_attn.W_K.bias.data, \
custom_attn.W_V.bias.data = split_qkv(
torch_attn.in_proj_bias.data, embed_dim
)
custom_attn.W_O.weight.data = torch_attn.out_proj.weight.data
custom_attn.W_O.bias.data = torch_attn.out_proj.bias.data
return custom_attn
def load_from_torch(
custom_transformer: CustomTransformer, torch_transformer: torch.nn.Transformer
):
for custom_layer, torch_layer in zip(
[*custom_transformer.encoder.layers, *custom_transformer.decoder.layers],
[*torch_transformer.encoder.layers, *torch_transformer.decoder.layers],
):
if hasattr(custom_layer, 'cross_attention'):
# Decoder
custom_layer.self_attention = _attn_load_from_torch(
custom_layer.self_attention, torch_layer.self_attn
)
custom_layer.cross_attention = _attn_load_from_torch(
custom_layer.cross_attention, torch_layer.multihead_attn
)
else:
# Encoder
custom_layer.attention = _attn_load_from_torch(
custom_layer.attention, torch_layer.self_attn
)
layer_pairs = [
(custom_layer.norm1, torch_layer.norm1),
(custom_layer.ffn.fc1, torch_layer.linear1),
(custom_layer.ffn.fc2, torch_layer.linear2),
(custom_layer.norm2, torch_layer.norm2)
]
for custom, torch in layer_pairs:
custom.weight.data = torch.weight.data
custom.bias.data = torch.bias.data
return custom_transformer
def _attn_load_from_torch(
custom_attn: MultiHeadAttention, torch_attn: torch.nn.MultiheadAttention
):
embed_dim = custom_attn.d_model
def split_qkv(weight, embed_dim):
return weight[:embed_dim], weight[embed_dim: 2 * embed_dim], weight[2 * embed_dim:3 * embed_dim]
custom_attn.W_Q.weight.data, \
custom_attn.W_K.weight.data, \
custom_attn.W_V.weight.data = split_qkv(
torch_attn.in_proj_weight.data, embed_dim
)
custom_attn.W_Q.bias.data, \
custom_attn.W_K.bias.data, \
custom_attn.W_V.bias.data = split_qkv(
torch_attn.in_proj_bias.data, embed_dim
)
custom_attn.W_O.weight.data = torch_attn.out_proj.weight.data
custom_attn.W_O.bias.data = torch_attn.out_proj.bias.data
return custom_attn
def load_from_torch(
custom_transformer: CustomTransformer, torch_transformer: torch.nn.Transformer
):
for custom_layer, torch_layer in zip(
[*custom_transformer.encoder.layers, *custom_transformer.decoder.layers],
[*torch_transformer.encoder.layers, *torch_transformer.decoder.layers],
):
if hasattr(custom_layer, 'cross_attention'):
# Decoder
custom_layer.self_attention = _attn_load_from_torch(
custom_layer.self_attention, torch_layer.self_attn
)
custom_layer.cross_attention = _attn_load_from_torch(
custom_layer.cross_attention, torch_layer.multihead_attn
)
else:
# Encoder
custom_layer.attention = _attn_load_from_torch(
custom_layer.attention, torch_layer.self_attn
)
layer_pairs = [
(custom_layer.norm1, torch_layer.norm1),
(custom_layer.ffn.fc1, torch_layer.linear1),
(custom_layer.ffn.fc2, torch_layer.linear2),
(custom_layer.norm2, torch_layer.norm2)
]
for custom, torch in layer_pairs:
custom.weight.data = torch.weight.data
custom.bias.data = torch.bias.data
return custom_transformer
通过代码验证实现的Transformer模型的正确性。
In [8]:
Copied!
transformer_config = {
'num_layers': 6,
'num_decoder_layers': 6,
'input_dim': 512,
'num_heads': 8,
'ffn_dim': 2048,
'dropout': 0.1,
'layer_norm_eps': 1e-6
}
custom_transformer = CustomTransformer(**transformer_config)
torch_transformer = torch.nn.Transformer(
d_model=transformer_config['input_dim'],
nhead=transformer_config['num_heads'],
num_encoder_layers=transformer_config['num_layers'],
num_decoder_layers=transformer_config['num_decoder_layers'],
dim_feedforward=transformer_config['ffn_dim'],
dropout=transformer_config['dropout'],
batch_first=True
)
custom_transformer = load_from_torch(custom_transformer, torch_transformer)
custom_transformer.eval()
torch_transformer.eval()
src = torch.randn(32, 10, transformer_config['input_dim'])
tgt = torch.randn(32, 20, transformer_config['input_dim'])
tgt_mask = 1 - torch.triu(torch.ones(20, 20), diagonal=1)
custom_output = custom_transformer(src, tgt, tgt_mask=tgt_mask)
torch_output = torch_transformer(src, tgt, tgt_mask=tgt_mask, tgt_is_causal=True)
custom_output[0, 0, :5], torch_output[0, 0, :5]
transformer_config = {
'num_layers': 6,
'num_decoder_layers': 6,
'input_dim': 512,
'num_heads': 8,
'ffn_dim': 2048,
'dropout': 0.1,
'layer_norm_eps': 1e-6
}
custom_transformer = CustomTransformer(**transformer_config)
torch_transformer = torch.nn.Transformer(
d_model=transformer_config['input_dim'],
nhead=transformer_config['num_heads'],
num_encoder_layers=transformer_config['num_layers'],
num_decoder_layers=transformer_config['num_decoder_layers'],
dim_feedforward=transformer_config['ffn_dim'],
dropout=transformer_config['dropout'],
batch_first=True
)
custom_transformer = load_from_torch(custom_transformer, torch_transformer)
custom_transformer.eval()
torch_transformer.eval()
src = torch.randn(32, 10, transformer_config['input_dim'])
tgt = torch.randn(32, 20, transformer_config['input_dim'])
tgt_mask = 1 - torch.triu(torch.ones(20, 20), diagonal=1)
custom_output = custom_transformer(src, tgt, tgt_mask=tgt_mask)
torch_output = torch_transformer(src, tgt, tgt_mask=tgt_mask, tgt_is_causal=True)
custom_output[0, 0, :5], torch_output[0, 0, :5]
Out[8]:
(tensor([ 0.6672, 0.4102, -0.2145, 1.8782, 0.4108], grad_fn=<SliceBackward0>), tensor([ 0.6672, 0.4102, -0.2145, 1.8781, 0.4108], grad_fn=<SliceBackward0>))
In [9]:
Copied!
torch.allclose(custom_output, torch_output, atol=1e-4)
torch.allclose(custom_output, torch_output, atol=1e-4)
Out[9]:
True