0. 先看结论
Scaled Dot-Product Attention 的公式是:
它做了三件事:
- 用 $QK^T$ 算“当前位置和其他位置有多相关”。
- 除以 $\sqrt{d_k}$,把分数压回稳定范围。
- 用 softmax 得到权重,再对 $V$ 做加权求和。
1 | flowchart LR |
一句话讲 scale:
点积会随着维度 $d_k$ 变大而变大;不缩放的话,softmax 很容易变得过尖,梯度变小,训练不稳定。除以 $\sqrt{d_k}$ 后,分数的方差大致回到 1。
1. Attention 要解决什么问题
语言模型处理一句话时,每个 token 都需要参考上下文。
比如:
1 | 小明把书放进书包,因为它太重了。 |
这里的“它”更可能指“书”,不是“书包”。模型需要一种机制,让当前位置可以去看前面的相关位置。Attention 做的就是这件事:
| 问题 | attention 的做法 |
|---|---|
| 当前 token 该看哪里 | 用 query 和 key 算相关性 |
| 每个位置该占多少权重 | 用 softmax 把分数变成权重 |
| 看完以后拿到什么信息 | 对 value 做加权求和 |
这不是在“查字典”,而是在每一层里动态计算上下文关系。
2. Q、K、V 分别是什么
对同一个 token 表示 $x$,Transformer 会用三个不同的线性层得到:
| 名称 | 全称 | 可以怎么理解 | 用途 |
|---|---|---|---|
| $Q$ | Query | 当前位置想找什么 | 主动发问 |
| $K$ | Key | 每个位置提供什么索引 | 被匹配 |
| $V$ | Value | 每个位置真正携带的信息 | 被加权汇总 |
同一个输入 $X$ 会被投影成三种角色。这样做的好处是:模型可以把“用来匹配的信息”和“真正要传递的信息”分开学。
1 | flowchart LR |
3. 公式里的形状
先看单个 batch、单个 head 的情况。
| 符号 | 形状 | 含义 |
|---|---|---|
| $L$ | 标量 | 序列长度 |
| $d_k$ | 标量 | query/key 的维度 |
| $d_v$ | 标量 | value 的维度 |
| $Q$ | $L \times d_k$ | 每个位置的 query |
| $K$ | $L \times d_k$ | 每个位置的 key |
| $V$ | $L \times d_v$ | 每个位置的 value |
| $QK^T$ | $L \times L$ | 每个位置对每个位置的相关性分数 |
| softmax 后 | $L \times L$ | attention 权重 |
| 输出 | $L \times d_v$ | 每个位置的新表示 |
矩阵乘法流程:
1 | Q: L x d_k |
每一行代表“某个位置看所有位置”的权重。
4. 第一步:为什么用点积
两个向量的点积:
点积越大,通常表示两个向量方向越接近。放到 attention 里:
| 点积结果 | 含义 |
|---|---|
| 很大 | 当前 query 和这个 key 很匹配 |
| 接近 0 | 关系不明显 |
| 很小或负数 | 不匹配 |
例子:
| 位置 | token | 与当前位置的分数 |
|---|---|---|
| 1 | 小明 | 0.8 |
| 2 | 书 | 3.2 |
| 3 | 书包 | 1.1 |
| 4 | 它 | 当前 token |
如果当前位置是“它”,第 2 个位置的分数最高,softmax 后它会拿到更高权重。
点积也适合硬件计算。所有位置之间的匹配可以一次矩阵乘法完成:
这比逐个位置循环计算更适合 GPU。
5. 第二步:为什么要除以 $\sqrt{d_k}$
Transformer 原论文给出的理由是:当 $d_k$ 比较大时,点积的值会变大,softmax 会进入梯度很小的区域。缩放可以缓解这个问题。
下面把这句话拆开。
5.1 点积为什么会变大
假设 query 和 key 的每个分量都满足:
| 假设 | 含义 |
|---|---|
| 均值为 0 | 正负值大致抵消 |
| 方差为 1 | 每个分量的尺度差不多 |
| 分量之间独立 | 为了方便推导 |
点积是:
如果 $q_i$ 和 $k_i$ 的均值都是 0、方差都是 1,那么乘积 $q_i k_i$ 的均值约为 0,方差约为 1。
把 $d_k$ 项加起来:
标准差就是:
所以维度越大,点积分数的典型幅度越大。
| $d_k$ | 点积分数的标准差约为 |
|---|---|
| 16 | 4 |
| 64 | 8 |
| 128 | 11.3 |
| 256 | 16 |
这就是 scale 的数学来源。
5.2 为什么是除以 $\sqrt{d_k}$,不是除以 $d_k$
如果点积的标准差约为 $\sqrt{d_k}$,那除以 $\sqrt{d_k}$ 后:
方差也回到 1 左右:
如果除以 $d_k$,会缩得太狠:
当 $d_k=64$ 时,标准差大约变成 $1/8$。分数太接近,softmax 会变得很平,注意力权重难以拉开差距。
| 缩放方式 | 分数尺度 | softmax 结果 |
|---|---|---|
| 不缩放 | 太大 | 过尖,容易接近 one-hot |
| 除以 $\sqrt{d_k}$ | 合适 | 分布有区分度,也不容易饱和 |
| 除以 $d_k$ | 太小 | 过平,接近平均看所有位置 |
5.3 不 scale 会发生什么
softmax 对大分数很敏感。
例子 1:分数尺度正常。
1 | softmax([1, 0, -1]) ≈ [0.665, 0.245, 0.090] |
三个位置都有梯度,模型还能调整。
例子 2:分数过大。
1 | softmax([10, 0, -10]) ≈ [0.99995, 0.00005, 0.00000] |
最大项几乎吃掉所有概率。这个状态下,softmax 的梯度很小,模型很难继续细调“第二相关”和“第三相关”的位置。
5.4 scale 不改变排序,但会改变分布形状
除以同一个正数不会改变大小顺序:
1 | [16, 8, 0] / 8 = [2, 1, 0] |
最大值仍然是最大值。变化在 softmax 之后:
| logits | softmax 大致结果 | 分布形状 |
|---|---|---|
[16, 8, 0] |
[0.9997, 0.0003, 0.0000] |
很尖 |
[2, 1, 0] |
[0.665, 0.245, 0.090] |
有区分,也不至于饱和 |
scale 的目的不是改变谁更相关,而是让 softmax 有合适的工作区间。
6. 第三步:softmax 把分数变成权重
对每个 query 位置,softmax 会在所有 key 位置上做归一化:
其中:
| 符号 | 含义 |
|---|---|
| $i$ | 当前 query 的位置 |
| $j$ | 被看的 key/value 位置 |
| $s_{ij}$ | 缩放后的相关性分数 |
| $\alpha_{ij}$ | 位置 $i$ 分给位置 $j$ 的注意力权重 |
一行权重会加起来等于 1:
所以 attention 可以看成“当前位置从所有 value 中按比例取信息”。
7. 第四步:对 V 做加权求和
得到权重后,输出为:
矩阵形式就是:
例子:
| 位置 | token | attention 权重 | value 信息 |
|---|---|---|---|
| 1 | 小明 | 0.10 | 人名相关信息 |
| 2 | 书 | 0.70 | 物体相关信息 |
| 3 | 书包 | 0.20 | 容器相关信息 |
输出向量就是:
1 | 0.10 * v_小明 + 0.70 * v_书 + 0.20 * v_书包 |
这也是 attention 被称为“加权检索”的原因。
8. Mask 在哪里加
实际 Transformer 里,softmax 前还会加 mask:
8.1 padding mask
padding token 是补齐长度用的,不应该被模型关注。
| token | 是否有效 | mask 值 |
|---|---|---|
| 我 | 有效 | 0 |
| 爱 | 有效 | 0 |
| NLP | 有效 | 0 |
<pad> |
无效 | $-\infty$ |
加上 $-\infty$ 后:
1 | softmax([2.0, 1.0, 0.5, -inf]) = [0.629, 0.231, 0.140, 0.000] |
8.2 causal mask
自回归语言模型不能看未来 token。
假设序列长度为 4,允许看的位置如下:
| query 位置 | 能看哪些 key 位置 |
|---|---|
| 1 | 1 |
| 2 | 1, 2 |
| 3 | 1, 2, 3 |
| 4 | 1, 2, 3, 4 |
对应矩阵:
1 | 允许 = 0 |
mask 一定要在 softmax 前加。softmax 后再把概率置 0,还需要重新归一化,容易写错。
9. 一个完整小例子
假设当前 query 对三个 key 的原始点积分数是:
1 | [16, 8, 0] |
设 $d_k=64$,那么 $\sqrt{d_k}=8$。
缩放前:
1 | softmax([16, 8, 0]) ≈ [0.9997, 0.0003, 0.0000] |
缩放后:
1 | [16, 8, 0] / 8 = [2, 1, 0] |
对 value 做加权求和:
1 | output = 0.665 * V1 + 0.245 * V2 + 0.090 * V3 |
缩放后,模型仍然更关注第一个位置,但不会完全忽略后两个位置。训练早期尤其需要这种余地。
10. 和 cosine similarity 的区别
有人会问:既然担心点积太大,为什么不用 cosine similarity?
cosine similarity 是:
它会把向量长度也归一化掉。Scaled dot-product attention 只按维度缩放,不会去掉向量模长信息。
| 对比 | scaled dot-product | cosine similarity |
|---|---|---|
| 缩放方式 | 除以固定的 $\sqrt{d_k}$ | 除以每对向量的模长 |
| 是否保留向量长度信息 | 保留 | 基本去掉 |
| 计算效率 | 很适合矩阵乘法 | 还要计算 norm |
| 在 Transformer 中 | 标准做法 | 可研究,但不是主流标准 |
向量长度有时也是模型学到的信号。直接用 cosine 会改变注意力打分的表达方式。
11. Multi-Head Attention 里的位置
Multi-Head Attention 会把 embedding 维度拆成多个 head,每个 head 各自做一遍 Scaled Dot-Product Attention。
1 | flowchart LR |
公式:
如果模型维度是 $d_{\text{model}}=768$,head 数是 $h=12$,那么每个 head 的 key 维度通常是:
此时 scale 就是:
每个 head 单独缩放。不是用整个 $d_{\text{model}}$ 去缩放。
12. 工程实现
12.1 简化版 PyTorch 代码
1 | import math |
真实工程里还会考虑:
| 问题 | 做法 |
|---|---|
fp16/bf16 下 -inf 处理 |
用 dtype 可承受的极小值,或交给框架内核 |
| 长上下文显存占用 | FlashAttention / memory-efficient attention |
| dropout | 对 attention weights 做 dropout |
| causal mask | 使用上三角 mask |
| KV cache | 推理时缓存历史 K/V,避免重复计算 |
12.2 数值稳定 softmax
softmax 实现通常会先减去最大值:
这不改变结果,但能避免 $e^x$ 溢出。
13. KV Cache 和推理
训练时,整段序列通常一次性进入模型:
1 | Q, K, V 都来自整段序列 |
自回归推理时,每次只生成一个新 token。历史 token 的 $K,V$ 不会变,可以缓存起来:
1 | 第 t 步: |
这样每一步不用重新计算所有历史 token 的 $K,V$。
| 阶段 | Q | K/V |
|---|---|---|
| 训练 | 全序列 | 全序列 |
| 推理第 t 步 | 当前 token | 历史 cache + 当前 token |
注意:KV cache 省的是重复计算 K/V,不会让 attention 对历史长度的依赖消失。当前 token 仍然要和历史 K 做匹配。
14. FlashAttention 做了什么
普通实现会显式生成 $L \times L$ 的 attention score 矩阵:
1 | scores = QK^T |
当 $L$ 很大时,$L \times L$ 矩阵很占显存。
FlashAttention 的思路是:不长期保存完整 attention 矩阵,而是分块计算,并用 online softmax 累积结果。
| 对比 | 普通 attention | FlashAttention |
|---|---|---|
| 数学结果 | exact | exact |
| 是否显式保存完整 $L \times L$ 权重 | 通常会 | 不长期保存 |
| 主要收益 | 实现简单 | 降低 HBM 读写,更快更省显存 |
它没有改变 Scaled Dot-Product Attention 的公式,只是改变计算组织方式。
15. 常见误解
| 误解 | 更准确的说法 |
|---|---|
| scale 是为了让分数都变小 | scale 是为了让分数尺度和维度解耦,避免 softmax 饱和 |
| 除以 $d_k$ 更合理,因为点积加了 $d_k$ 项 | 点积的方差随 $d_k$ 增长,标准差随 $\sqrt{d_k}$ 增长 |
| scale 不影响结果 | 它不改排序,但会改 softmax 后的权重分布 |
| attention 权重就是解释 | attention 权重能显示信息流,但不等同完整可解释性 |
| mask 可以 softmax 后再加 | 推荐 softmax 前加,否则要重新归一化 |
| Multi-head 里用 $d_{\text{model}}$ 缩放 | 每个 head 用自己的 $d_k$ 缩放 |
| FlashAttention 是近似 | FlashAttention 的主线实现是 exact attention |
16. 手算流程
给定:
1 | q = [1, 1] |
点积:
1 | q·k1 = 1 |
$d_k=2$,所以:
1 | sqrt(d_k) = 1.414 |
缩放:
1 | [1, 1, 2] / 1.414 ≈ [0.707, 0.707, 1.414] |
softmax:
1 | softmax([0.707, 0.707, 1.414]) ≈ [0.248, 0.248, 0.503] |
如果三个 value 是 $v_1,v_2,v_3$:
1 | output = 0.248*v1 + 0.248*v2 + 0.503*v3 |
第三个 key 和 query 最匹配,所以第三个 value 权重最高。
17. 和前一篇 softmax 文档的关系
Scaled Dot-Product Attention 里有两个层次:
| 层次 | 对应内容 |
|---|---|
| softmax 文档 | 解释“分数怎么变成概率/权重” |
| 本文 | 解释“attention 分数从哪里来,为什么要缩放” |
可以把本文的公式拆开看:
负责生成稳定的 attention scores;
负责把 scores 变成权重;
负责按权重汇总信息。
18. 小结
Scaled Dot-Product Attention 的内部逻辑可以按这条线记:
1 | Q/K 算相关性 -> 除以 sqrt(d_k) 稳定分数 -> softmax 得到权重 -> 加权汇总 V |
scale 的来源是点积方差:
除以 $\sqrt{d_k}$ 后,softmax 输入的尺度更稳定。这样既能保留点积注意力的高效矩阵乘法,又能避免高维点积分数过大导致 softmax 饱和。
19. 参考资料
- Transformer 原论文:Vaswani et al., Attention Is All You Need
- 原论文 HTML 版本:Attention Is All You Need - arXiv HTML
- FlashAttention:Dao et al., FlashAttention
- FlashAttention-2:Dao, FlashAttention-2