

Attention机制虽然能捕捉序列中不同位置的依赖关系,但是无法区分不同位置的元素。为了解决这个问题,Transformer模型引入了位置编码(Positional Encoding)。



class AbsolutePE(torch.nn.Module):
    def __init__(self, max_len: int, d_model: int):
        super(AbsolutePE, self).__init__()
        self.pe = torch.nn.Embedding(max_len, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [N, L]
        pos = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        return self.pe(pos) # [1, L, D], can be broadcasted to [N, L, D]




\[ \begin{aligned} PE_{(pos, 2i)} &= \sin(pos / 10000^{2i / d_{\text{model}}}) \\ PE_{(pos, 2i + 1)} &= \cos(pos / 10000^{2i / d_{\text{model}}}) \\ \end{aligned} \]


class SinusoidPE(torch.nn.Module):
    def __init__(self, d_model: int):
        super(SinusoidPE, self).__init__()
        self.d_model = d_model

    def _denominator(self, device: torch.device) -> torch.Tensor:
        return 10000 ** (torch.arange(0, self.d_model, device=device) / self.d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [N, L]
        pos = torch.arange(x.size(1), device=x.device).unsqueeze(0) # [1, L]
        pos = pos.unsqueeze(-1) / self._denominator(x.device) # [1, L, D]
        pos[:, :, 0::2] = torch.sin(pos[:, :, 0::2])
        pos[:, :, 1::2] = torch.cos(pos[:, :, 1::2])
        return pos # [1, L, D], can be broadcasted to [N, L, D]


Su等人提出的旋转位置编码(Rotary Positional Embedding,RoPE)。编码的核心思想是通过旋转矩阵将位置信息嵌入到特征空间中,从而使模型能够学习到位置信息。

\[ f_{q, k} \bsx_m = \bsR_{\Theta, m}^d W_{q, k} \bsx_m \]

其中,\(\bsR_{\Theta, m}\)为旋转矩阵,将向量每两个分量进行旋转。

\[ \begin{aligned} \bsR^k_{\Theta, m} &= \begin{bmatrix} \cos(m\theta_k) & -\sin(m\theta_k) \\ \sin(m\theta_k) & \cos(m\theta_k) \end{bmatrix} \\ \bsR_{\Theta, m} &= \begin{bmatrix} \bsR^1_{\Theta, m} & 0 & \cdots & 0 \\ 0 & \bsR^2_{\Theta, m} & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & \bsR^{d / 2}_{\Theta, m} \end{bmatrix}_{d\times d} \end{aligned} \]

注意,\(m\)为位置,取值范围为\(1, \ldots, L\)\(k\)为维度下标,取值范围为\(1, \ldots, d / 2\)。最终的旋转操作将每个位置\(m\)的向量\(\bsx_m\)应用旋转矩阵\(\bsR_{\Theta, m}\),得到新的向量。

import functools

class RoPE(torch.nn.Module):
    def __init__(self, d_model: int, theta: int | float = 10000):
        super(RoPE, self).__init__()
        self.d_model = d_model
        self.theta = theta ** -(torch.arange(0, d_model, 2) / d_model)

    def _forward_l(self, L: int) -> torch.Tensor:
        # Use lru_cache to avoid redundant computation for the same L

        D = self.d_model
        pos = torch.einsum(
            torch.arange(L), self.theta
        )  # [L, D / 2]

        # Major diagonal is cos, cos, ..., cos; D elements
        cos = torch.cos(pos).repeat_interleave(2)
        # Minor diagonal is sin, 0, sin, 0, ..., sin; D - 1 elements
        sin = torch.stack([
            torch.sin(pos), torch.zeros_like(pos), dim=-1
        ]).reshape(L, D)[:, :-1]

        result = torch.zeros(L, D, D)
        result = torch.diagonal_scatter(result, cos, dim1=1, dim2=2)
        result = torch.diagonal_scatter(result, sin, dim1=1, dim2=2, offset=-1)
        result = torch.diagonal_scatter(result, -sin, dim1=1, dim2=2, offset=1)
        return result

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [N, L, H, D]
        _, L, _, D = x.size()
        rot_matrix = self._forward_l(L).to(x.device)  # [L, D, D]
        return torch.einsum('lde,nlhe->nlhd', rot_matrix, x)


class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model: int, num_heads: int, rope: RoPE | None = None):

        self.rope = rope


    def forward(
        self, x_q: torch.Tensor, x_k: torch.Tensor, x_v: torch.Tensor,
        padding_mask: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None
    ) -> torch.Tensor:

        if self.rope is not None:
            Q = self.rope(self.W_Q(x_q).reshape(N, L_Q, self.num_heads, self.d_k))
            K = self.rope(self.W_K(x_k).reshape(N, L_KV, self.num_heads, self.d_k))
            Q = self.W_Q(x_q).reshape(N, L_Q, self.num_heads, self.d_k)
            K = self.W_K(x_k).reshape(N, L_KV, self.num_heads, self.d_k)
        V = self.W_V(x_v).reshape(N, L_KV, self.num_heads, self.d_k)


