


语言模型是一类序列模型,其目标是对给定的序列建模,即对序列中的每个位置预测下一个位置的概率。文本首先通过分词,获得对应的离散符号(token),然后经过嵌入层(embedding layer)将符号映射到\(D\)维向量空间。因此,序列建模的输入是一个\(N\times L\times D\)的张量,其中\(N\)是batch size,\(L\)是序列长度,\(D\)是嵌入维度。



\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{\bsQ\bsK^\top}{\sqrt{d_k}}\right)\bsV \]

自注意力机制指的是查询、键和值都来自同一个序列。在实现中,输入序列\(\bsx\)分别经过三组线性变换\(\bsW^Q, \bsW^K, \bsW^V\)得到查询\(\bsQ\)、键\(\bsK\)和值\(\bsV\),然后计算注意力得分,最后通过值的加权求和得到输出。交叉注意力机制指查询和







\[ \text{softmax}(\bsz)_i = \frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} \]






einsum是Einstein Summation Notation的缩写,是一种用于张量运算的记号。在PyTorch中,einsum函数的定义为:

torch.einsum(equation, *operands)

其中equation是一个字符串,描述了张量的运算方式,operands是一个或多个张量。例如,torch.einsum("ij,jk->ik", A, B)表示计算矩阵\(A\)\(B\)的乘积。可以将其理解为乘积-求和的过程,即将equation中的字母对应到operands中的张量,然后对这些张量进行乘积和求和。忽略batch size,我们需要计算\(\bsQ\bsK^\top\),其中\(\bsQ\)的维度是\((L, D)\)\(\bsK\)的维度是\((L, D)\),输出的维度是\((L, L)\),因此equation"id,jd->ij"

下面我们通过PyTorch实现多头注意力机制。在实现注意力机制前,需要先实现线性变换和softmax函数。之后实现用于多头注意力的MultiHeadSelfAttention类,其输入是一个\(N\times L\times D\)的张量,输出是一个\(N\times L\times D\)的张量。

  1. padding mask:维度[N, L],用于掩盖序列中的padding元素,使其不参与注意力计算。
  2. attention mask:维度[L, L],用于自注意力机制中,控制序列中元素的依赖关系。通常用causal mask或look-ahead mask规定序列某个位置的元素只能依赖于该位置之前的元素,而不能依赖于之后的元素。
  3. 在交叉注意力机制中的memory mask:维度[L_Q, L_KV],用于控制查询和键之间的依赖关系。



class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")

        self.d_k = d_model // num_heads
        self.sqrt_d_k = self.d_k ** 0.5

        self.W_Q = torch.nn.Linear(d_model, d_model)
        self.W_K = torch.nn.Linear(d_model, d_model)
        self.W_V = torch.nn.Linear(d_model, d_model)
        self.W_O = torch.nn.Linear(d_model, d_model)

    def forward(
        self, x_q: torch.Tensor, x_k: torch.Tensor, x_v: torch.Tensor,
        padding_mask: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        # x_q: (N, L_Q, D), x_k: (N, L_KV, D), x_v: (N, L_KV, D)
        # padding_mask: (N, L_KV), attention_mask: (L_Q, L_KV)
        N, L_Q, D = x_q.size()
        _, L_KV, _ = x_k.size()

        # Linear transformation -> Split heads
        Q = self.W_Q(x_q).reshape(N, L_Q, self.num_heads, self.d_k)
        K = self.W_K(x_k).reshape(N, L_KV, self.num_heads, self.d_k)
        V = self.W_V(x_v).reshape(N, L_KV, self.num_heads, self.d_k)

        # Compute attention score
        score = torch.einsum('nihd,njhd->nijh', Q, K) / self.sqrt_d_k

        # Apply attention mask
        if attention_mask is not None:
            score = score.masked_fill(
                attention_mask.reshape(1, L_Q, L_KV, 1) == 0, float('-inf')

        # Apply padding mask
        if padding_mask is not None:
            score = score.masked_fill(
                padding_mask.reshape(N, 1, L_KV, 1) == 0, float('-inf')

        # Softmax -> Weighted sum -> Merge heads -> Output transformation
        score = torch.nn.functional.softmax(score, dim=2)
        value = torch.einsum(
            'nijh,njhd->nihd', score, V
        ).reshape(N, L_Q, self.d_model)
        return self.W_O(value)

class MultiHeadSelfAttention(MultiHeadAttention):
    def __init__(self, d_model: int, num_heads: int):
        super(MultiHeadSelfAttention, self).__init__(d_model, num_heads)

    def forward(
        self, x: torch.Tensor,
        padding_mask: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        # Self attention is applied to the same input
        return super().forward(x, x, x, padding_mask, attention_mask)

class MultiHeadCrossAttention(MultiHeadAttention):
    def __init__(self, d_model: int, num_heads: int):
        super(MultiHeadCrossAttention, self).__init__(d_model, num_heads)

    def forward(
        self, x_q: torch.Tensor, x_kv: torch.Tensor,
        padding_mask: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        # Cross attention is applying query on another kv sequence
        return super().forward(x_q, x_kv, x_kv, padding_mask, attention_mask)
