⚠️
本文有一定时效性 · 1个月前更新
最后更新: 2026年04月01日
预计阅读时间: 34.7 分钟
8672 字 250 字/分

Attension也是整个 Transformer 里最精髓的部分了, 也卡了我相当之久

前向传播解析部分

注意力公式如下, 很晦涩, 但我尽可能以简单的方式来解释这些问题

$$ Attention(Q,K,V)=softmax\left(\frac{QK^{T}}{\sqrt{d_{k}}}\right)V $$

$QKV$计算

先解释 Q K V 分别是什么:

  • Q: Query 向量, 代表着我要查什么东西
  • K: Key 向量, 代表着当前被查询的这个东西
  • V: Value 向量, 实际上当前被查询的东西的具体内容
  • $d_{k}$: 向量的维度

简单来说, Q 负责确认"我要找什么", K 负责"我现在找到的是什么", V 是"找大的具体内容". 首先, 这三个矩阵的型是相同的

先看 $QK^{T}$, 两个高纬度的向量做点积, 根据高中的知识, 通常来说, 点积更大的, 这两个向量在方向上约接近

举个例子, 我们按照单词划分Token, 输入一个 3 Tokens的一句话

Token1 Token2 Token3

那对于这几个Token对应的矩阵, 就应当是三行(对应三个 Token) 二列(Token 维度为 2), 即输入矩阵 X

$$ X=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix} $$

我们可以简单的认为, 其中 Token1 对应第一行的向量(1, 2), Token2 对应第二行的(0, 1)

接下来, 根据公式 $Q=XW_{Q},\quad K=XW_{K},\quad V=XW_{V}$ 来生成我们需要的 QKV 三个矩阵, 其中 $X$ 是输入, $W_{Q,K,V}$ 是一个可训练的权重矩阵, 为了方便理解, 我们直接令:

$W_{Q}=\begin{bmatrix}1 & 0\\ 0 & 1\end{bmatrix}$ $W_{K}=\begin{bmatrix}1 & 1\\ 0 & 1\end{bmatrix}$ $W_{V}=\begin{bmatrix}1 & 0\\ 1 & 1\end{bmatrix}$

由上述 QKV 公式, 我们可以计算

$$ Q=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix}\cdot\begin{bmatrix}1 & 0\\ 0 & 1\end{bmatrix}=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix} $$

$$ K=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix}\cdot\begin{bmatrix}1 & 1\\ 0 & 1\end{bmatrix}=\begin{bmatrix}1 & 3\\ 0 & 1\\ 3 & 4\end{bmatrix} $$

$$ V=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix}\cdot\begin{bmatrix}1 & 0\\ 1 & 1\end{bmatrix}=\begin{bmatrix}3 & 2\\ 1 & 1\\ 4 & 1\end{bmatrix} $$

$QK^{T}$计算

接下来, 计算 $QK^{T}$

$$ QK^{T}=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix}\cdot\begin{bmatrix}1 & 0 & 3\\ 3 & 1 & 4\end{bmatrix}=\begin{bmatrix}7 & 2 & 11\\ 3 & 1 & 4\\ 6 & 1 & 13\end{bmatrix} $$

其中, 我们计算的结果矩阵中的每一行, 就是这个Token和其他另外的几个Token的相似程度, 并且$QK^{T}$一定是方阵, 每行和每列都对应一个Token

以第一行 Token1 为例, 通过计算得到, 在总共的三个Token中, Token1最关注Token3, 最不关注Token2, 但这只是一个初始的结果, 并不能直接用, 我们还需要进一步处理

除以 $\sqrt{d_{k}}$

为什么它偏偏要除以这个开方后的$d_k$呢, 在我们这个例子中可能看不出来, 因为我们选择的向量维度只有 2, 根据如下的公式

$$ Q_i \cdot K_j = \sum_{t=1}^{d_k} q_t k_t $$

随着我们向量维度的增大, 这个矩阵的数值也会开始爆炸, 大的数字会变得更大, 可能是这样的

$$ K^{T}=\begin{bmatrix}1 & 0 & 114514\\ 3 & 1 & 1919810114514\\ 888 & 888 & 888\end{bmatrix} $$

所以我们需要除以一个数, 来把矩阵里的每个参数的值拉回到一个正常的水平

$softmax$

接下来, 我们要把这个矩阵中的值转换为一个概率, 而最常用的方法就是Softmax, 把每个值映射到一个 0-1 的概率, 来决定注意力的分配

$$ softmax(x_{i})=\frac{e^{x_{i}}}{\sum e^{x_{j}}} $$

根据这个公式, 对每行单独进行softmax, 就计算出了每个Token对于其他Token的注意力分配, 最终得到的结果可能类似于 下面的结果

$$ softmax\left(\frac{QK^{T}}{\sqrt2}\right)\approx\begin{bmatrix}0.055 & 0.002 & 0.943\\ 0.304 & 0.074 & 0.622\\ 0.007 & 0.0002 & 0.993\end{bmatrix} $$

乘以$V$

根据上面的结果, 我们能发现, 对于Token1来说, 他对三个Token的注意力分布分别是0.055, 0.002, 0.943, 带入所有信息, 计算最终的Attention值

$$ Output=AV\approx\begin{bmatrix}3.94 & 1.06\\ 3.50 & 1.30\\ 3.99 & 1.01\end{bmatrix} $$

最终的结果的行数, 代表Token的总数, 而列数就是我们设定的向量维度的大小, 那这个向量具体的意义是什么呢? 这个向量就是综合了上下文信息后, 这个Token对应的向量

比如我原本的三个Token是苹果手机电脑, 其中苹果这个Token, 既有可能指苹果这个水果, 也可能指苹果这个品牌, 那经过Attention处理后的Token, 就相当于通过计算, 确定出了这个苹果对应的Token向量到底是哪个, 在这例子中, [3.94, 1.06] 这个向量, 就可能指"苹果"这个科技品牌

但同样需要注意的是, 经过 Attention变换后的向量, 并不是新的数据, 而是融合了上下文语义后的向量

反向传播

显然, 反向传播就需要依据上面的路径一步步计算返回的梯度值

output
← attention, V
← scores
← Q, K
← Wq, Wk, Wv, X

首先计算 output 对 Attention 和 V 的梯度, 在前向传播时, 我们有

$$ Output=Attention\cdot V $$

假设我们的损失函数为 Loss, 并且已知从上一层神经网络传回的梯度 $\frac{\partial L}{\partial Output}$, 我们需要计算 $\frac{\partial L}{\partial Attention},\quad\frac{\partial L}{\partial V}$这两个梯度

矩阵乘法的梯度公式

对于一个矩阵乘法 $Y = AB$, 若已知某函数 $L$ 对于 $Y$ 的梯度 $\frac{\partial L}{\partial Y}$, 那么则有如下两个结论

$$ \frac{\partial L}{\partial A}=\frac{\partial L}{\partial Y}B^{T} $$

$$ \frac{\partial L}{\partial B}=A^{T}\frac{\partial L}{\partial Y} $$

直接带入计算输出时的公式$O = AV$, 可得

$$ \frac{\partial L}{\partial Attention}=\frac{\partial L}{\partial Output}^{}V^{T} $$

$$ \frac{\partial L}{\partial V}=Attention^{T}\cdot\frac{\partial L}{\partial Output} $$


数学证明如下

以以下公式为例:

$$ \frac{\partial L}{\partial \text{Attention}} = \frac{\partial L}{\partial \text{Output}} \cdot V^T $$

根据$output = Attention · V$ 易知:

$$ Output_{ij}=\sum_{k}Attention_{ik}V_{kj} $$

那么对某个 $Attention_{ij}$ 求偏导, 得:

$$ \frac{\partial Output_{ij}}{\partial Attention_{im}}=V_{mj} $$

再根据链式法则有:

$$ \frac{\partial L}{\partial Attention_{im}}=\sum_{j}\frac{\partial L}{\partial Output_{ij}}\frac{\partial Output_{ij}}{\partial Attention_{im}} $$

右侧可以直接合并, 就是矩阵乘法

$$ \frac{\partial L}{\partial Attention}=\frac{\partial L}{\partial Output}V^{T} $$

另一个公式同理, 不多证明

传播至 $Softmax$

接下来, 反向传播到 $Softmax $ 层, 上一步我们得到了 $L$ 关于 $Attention$ 的梯度, 但在 $Softmax$ 层, $Attention$ 本身并不是直接参数, 而是经过 $Softmax$ 变换后得到的, 因此, 如果要继续求 $L$ 关于 $scores$ 的梯度 $\frac{\partial L}{\partial scores}$ , 就要复杂一些

当然, 因为 $Softmax$ 不是一个逐元素独立的函数, 而是把输入矩阵的每一行单独归一化, 对每一行, 有: $a_{j}=\frac{e^{s_{j}}}{\sum_{k}e^{s_{k}}}$

它的导数很诡异, 是: $\frac{\partial a_{j}}{\partial s_{m}}=a_{j}(\delta_{jm}-a_{m})$, 其中 $\delta_{jm}$ 是 Kronecker delta, 当且仅当 $ j = m $ 时为 1, 否则为 0

继续, 我们需要求出每一行第 j 个位置的梯度 $g_j = \frac{\partial L}{\partial a_j}$, 由链式法则有

$$ \frac{\partial L}{\partial s_{j}}=a_{j}\left(g_{j}-\sum_{k}g_{k}a_{k}\right) $$

代码实现

public class SingleSelfAttention {
    // 定义三个初始化后的可训练的 QKV 矩阵
    private Matrix wq;
    private Matrix wk;
    private Matrix wv;
    // 向量维度, 即文章中的 2 维
    private int dModel;

    // 临时存储, 用于反向传播计算
    private Matrix lastInput;
    private Matrix lastQ,lastK,lastV,lastAttention;
    public SingleSelfAttention(int dModel){
        // 检测输入合法性
        if(dModel <= 0){
            throw new IllegalArgumentException("dModel must be greater than 0\n");
        }
        this.dModel = dModel;

        // 初始化矩阵
        wq = Matrix.random(dModel, dModel);
        wk = Matrix.random(dModel, dModel);
        wv = Matrix.random(dModel, dModel);
    }

    public Matrix forward(Matrix matrix){

        // 计算 QKV
        lastInput = matrix;
        lastQ = matrix.multiply(wq);
        lastK = matrix.multiply(wk);
        lastV = matrix.multiply(wv);

        // 使用Attention公式
        Matrix scores = lastQ.multiply(lastK.transpose());
        scores = scores.constantMultiply(1.0 / Math.sqrt(dModel));
        scores = applyCausalAttention(scores);

        //对每行进行Softmax
        lastAttention = scores.softmaxByRow();

        return lastAttention.multiply(lastV);

    }

    public Matrix backward(Matrix gradOutput, double learningRate){

        //
        Matrix gradAttention = gradOutput.multiply(lastV.transpose());
        // output = attention * V, so dV = attention^T * dOutput
        Matrix gradV = lastAttention.transpose().multiply(gradOutput);

        //The gradient of scores
        Matrix gradScores = new Matrix (lastAttention.getRow(), lastAttention.getCol());
        for(int i = 0 ; i < lastAttention.getRow() ; i++){
            double dot = 0.0;
            for(int j = 0 ; j < lastAttention.getCol(); j++){
                dot += gradAttention.getElement(i, j) * lastAttention.getElement(i, j);
            }
            for(int j = 0; j < lastAttention.getCol(); j++){
                double value = lastAttention.getElement(i, j)* (gradAttention.getElement(i, j) - dot);
                if(j > i){
                    value = 0.0;
                }
                gradScores.setElement(i, j, value/ Math.sqrt(dModel));
            }
        }

        Matrix gradQ = gradScores.multiply(lastK);
        Matrix gradK = gradScores.transpose().multiply(lastQ);

        Matrix gradWq = lastInput.transpose().multiply(gradQ);
        Matrix gradWk = lastInput.transpose().multiply(gradK);
        Matrix gradWv = lastInput.transpose().multiply(gradV);

        Matrix gradInput = gradQ.multiply(wq.transpose())
                .add(gradK.multiply(wk.transpose()))
                .add(gradV.multiply(wv.transpose()));
        updateWeights(wq, gradWq, learningRate);
        updateWeights(wk, gradWk, learningRate);
        updateWeights(wv, gradWv, learningRate);

        return gradInput;
    }
    public void updateWeights(Matrix weights, Matrix gradWeights, double learningRate){
        for(int i = 0; i < weights.getRow(); i++){
            for(int j = 0 ; j < weights.getCol(); j++){
                double value = weights.getElement(i, j);
                weights.setElement(i, j, value - learningRate * gradWeights.getElement(i, j));
            }
        }
    }

    private Matrix applyCausalAttention(Matrix scores){
        Matrix result = scores.copy();

        for(int i = 0; i < result.getRow(); i++){
            for(int j = i + 1; j < result.getCol(); j++){
                result.setElement(i, j, -1e9);
            }
        }
        return result;
    }

    public Matrix getWq() {
        return wq;
    }
    public Matrix getWk() {
        return wk;
    }
    public Matrix getWv() {
        return wv;
    }

    public void setWq(Matrix newWq){
        validateWeightShape(newWq, "wq");
        this.wq = newWq.copy();
    }
    public void setWk(Matrix newWk){
        validateWeightShape(newWk, "wk");
        this.wk = newWk.copy();
    }
    public void setWv(Matrix newWv){
        validateWeightShape(newWv, "wv");
        this.wv = newWv.copy();
    }

    private void validateWeightShape(Matrix matrix, String weightName){
        if(matrix == null){
            throw new IllegalArgumentException("The " + weightName + " cannot be null");
        }
        if(matrix.getRow() != dModel || matrix.getCol() != dModel){
            throw new IllegalArgumentException("The shape of " + weightName + " must be (" + dModel + ", " + dModel + ")");
        }
    }
}