news 2026/4/15 17:45:05

动手学深度学习——自注意力代码

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
动手学深度学习——自注意力代码

1. 前言

上一篇我们已经从概念上理解了自注意力(Self-Attention)

  • query、key、value 都来自同一个序列

  • 每个位置都可以动态关注序列中其他位置

  • 它比 RNN 更擅长建模长距离依赖

  • 它也是后面 Transformer 和 BERT 的核心基础

这一篇就继续按李沐的节奏,把自注意力真正落实到代码上。

这一节最关键的不是记 API,
而是把这条计算链真正看懂:

  • 输入序列怎么变成Q、K、V

  • QK^T为什么能得到相关性分数

  • softmax在哪个维度做

  • V怎么被加权汇总

  • 输出形状到底怎么理解

如果一句话概括这一节代码的核心,那就是:

先做两两匹配,再做归一化加权求和。


2. 自注意力代码到底在算什么

自注意力最核心的公式通常写成:

Attention(Q, K, V) = softmax(QK^T / sqrt(d)) V

这条式子你后面会反复见到。

它可以拆成三步:

第一步:算分数矩阵

QK^T

表示每个 query 和每个 key 的相似度。

第二步:缩放并 softmax

softmax(QK^T / sqrt(d))

把相似度变成注意力权重分布。

第三步:对 V 做加权和

attention_weights @ V

得到每个位置新的上下文化表示。

所以整个自注意力并不神秘,本质就是:

分数矩阵 → 权重矩阵 → 加权输出矩阵


3. 先从输入张量形状开始

理解自注意力代码,第一步一定是先把形状看清楚。

假设输入序列表示为:

X.shape = (batch_size, num_steps, num_hiddens)

这里:

  • batch_size:一批样本数量

  • num_steps:序列长度

  • num_hiddens:每个位置的特征维度

例如一句话长度为 5,隐藏维度为 100,
那么每个样本就可以看成一个:

5 × 100

的矩阵。

自注意力的目标就是:

让这 5 个位置彼此交互,重新得到 5 个新的表示。


4. 为什么要先做Q、K、V线性变换

虽然自注意力里三者都来自同一个输入X
但它们的角色不同。

所以通常会先做三次线性投影:

Q = XW_q K = XW_k V = XW_v

在代码里,常见写法类似:

self.W_q = nn.Linear(num_hiddens, num_hiddens, bias=False) self.W_k = nn.Linear(num_hiddens, num_hiddens, bias=False) self.W_v = nn.Linear(num_hiddens, num_hiddens, bias=False)

然后前向传播时:

queries = self.W_q(X) keys = self.W_k(X) values = self.W_v(X)

这样同一个输入序列就变成了三套不同用途的表示。


5. 为什么不能直接用XQ、K、V

理论上有些极简情形可以直接用,
但实际中通常还是会做线性投影。

原因很简单:

第一,角色不同

  • query 表示“我要找什么”

  • key 表示“我能被怎么匹配”

  • value 表示“我真正提供什么内容”

同一个原始表示未必适合同时扮演这三种角色。

第二,提升表达能力

不同投影矩阵能让模型学到更灵活的匹配方式。

第三,便于后面多头注意力拆分

Transformer 里这是标准做法。

所以线性投影几乎可以看成自注意力的标配。


6. 最基础的缩放点积注意力代码怎么写

李沐这里常见的基础实现会先写一个“缩放点积注意力”模块,大致如下:

class DotProductAttention(nn.Module): def __init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values)

这段代码就是自注意力实现里最关键的主体之一。


7.keys.transpose(1, 2)为什么要转置

先看形状。

假设:

queries.shape = (batch_size, num_queries, d) keys.shape = (batch_size, num_kv_pairs, d)

如果要做每个 query 和每个 key 的内积,
那么就需要把 keys 变成:

(batch_size, d, num_kv_pairs)

所以要写:

keys.transpose(1, 2)

这样后面:

torch.bmm(queries, keys.transpose(1, 2))

结果才会是:

(batch_size, num_queries, num_kv_pairs)

这正好就是“分数矩阵”。


8.torch.bmm(queries, keys.transpose(1, 2))到底算出了什么

这一步算的是:

每个 query 和每个 key 的点积相似度

如果某个 batch 里有:

  • 5 个 query

  • 5 个 key

那结果就是一个:

5 × 5

的分数矩阵。

其中矩阵第i, j个元素表示:

第 i 个位置作为 query 时,对第 j 个位置 key 的相关性评分

所以这一步本质上是在建立:

序列中任意两个位置之间的关系图

这也是自注意力强大的原因之一。


9. 为什么要除以sqrt(d)

这一句:

scores = ... / math.sqrt(d)

就是缩放点积注意力里的“缩放”。

为什么要做?

因为如果维度d很大,
点积结果的数值幅度通常也会偏大。
而后面还要做 softmax,分数太大容易导致:

  • 分布特别尖锐

  • 梯度变小

  • 训练不稳定

所以除以:

sqrt(d)

就是为了把分数控制在更合适的范围里。

这一步看似小,实际上非常关键。


10.masked_softmax(scores, valid_lens)在这里起什么作用

这一步的作用和前面注意力代码一样:

把 padding 位置屏蔽掉

例如一个 batch 中有些句子长度较短,
后面补了<pad>
那么这些 pad 位置不应该参与自注意力分配。

所以:

  • 真实 token 位置会正常参与 softmax

  • pad 位置会被 mask 成极小值

  • softmax 后权重几乎为 0

也就是说,自注意力虽然是“全局关注”,
但这个“全局”仍然只是在有效长度范围内


11.torch.bmm(attention_weights, values)为什么就是加权和

假设:

attention_weights.shape = (batch_size, num_queries, num_kv_pairs) values.shape = (batch_size, num_kv_pairs, value_dim)

那么矩阵乘法后得到:

(batch_size, num_queries, value_dim)

这意味着:

  • 对于每个 query

  • attention_weights 给出了对所有 value 的权重分布

  • bmm 就是按照这些权重对 value 做加权和

所以最后输出的每个位置表示,本质上就是:

把序列中其他位置的信息,按相关性加权汇总过来。

这一步就是“自注意力后的新表示”。


12. 为什么说自注意力输出是“上下文化表示”

因为原始输入位置x_i只包含:

  • 当前位置本身的信息

而经过自注意力后,第i个位置的新表示包含了:

  • 自己的信息

  • 和自己相关的其他位置的信息

所以它不再是“孤立 token 表示”,
而是:

融合了整句上下文后的表示

这就是“上下文化表示”的含义。

例如同一个词出现在不同句子里,
经过自注意力后的表示会不同,
因为它吸收的上下文不同。


13. 如果是自注意力,num_queriesnum_kv_pairs有什么关系

在最典型的自注意力里,三者都来自同一个序列,
所以通常有:

num_queries = num_kv_pairs = num_steps

也就是说:

  • 序列有几个位置

  • 就有几个 query

  • 也有几个 key/value

于是分数矩阵通常是:

(batch_size, num_steps, num_steps)

这就是著名的:

每个位置和每个位置之间的关系矩阵

所以自注意力的核心其实就是在构造这样一张“全序列关系图”。


14. 一个具体形状例子帮助理解

假设:

  • batch_size = 2

  • num_steps = 4

  • num_hiddens = 8

那么输入X形状是:

(2, 4, 8)

做完Q、K、V线性变换后,形状仍然可能是:

Q.shape = K.shape = V.shape = (2, 4, 8)

然后:

分数矩阵

scores.shape = (2, 4, 4)

表示每个 batch 里 4 个位置两两打分。

注意力输出

output.shape = (2, 4, 8)

表示最后仍然是 4 个位置,每个位置一个新的 8 维表示。

你会发现:

自注意力通常不会改变“位置数”,而是更新“每个位置的表示内容”。


15. 为什么自注意力比 RNN 更容易并行

因为这里所有位置之间的关系,
都可以通过矩阵运算一次性算出来。

例如:

  • QK^T一次就得到所有 query-key 配对分数

  • softmax一次处理完整分数矩阵

  • attention_weights @ V一次得到全部位置输出

不像 RNN 那样必须:

  • 第 1 步算完才能算第 2 步

  • 第 2 步算完才能算第 3 步

所以自注意力特别适合 GPU/TPU 这类擅长大规模矩阵并行的硬件。

这也是它后来能彻底改变 NLP 的一个重要工程原因。


16. 代码里为什么常常保存attention_weights

和前面普通注意力一样,自注意力实现里也常会保存:

self.attention_weights

原因主要有两个:

第一,后续可能需要继续使用

例如调试或某些分析模块。

第二,便于可视化

自注意力特别适合画热力图。
你可以直接看到:

  • 第 i 个位置到底关注了哪些位置

  • 是否关注了长距离词

  • 是否学到了某种句法或语义关系

这也是自注意力特别“可解释”的一个地方。


17. 自注意力代码和前面注意力代码最本质的区别是什么

如果一句话说清楚,就是:

前面普通注意力通常是一个序列看另一个序列,
自注意力则是同一个序列看自己。

所以代码层面最本质的变化就是:

  • query、key、value 不再来自不同模块

  • 而是都由同一个输入X线性变换得到

其他主线其实没变:

  • 还是先算分数

  • 再 softmax

  • 再加权和

所以自注意力不是“完全不同的机制”,
而是注意力机制的一个非常重要的特殊形式。


18. 这一节最该掌握什么

如果从学习重点来看,最关键的是这几件事。

18.1 看懂Q、K、V是怎么从X得来的

知道为什么要做三次线性投影。

18.2 看懂QK^T为什么是分数矩阵

这是自注意力最核心的一步。

18.3 看懂softmax的作用

它把分数变成注意力分布。

18.4 看懂attention_weights @ V为什么是加权和

这是输出形成的关键。

18.5 理解输出形状通常和输入位置数保持一致

只是每个位置的表示被更新了。


19. 这一节和后面 Transformer / BERT 的关系

这一节实际上非常关键,因为:

Transformer 的核心模块,本质上就是自注意力。

后面的:

  • 多头注意力

  • 位置编码

  • Transformer Encoder

  • BERT

都建立在这一节基础上。

所以你一定要意识到:

自注意力代码这一节,不只是“又一个模块”,而是后面整个现代 NLP 架构的地基。

如果这一节真的看懂了,
后面 Transformer 和 BERT 会顺很多。


20. 本节总结

这一节我们学习了自注意力的代码实现,核心内容可以总结为以下几点。

20.1 自注意力通常先把输入序列映射成 Q、K、V 三组表示

虽然都来自同一输入,但角色不同。

20.2QK^T用于计算序列中任意两个位置之间的相关性分数

这是自注意力的核心关系矩阵。

20.3 分数经过缩放和 softmax 后变成注意力权重

用于表示每个位置对其他位置的关注程度。

20.4 输出由注意力权重对 V 做加权和得到

从而形成新的上下文化表示。

20.5 自注意力特别适合并行计算,也是后面 Transformer 和 BERT 的核心基础

这一点非常重要。


21. 学习感悟

这一节特别有价值,因为它让你真正看到:

“一个位置如何去读整个序列”
这件事,原来是可以被精确写成矩阵运算的。

以前我们更多把序列建模理解成:

  • 沿时间递推

  • 慢慢传递信息

而自注意力一下把这个过程改写成了:

我直接和所有位置做匹配,然后把最相关的信息一次性取回来。

这不只是模型结构上的变化,
更是一种完全不同的信息流动方式。

也正因为如此,自注意力才会成为后面 Transformer 和 BERT 的真正起点。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/15 17:42:20

RK3588实战:Qt+OpenCV环境搭建与USB摄像头实时采集全攻略

1. 环境准备&#xff1a;从零搭建RK3588开发环境 第一次拿到RK3588开发板时&#xff0c;我和大多数开发者一样兴奋又忐忑。这款六核ARM处理器在嵌入式视觉领域确实是个狠角色&#xff0c;但要让它的性能真正发挥出来&#xff0c;环境搭建就是第一道门槛。这里分享我反复验证过的…

作者头像 李华
网站建设 2026/4/15 17:41:05

零基础3分钟解锁RPG游戏资源:RPG Maker MV解密终极指南

零基础3分钟解锁RPG游戏资源&#xff1a;RPG Maker MV解密终极指南 【免费下载链接】RPG-Maker-MV-Decrypter You can decrypt RPG-Maker-MV Resource Files with this project ~ If you dont wanna download it, you can use the Script on my HP: 项目地址: https://gitcod…

作者头像 李华
网站建设 2026/4/15 17:38:16

DDrawCompat终极指南:5分钟让Windows老游戏重获新生

DDrawCompat终极指南&#xff1a;5分钟让Windows老游戏重获新生 【免费下载链接】DDrawCompat DirectDraw and Direct3D 1-7 compatibility, performance and visual enhancements for Windows Vista, 7, 8, 10 and 11 项目地址: https://gitcode.com/gh_mirrors/dd/DDrawCom…

作者头像 李华
网站建设 2026/4/15 17:36:34

你的Windows电脑太“胖“了?试试这个一键瘦身神器!

你的Windows电脑太"胖"了&#xff1f;试试这个一键瘦身神器&#xff01; 【免费下载链接】Win11Debloat A simple, lightweight PowerShell script that allows you to remove pre-installed apps, disable telemetry, as well as perform various other changes to …

作者头像 李华
网站建设 2026/4/15 17:36:32

CXPatcher:免费一键解锁CrossOver游戏兼容性的完整指南

CXPatcher&#xff1a;免费一键解锁CrossOver游戏兼容性的完整指南 【免费下载链接】CXPatcher A patcher to upgrade Crossover dependencies and improve compatibility 项目地址: https://gitcode.com/gh_mirrors/cx/CXPatcher 你是否在Mac上使用CrossOver运行Window…

作者头像 李华