1. 前言
上一篇我们已经从概念上理解了自注意力(Self-Attention):
query、key、value 都来自同一个序列
每个位置都可以动态关注序列中其他位置
它比 RNN 更擅长建模长距离依赖
它也是后面 Transformer 和 BERT 的核心基础
这一篇就继续按李沐的节奏,把自注意力真正落实到代码上。
这一节最关键的不是记 API,
而是把这条计算链真正看懂:
输入序列怎么变成
Q、K、VQK^T为什么能得到相关性分数softmax在哪个维度做V怎么被加权汇总输出形状到底怎么理解
如果一句话概括这一节代码的核心,那就是:
先做两两匹配,再做归一化加权求和。
2. 自注意力代码到底在算什么
自注意力最核心的公式通常写成:
Attention(Q, K, V) = softmax(QK^T / sqrt(d)) V这条式子你后面会反复见到。
它可以拆成三步:
第一步:算分数矩阵
QK^T表示每个 query 和每个 key 的相似度。
第二步:缩放并 softmax
softmax(QK^T / sqrt(d))把相似度变成注意力权重分布。
第三步:对 V 做加权和
attention_weights @ V得到每个位置新的上下文化表示。
所以整个自注意力并不神秘,本质就是:
分数矩阵 → 权重矩阵 → 加权输出矩阵
3. 先从输入张量形状开始
理解自注意力代码,第一步一定是先把形状看清楚。
假设输入序列表示为:
X.shape = (batch_size, num_steps, num_hiddens)这里:
batch_size:一批样本数量num_steps:序列长度num_hiddens:每个位置的特征维度
例如一句话长度为 5,隐藏维度为 100,
那么每个样本就可以看成一个:
5 × 100的矩阵。
自注意力的目标就是:
让这 5 个位置彼此交互,重新得到 5 个新的表示。
4. 为什么要先做Q、K、V线性变换
虽然自注意力里三者都来自同一个输入X,
但它们的角色不同。
所以通常会先做三次线性投影:
Q = XW_q K = XW_k V = XW_v在代码里,常见写法类似:
self.W_q = nn.Linear(num_hiddens, num_hiddens, bias=False) self.W_k = nn.Linear(num_hiddens, num_hiddens, bias=False) self.W_v = nn.Linear(num_hiddens, num_hiddens, bias=False)然后前向传播时:
queries = self.W_q(X) keys = self.W_k(X) values = self.W_v(X)这样同一个输入序列就变成了三套不同用途的表示。
5. 为什么不能直接用X当Q、K、V
理论上有些极简情形可以直接用,
但实际中通常还是会做线性投影。
原因很简单:
第一,角色不同
query 表示“我要找什么”
key 表示“我能被怎么匹配”
value 表示“我真正提供什么内容”
同一个原始表示未必适合同时扮演这三种角色。
第二,提升表达能力
不同投影矩阵能让模型学到更灵活的匹配方式。
第三,便于后面多头注意力拆分
Transformer 里这是标准做法。
所以线性投影几乎可以看成自注意力的标配。
6. 最基础的缩放点积注意力代码怎么写
李沐这里常见的基础实现会先写一个“缩放点积注意力”模块,大致如下:
class DotProductAttention(nn.Module): def __init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values)这段代码就是自注意力实现里最关键的主体之一。
7.keys.transpose(1, 2)为什么要转置
先看形状。
假设:
queries.shape = (batch_size, num_queries, d) keys.shape = (batch_size, num_kv_pairs, d)如果要做每个 query 和每个 key 的内积,
那么就需要把 keys 变成:
(batch_size, d, num_kv_pairs)所以要写:
keys.transpose(1, 2)这样后面:
torch.bmm(queries, keys.transpose(1, 2))结果才会是:
(batch_size, num_queries, num_kv_pairs)这正好就是“分数矩阵”。
8.torch.bmm(queries, keys.transpose(1, 2))到底算出了什么
这一步算的是:
每个 query 和每个 key 的点积相似度
如果某个 batch 里有:
5 个 query
5 个 key
那结果就是一个:
5 × 5的分数矩阵。
其中矩阵第i, j个元素表示:
第 i 个位置作为 query 时,对第 j 个位置 key 的相关性评分
所以这一步本质上是在建立:
序列中任意两个位置之间的关系图
这也是自注意力强大的原因之一。
9. 为什么要除以sqrt(d)
这一句:
scores = ... / math.sqrt(d)就是缩放点积注意力里的“缩放”。
为什么要做?
因为如果维度d很大,
点积结果的数值幅度通常也会偏大。
而后面还要做 softmax,分数太大容易导致:
分布特别尖锐
梯度变小
训练不稳定
所以除以:
sqrt(d)就是为了把分数控制在更合适的范围里。
这一步看似小,实际上非常关键。
10.masked_softmax(scores, valid_lens)在这里起什么作用
这一步的作用和前面注意力代码一样:
把 padding 位置屏蔽掉
例如一个 batch 中有些句子长度较短,
后面补了<pad>。
那么这些 pad 位置不应该参与自注意力分配。
所以:
真实 token 位置会正常参与 softmax
pad 位置会被 mask 成极小值
softmax 后权重几乎为 0
也就是说,自注意力虽然是“全局关注”,
但这个“全局”仍然只是在有效长度范围内。
11.torch.bmm(attention_weights, values)为什么就是加权和
假设:
attention_weights.shape = (batch_size, num_queries, num_kv_pairs) values.shape = (batch_size, num_kv_pairs, value_dim)那么矩阵乘法后得到:
(batch_size, num_queries, value_dim)这意味着:
对于每个 query
attention_weights 给出了对所有 value 的权重分布
bmm 就是按照这些权重对 value 做加权和
所以最后输出的每个位置表示,本质上就是:
把序列中其他位置的信息,按相关性加权汇总过来。
这一步就是“自注意力后的新表示”。
12. 为什么说自注意力输出是“上下文化表示”
因为原始输入位置x_i只包含:
当前位置本身的信息
而经过自注意力后,第i个位置的新表示包含了:
自己的信息
和自己相关的其他位置的信息
所以它不再是“孤立 token 表示”,
而是:
融合了整句上下文后的表示
这就是“上下文化表示”的含义。
例如同一个词出现在不同句子里,
经过自注意力后的表示会不同,
因为它吸收的上下文不同。
13. 如果是自注意力,num_queries和num_kv_pairs有什么关系
在最典型的自注意力里,三者都来自同一个序列,
所以通常有:
num_queries = num_kv_pairs = num_steps也就是说:
序列有几个位置
就有几个 query
也有几个 key/value
于是分数矩阵通常是:
(batch_size, num_steps, num_steps)这就是著名的:
每个位置和每个位置之间的关系矩阵
所以自注意力的核心其实就是在构造这样一张“全序列关系图”。
14. 一个具体形状例子帮助理解
假设:
batch_size = 2num_steps = 4num_hiddens = 8
那么输入X形状是:
(2, 4, 8)做完Q、K、V线性变换后,形状仍然可能是:
Q.shape = K.shape = V.shape = (2, 4, 8)然后:
分数矩阵
scores.shape = (2, 4, 4)表示每个 batch 里 4 个位置两两打分。
注意力输出
output.shape = (2, 4, 8)表示最后仍然是 4 个位置,每个位置一个新的 8 维表示。
你会发现:
自注意力通常不会改变“位置数”,而是更新“每个位置的表示内容”。
15. 为什么自注意力比 RNN 更容易并行
因为这里所有位置之间的关系,
都可以通过矩阵运算一次性算出来。
例如:
QK^T一次就得到所有 query-key 配对分数softmax一次处理完整分数矩阵attention_weights @ V一次得到全部位置输出
不像 RNN 那样必须:
第 1 步算完才能算第 2 步
第 2 步算完才能算第 3 步
所以自注意力特别适合 GPU/TPU 这类擅长大规模矩阵并行的硬件。
这也是它后来能彻底改变 NLP 的一个重要工程原因。
16. 代码里为什么常常保存attention_weights
和前面普通注意力一样,自注意力实现里也常会保存:
self.attention_weights原因主要有两个:
第一,后续可能需要继续使用
例如调试或某些分析模块。
第二,便于可视化
自注意力特别适合画热力图。
你可以直接看到:
第 i 个位置到底关注了哪些位置
是否关注了长距离词
是否学到了某种句法或语义关系
这也是自注意力特别“可解释”的一个地方。
17. 自注意力代码和前面注意力代码最本质的区别是什么
如果一句话说清楚,就是:
前面普通注意力通常是一个序列看另一个序列,
自注意力则是同一个序列看自己。
所以代码层面最本质的变化就是:
query、key、value 不再来自不同模块
而是都由同一个输入
X线性变换得到
其他主线其实没变:
还是先算分数
再 softmax
再加权和
所以自注意力不是“完全不同的机制”,
而是注意力机制的一个非常重要的特殊形式。
18. 这一节最该掌握什么
如果从学习重点来看,最关键的是这几件事。
18.1 看懂Q、K、V是怎么从X得来的
知道为什么要做三次线性投影。
18.2 看懂QK^T为什么是分数矩阵
这是自注意力最核心的一步。
18.3 看懂softmax的作用
它把分数变成注意力分布。
18.4 看懂attention_weights @ V为什么是加权和
这是输出形成的关键。
18.5 理解输出形状通常和输入位置数保持一致
只是每个位置的表示被更新了。
19. 这一节和后面 Transformer / BERT 的关系
这一节实际上非常关键,因为:
Transformer 的核心模块,本质上就是自注意力。
后面的:
多头注意力
位置编码
Transformer Encoder
BERT
都建立在这一节基础上。
所以你一定要意识到:
自注意力代码这一节,不只是“又一个模块”,而是后面整个现代 NLP 架构的地基。
如果这一节真的看懂了,
后面 Transformer 和 BERT 会顺很多。
20. 本节总结
这一节我们学习了自注意力的代码实现,核心内容可以总结为以下几点。
20.1 自注意力通常先把输入序列映射成 Q、K、V 三组表示
虽然都来自同一输入,但角色不同。
20.2QK^T用于计算序列中任意两个位置之间的相关性分数
这是自注意力的核心关系矩阵。
20.3 分数经过缩放和 softmax 后变成注意力权重
用于表示每个位置对其他位置的关注程度。
20.4 输出由注意力权重对 V 做加权和得到
从而形成新的上下文化表示。
20.5 自注意力特别适合并行计算,也是后面 Transformer 和 BERT 的核心基础
这一点非常重要。
21. 学习感悟
这一节特别有价值,因为它让你真正看到:
“一个位置如何去读整个序列”
这件事,原来是可以被精确写成矩阵运算的。
以前我们更多把序列建模理解成:
沿时间递推
慢慢传递信息
而自注意力一下把这个过程改写成了:
我直接和所有位置做匹配,然后把最相关的信息一次性取回来。
这不只是模型结构上的变化,
更是一种完全不同的信息流动方式。
也正因为如此,自注意力才会成为后面 Transformer 和 BERT 的真正起点。