位置编码¶
Attention机制虽然能捕捉序列中不同位置的依赖关系,但是无法区分不同位置的元素。为了解决这个问题,Transformer模型引入了位置编码(Positional Encoding)。
绝对位置编码¶
绝对位置编码是Transformer模型中最早引入的位置编码方式,其核心思想是为序列中的每个位置分配一个唯一的向量,将这个向量和输入序列相加得到最终序列的表示。如BERT模型,限制了序列的最大长度为512,因此可以使用固定的位置编码,将每个位置映射到一个可学习的向量。
class AbsolutePE(torch.nn.Module):
def __init__(self, max_len: int, d_model: int):
super(AbsolutePE, self).__init__()
self.pe = torch.nn.Embedding(max_len, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [N, L]
pos = torch.arange(x.size(1), device=x.device).unsqueeze(0)
return self.pe(pos) # [1, L, D], can be broadcasted to [N, L, D]
此类方法的缺点是无法处理超过最大长度的序列,并且忽视了元素之间的相对位置关系。
正弦位置编码¶
Vaswani等人提出的位置编码包含一个正弦函数和一个余弦函数,其数学表达为:
其中,\(pos\)表示位置,\(i\)表示维度,\(d_{\text{model}}\)表示模型的维度,输出的位置编码直接与输入序列相加得到最终表示。正弦位置编码不需要额外进行学习。Vaswani的实验表明,固定的位置编码和可学习的位置编码在性能上没有显著差异。
class SinusoidPE(torch.nn.Module):
def __init__(self, d_model: int):
super(SinusoidPE, self).__init__()
self.d_model = d_model
def _denominator(self, device: torch.device) -> torch.Tensor:
return 10000 ** (torch.arange(0, self.d_model, device=device) / self.d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [N, L]
pos = torch.arange(x.size(1), device=x.device).unsqueeze(0) # [1, L]
pos = pos.unsqueeze(-1) / self._denominator(x.device) # [1, L, D]
pos[:, :, 0::2] = torch.sin(pos[:, :, 0::2])
pos[:, :, 1::2] = torch.cos(pos[:, :, 1::2])
return pos # [1, L, D], can be broadcasted to [N, L, D]
旋转位置编码¶
Su等人提出的旋转位置编码(Rotary Positional Embedding,RoPE)。编码的核心思想是通过旋转矩阵将位置信息嵌入到特征空间中,从而使模型能够学习到位置信息。
其中,\(\bsR_{\Theta, m}\)为旋转矩阵,将向量每两个分量进行旋转。
注意,\(m\)为位置,取值范围为\(1, \ldots, L\),\(k\)为维度下标,取值范围为\(1, \ldots, d / 2\)。最终的旋转操作将每个位置\(m\)的向量\(\bsx_m\)应用旋转矩阵\(\bsR_{\Theta, m}\),得到新的向量。
import functools
class RoPE(torch.nn.Module):
def __init__(self, d_model: int, theta: int | float = 10000):
super(RoPE, self).__init__()
self.d_model = d_model
self.theta = theta ** -(torch.arange(0, d_model, 2) / d_model)
@functools.lru_cache(maxsize=None)
def _forward_l(self, L: int) -> torch.Tensor:
# Use lru_cache to avoid redundant computation for the same L
D = self.d_model
pos = torch.einsum(
'l,d->ld',
torch.arange(L), self.theta
) # [L, D / 2]
# Major diagonal is cos, cos, ..., cos; D elements
cos = torch.cos(pos).repeat_interleave(2)
# Minor diagonal is sin, 0, sin, 0, ..., sin; D - 1 elements
sin = torch.stack([
torch.sin(pos), torch.zeros_like(pos), dim=-1
]).reshape(L, D)[:, :-1]
result = torch.zeros(L, D, D)
result = torch.diagonal_scatter(result, cos, dim1=1, dim2=2)
result = torch.diagonal_scatter(result, sin, dim1=1, dim2=2, offset=-1)
result = torch.diagonal_scatter(result, -sin, dim1=1, dim2=2, offset=1)
return result
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [N, L, H, D]
_, L, _, D = x.size()
rot_matrix = self._forward_l(L).to(x.device) # [L, D, D]
return torch.einsum('lde,nlhe->nlhd', rot_matrix, x)
和其他位置编码方式不同,RoPE在QKV变换后进行位置编码,因此需要对注意力机制的实现进行修改。
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model: int, num_heads: int, rope: RoPE | None = None):
...
self.rope = rope
...
def forward(
self, x_q: torch.Tensor, x_k: torch.Tensor, x_v: torch.Tensor,
padding_mask: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
...
if self.rope is not None:
Q = self.rope(self.W_Q(x_q).reshape(N, L_Q, self.num_heads, self.d_k))
K = self.rope(self.W_K(x_k).reshape(N, L_KV, self.num_heads, self.d_k))
else:
Q = self.W_Q(x_q).reshape(N, L_Q, self.num_heads, self.d_k)
K = self.W_K(x_k).reshape(N, L_KV, self.num_heads, self.d_k)
V = self.W_V(x_v).reshape(N, L_KV, self.num_heads, self.d_k)
...
为了计算方便,Llama等部分模型的RoPE将输入向量分为前后两段,前后两段相同位置的元素分为一组进行旋转操作,而不是将整个向量按照相邻的两个元素分为一组进行旋转。