0. 先看结论

Scaled Dot-Product Attention 的公式是:

它做了三件事:

  1. 用 $QK^T$ 算“当前位置和其他位置有多相关”。
  2. 除以 $\sqrt{d_k}$,把分数压回稳定范围。
  3. 用 softmax 得到权重,再对 $V$ 做加权求和。
1
2
3
4
5
6
7
8
flowchart LR
A["Q query"] --> D["QK 相似度分数"]
B["K key"] --> D
D --> E["除以 sqrt dk"]
E --> F["softmax 得到权重"]
C["V value"] --> G["加权求和"]
F --> G
G --> H["新的 token 表示"]

一句话讲 scale:

点积会随着维度 $d_k$ 变大而变大;不缩放的话,softmax 很容易变得过尖,梯度变小,训练不稳定。除以 $\sqrt{d_k}$ 后,分数的方差大致回到 1。


1. Attention 要解决什么问题

语言模型处理一句话时,每个 token 都需要参考上下文。

比如:

1
小明把书放进书包,因为它太重了。

这里的“它”更可能指“书”,不是“书包”。模型需要一种机制,让当前位置可以去看前面的相关位置。Attention 做的就是这件事:

问题 attention 的做法
当前 token 该看哪里 用 query 和 key 算相关性
每个位置该占多少权重 用 softmax 把分数变成权重
看完以后拿到什么信息 对 value 做加权求和

这不是在“查字典”,而是在每一层里动态计算上下文关系。


2. Q、K、V 分别是什么

对同一个 token 表示 $x$,Transformer 会用三个不同的线性层得到:

名称 全称 可以怎么理解 用途
$Q$ Query 当前位置想找什么 主动发问
$K$ Key 每个位置提供什么索引 被匹配
$V$ Value 每个位置真正携带的信息 被加权汇总

同一个输入 $X$ 会被投影成三种角色。这样做的好处是:模型可以把“用来匹配的信息”和“真正要传递的信息”分开学。

1
2
3
4
5
6
7
flowchart LR
A["输入 X"] --> B["WQ"]
A --> C["WK"]
A --> D["WV"]
B --> E["Q"]
C --> F["K"]
D --> G["V"]

3. 公式里的形状

先看单个 batch、单个 head 的情况。

符号 形状 含义
$L$ 标量 序列长度
$d_k$ 标量 query/key 的维度
$d_v$ 标量 value 的维度
$Q$ $L \times d_k$ 每个位置的 query
$K$ $L \times d_k$ 每个位置的 key
$V$ $L \times d_v$ 每个位置的 value
$QK^T$ $L \times L$ 每个位置对每个位置的相关性分数
softmax 后 $L \times L$ attention 权重
输出 $L \times d_v$ 每个位置的新表示

矩阵乘法流程:

1
2
3
4
5
6
7
Q:        L x d_k
K^T: d_k x L
QK^T: L x L

softmax(QK^T / sqrt(d_k)): L x L
V: L x d_v
output: L x d_v

每一行代表“某个位置看所有位置”的权重。


4. 第一步:为什么用点积

两个向量的点积:

点积越大,通常表示两个向量方向越接近。放到 attention 里:

点积结果 含义
很大 当前 query 和这个 key 很匹配
接近 0 关系不明显
很小或负数 不匹配

例子:

位置 token 与当前位置的分数
1 小明 0.8
2 3.2
3 书包 1.1
4 当前 token

如果当前位置是“它”,第 2 个位置的分数最高,softmax 后它会拿到更高权重。

点积也适合硬件计算。所有位置之间的匹配可以一次矩阵乘法完成:

这比逐个位置循环计算更适合 GPU。


5. 第二步:为什么要除以 $\sqrt{d_k}$

Transformer 原论文给出的理由是:当 $d_k$ 比较大时,点积的值会变大,softmax 会进入梯度很小的区域。缩放可以缓解这个问题。

下面把这句话拆开。

5.1 点积为什么会变大

假设 query 和 key 的每个分量都满足:

假设 含义
均值为 0 正负值大致抵消
方差为 1 每个分量的尺度差不多
分量之间独立 为了方便推导

点积是:

如果 $q_i$ 和 $k_i$ 的均值都是 0、方差都是 1,那么乘积 $q_i k_i$ 的均值约为 0,方差约为 1。

把 $d_k$ 项加起来:

标准差就是:

所以维度越大,点积分数的典型幅度越大。

$d_k$ 点积分数的标准差约为
16 4
64 8
128 11.3
256 16

这就是 scale 的数学来源。

5.2 为什么是除以 $\sqrt{d_k}$,不是除以 $d_k$

如果点积的标准差约为 $\sqrt{d_k}$,那除以 $\sqrt{d_k}$ 后:

方差也回到 1 左右:

如果除以 $d_k$,会缩得太狠:

当 $d_k=64$ 时,标准差大约变成 $1/8$。分数太接近,softmax 会变得很平,注意力权重难以拉开差距。

缩放方式 分数尺度 softmax 结果
不缩放 太大 过尖,容易接近 one-hot
除以 $\sqrt{d_k}$ 合适 分布有区分度,也不容易饱和
除以 $d_k$ 太小 过平,接近平均看所有位置

5.3 不 scale 会发生什么

softmax 对大分数很敏感。

例子 1:分数尺度正常。

1
softmax([1, 0, -1]) ≈ [0.665, 0.245, 0.090]

三个位置都有梯度,模型还能调整。

例子 2:分数过大。

1
softmax([10, 0, -10]) ≈ [0.99995, 0.00005, 0.00000]

最大项几乎吃掉所有概率。这个状态下,softmax 的梯度很小,模型很难继续细调“第二相关”和“第三相关”的位置。

5.4 scale 不改变排序,但会改变分布形状

除以同一个正数不会改变大小顺序:

1
[16, 8, 0] / 8 = [2, 1, 0]

最大值仍然是最大值。变化在 softmax 之后:

logits softmax 大致结果 分布形状
[16, 8, 0] [0.9997, 0.0003, 0.0000] 很尖
[2, 1, 0] [0.665, 0.245, 0.090] 有区分,也不至于饱和

scale 的目的不是改变谁更相关,而是让 softmax 有合适的工作区间。


6. 第三步:softmax 把分数变成权重

对每个 query 位置,softmax 会在所有 key 位置上做归一化:

其中:

符号 含义
$i$ 当前 query 的位置
$j$ 被看的 key/value 位置
$s_{ij}$ 缩放后的相关性分数
$\alpha_{ij}$ 位置 $i$ 分给位置 $j$ 的注意力权重

一行权重会加起来等于 1:

所以 attention 可以看成“当前位置从所有 value 中按比例取信息”。


7. 第四步:对 V 做加权求和

得到权重后,输出为:

矩阵形式就是:

例子:

位置 token attention 权重 value 信息
1 小明 0.10 人名相关信息
2 0.70 物体相关信息
3 书包 0.20 容器相关信息

输出向量就是:

1
0.10 * v_小明 + 0.70 * v_书 + 0.20 * v_书包

这也是 attention 被称为“加权检索”的原因。


8. Mask 在哪里加

实际 Transformer 里,softmax 前还会加 mask:

8.1 padding mask

padding token 是补齐长度用的,不应该被模型关注。

token 是否有效 mask 值
有效 0
有效 0
NLP 有效 0
<pad> 无效 $-\infty$

加上 $-\infty$ 后:

1
softmax([2.0, 1.0, 0.5, -inf]) = [0.629, 0.231, 0.140, 0.000]

8.2 causal mask

自回归语言模型不能看未来 token。

假设序列长度为 4,允许看的位置如下:

query 位置 能看哪些 key 位置
1 1
2 1, 2
3 1, 2, 3
4 1, 2, 3, 4

对应矩阵:

1
2
3
4
5
6
7
8
9
允许 = 0
禁止 = -inf

[
[0, -inf, -inf, -inf],
[0, 0, -inf, -inf],
[0, 0, 0, -inf],
[0, 0, 0, 0 ],
]

mask 一定要在 softmax 前加。softmax 后再把概率置 0,还需要重新归一化,容易写错。


9. 一个完整小例子

假设当前 query 对三个 key 的原始点积分数是:

1
[16, 8, 0]

设 $d_k=64$,那么 $\sqrt{d_k}=8$。

缩放前:

1
softmax([16, 8, 0]) ≈ [0.9997, 0.0003, 0.0000]

缩放后:

1
2
[16, 8, 0] / 8 = [2, 1, 0]
softmax([2, 1, 0]) ≈ [0.665, 0.245, 0.090]

对 value 做加权求和:

1
output = 0.665 * V1 + 0.245 * V2 + 0.090 * V3

缩放后,模型仍然更关注第一个位置,但不会完全忽略后两个位置。训练早期尤其需要这种余地。


10. 和 cosine similarity 的区别

有人会问:既然担心点积太大,为什么不用 cosine similarity?

cosine similarity 是:

它会把向量长度也归一化掉。Scaled dot-product attention 只按维度缩放,不会去掉向量模长信息。

对比 scaled dot-product cosine similarity
缩放方式 除以固定的 $\sqrt{d_k}$ 除以每对向量的模长
是否保留向量长度信息 保留 基本去掉
计算效率 很适合矩阵乘法 还要计算 norm
在 Transformer 中 标准做法 可研究,但不是主流标准

向量长度有时也是模型学到的信号。直接用 cosine 会改变注意力打分的表达方式。


11. Multi-Head Attention 里的位置

Multi-Head Attention 会把 embedding 维度拆成多个 head,每个 head 各自做一遍 Scaled Dot-Product Attention。

1
2
3
4
5
6
flowchart LR
A["X"] --> B["投影成 Q K V"]
B --> C["切成多个 head"]
C --> D["每个 head 做 scaled attention"]
D --> E["拼接 heads"]
E --> F["输出投影"]

公式:

如果模型维度是 $d_{\text{model}}=768$,head 数是 $h=12$,那么每个 head 的 key 维度通常是:

此时 scale 就是:

每个 head 单独缩放。不是用整个 $d_{\text{model}}$ 去缩放。


12. 工程实现

12.1 简化版 PyTorch 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import math
import torch

def scaled_dot_product_attention(q, k, v, mask=None):
# q, k: [batch, heads, seq_len, d_k]
# v: [batch, heads, seq_len, d_v]
d_k = q.size(-1)

scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))

weights = torch.softmax(scores, dim=-1)
output = torch.matmul(weights, v)
return output, weights

真实工程里还会考虑:

问题 做法
fp16/bf16 下 -inf 处理 用 dtype 可承受的极小值,或交给框架内核
长上下文显存占用 FlashAttention / memory-efficient attention
dropout 对 attention weights 做 dropout
causal mask 使用上三角 mask
KV cache 推理时缓存历史 K/V,避免重复计算

12.2 数值稳定 softmax

softmax 实现通常会先减去最大值:

这不改变结果,但能避免 $e^x$ 溢出。


13. KV Cache 和推理

训练时,整段序列通常一次性进入模型:

1
Q, K, V 都来自整段序列

自回归推理时,每次只生成一个新 token。历史 token 的 $K,V$ 不会变,可以缓存起来:

1
2
3
4
第 t 步:
新 Q: 当前 token
历史 K/V: 从 cache 里取
新 K/V: 当前 token 算出来后追加进 cache

这样每一步不用重新计算所有历史 token 的 $K,V$。

阶段 Q K/V
训练 全序列 全序列
推理第 t 步 当前 token 历史 cache + 当前 token

注意:KV cache 省的是重复计算 K/V,不会让 attention 对历史长度的依赖消失。当前 token 仍然要和历史 K 做匹配。


14. FlashAttention 做了什么

普通实现会显式生成 $L \times L$ 的 attention score 矩阵:

1
2
3
scores = QK^T
weights = softmax(scores)
output = weights V

当 $L$ 很大时,$L \times L$ 矩阵很占显存。

FlashAttention 的思路是:不长期保存完整 attention 矩阵,而是分块计算,并用 online softmax 累积结果。

对比 普通 attention FlashAttention
数学结果 exact exact
是否显式保存完整 $L \times L$ 权重 通常会 不长期保存
主要收益 实现简单 降低 HBM 读写,更快更省显存

它没有改变 Scaled Dot-Product Attention 的公式,只是改变计算组织方式。


15. 常见误解

误解 更准确的说法
scale 是为了让分数都变小 scale 是为了让分数尺度和维度解耦,避免 softmax 饱和
除以 $d_k$ 更合理,因为点积加了 $d_k$ 项 点积的方差随 $d_k$ 增长,标准差随 $\sqrt{d_k}$ 增长
scale 不影响结果 它不改排序,但会改 softmax 后的权重分布
attention 权重就是解释 attention 权重能显示信息流,但不等同完整可解释性
mask 可以 softmax 后再加 推荐 softmax 前加,否则要重新归一化
Multi-head 里用 $d_{\text{model}}$ 缩放 每个 head 用自己的 $d_k$ 缩放
FlashAttention 是近似 FlashAttention 的主线实现是 exact attention

16. 手算流程

给定:

1
2
3
4
q = [1, 1]
k1 = [1, 0]
k2 = [0, 1]
k3 = [1, 1]

点积:

1
2
3
q·k1 = 1
q·k2 = 1
q·k3 = 2

$d_k=2$,所以:

1
sqrt(d_k) = 1.414

缩放:

1
[1, 1, 2] / 1.414 ≈ [0.707, 0.707, 1.414]

softmax:

1
softmax([0.707, 0.707, 1.414]) ≈ [0.248, 0.248, 0.503]

如果三个 value 是 $v_1,v_2,v_3$:

1
output = 0.248*v1 + 0.248*v2 + 0.503*v3

第三个 key 和 query 最匹配,所以第三个 value 权重最高。


17. 和前一篇 softmax 文档的关系

Scaled Dot-Product Attention 里有两个层次:

层次 对应内容
softmax 文档 解释“分数怎么变成概率/权重”
本文 解释“attention 分数从哪里来,为什么要缩放”

可以把本文的公式拆开看:

负责生成稳定的 attention scores;

负责把 scores 变成权重;

负责按权重汇总信息。


18. 小结

Scaled Dot-Product Attention 的内部逻辑可以按这条线记:

1
Q/K 算相关性 -> 除以 sqrt(d_k) 稳定分数 -> softmax 得到权重 -> 加权汇总 V

scale 的来源是点积方差:

除以 $\sqrt{d_k}$ 后,softmax 输入的尺度更稳定。这样既能保留点积注意力的高效矩阵乘法,又能避免高维点积分数过大导致 softmax 饱和。


19. 参考资料