跳转至

大模型推理

在推理阶段,大模型根据输入的文本序列,不断预测输入文本的下一个词,直到预测出结束符号为止。不过,在预训练阶段中,不包含对bos(begin of sentence)和eos(end of sentence)的预测。所以,如果使用基座模型,模型会不断输出token。

KV-cache

把输入序列重新输出LLM,预测下一个单词推理流程的计算效率非常低,原因在于每一次预测都需要对整个序列计算注意力得分。然而,由于每一个token的表示只依赖于前面的token,在推理阶段,我们不必重算前边的token,只需要存储前边的token的KV表示即可。只需用最后一个token的Q计算注意力即可得到对下一个token的预测。这种技术即称为KV缓存。包含KV缓存的推理过程分为两个阶段,即缓存填充(prefill)阶段和解码(decode)阶段。

为何只需要KV-cache而不需要Q-cache

假设输入序列经过\(QKV\)计算后的序列为\(S^Q, S^K, S^V\),并且可以拆分为列向量表示:

\[ \begin{aligned} S^Q &= \begin{bmatrix} Q_1 & Q_2 & \cdots & Q_n \end{bmatrix} \\ S^K &= \begin{bmatrix} K_1 & K_2 & \cdots & K_n \end{bmatrix} \\ S^V &= \begin{bmatrix} V_1 & V_2 & \cdots & V_n \end{bmatrix} \]

注意力得分矩阵为

\[ \begin{bmatrix} Q_1K_1 & Q_1K_2 & \cdots & Q_1K_n \\ Q_2K_1 & Q_2K_2 & \cdots & Q_2K_n \\ \vdots & \vdots & \ddots & \vdots \\ Q_nK_1 & Q_nK_2 & \cdots & Q_nK_n \]

由于decoder的causal mask限制,最终计算的表示为

\[ \begin{aligned} O_1 &= \sum_{i=1}^1 \text{softmax}(Q_1K_i)V_i \\ O_2 &= \sum_{i=1}^2 \text{softmax}(Q_2K_i)V_i \\ \end{aligned} \]

也即,每个位置的表示,只依赖该位置的Q,和该位置及之前的K和V。因此预测下一个单词时,不需要重新计算KV,将KV存储下来即可。

将输入序列输入LLM时,首先进行缓存填充,保存下所有层的KV-cache,然后进行解码,不断预测下一个单词,并记录预测单词的KV表示。对于decoder的每一层,都需要保存一个KV-cache,每一层cache的大小为\(2\times N\times L\times D\)。对于长序列,KV-cache的大小会非常大。

评论