1. 前言
上一篇我们已经从原理上认识了GRU(门控循环单元):
它是对基础 RNN 的改进
它引入了门控机制
它通过更新门和重置门来控制信息流
它更擅长处理长期依赖问题
但是,只理解公式还不够。
和前面 RNN 一样,真正把 GRU 学扎实,最好的方式还是:
把公式一步一步写成代码。
这一节的任务就是把 GRU 真正落到实现层面。
你会看到:
GRU 比 RNN 多了哪些参数
更新门、重置门在代码里怎么写
候选隐藏状态如何计算
最终隐藏状态如何更新
简洁实现和从零实现分别怎么对应
这一篇本质上就是:
把“门控记忆”变成可运行的程序。
2. GRU 代码实现要解决什么
如果把这一节拆开看,核心其实就 4 件事:
2.1 初始化更多参数
相比基础 RNN,GRU 不再只有一组隐藏状态更新参数,
而是需要分别为:
更新门
重置门
候选隐藏状态
各自准备参数。
2.2 写新的状态更新公式
也就是把上一篇的四条核心公式真正变成代码。
2.3 保持语言模型训练接口一致
GRU 虽然内部更复杂,但对外仍然要能接:
输入序列
初始状态
输出预测
最终状态
2.4 对照简洁实现
看 PyTorch 里的nn.GRU到底帮我们封装了哪些内容。
3. 先回顾 GRU 的核心公式
写代码前,先把最关键的四条公式再捋清楚。
更新门
Z_t = σ(X_t W_xz + H_{t-1} W_hz + b_z)重置门
R_t = σ(X_t W_xr + H_{t-1} W_hr + b_r)候选隐藏状态
H_t_tilde = tanh(X_t W_xh + (R_t ⊙ H_{t-1}) W_hh + b_h)最终隐藏状态
H_t = Z_t ⊙ H_{t-1} + (1 - Z_t) ⊙ H_t_tilde其中:
σ是 sigmoid⊙是按元素乘法
你可以看到,GRU 和基础 RNN 最大的不同就在于:
更新隐藏状态之前,先算门。
4. GRU 从零实现:先初始化参数
基础 RNN 只需要一套隐藏更新参数。
而 GRU 至少要准备三套。
常见写法如下:
def get_params(vocab_size, num_hiddens, device): num_inputs = num_outputs = vocab_size def normal(shape): return torch.randn(size=shape, device=device) * 0.01 def three(): return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)), torch.zeros(num_hiddens, device=device)) W_xz, W_hz, b_z = three() # 更新门 W_xr, W_hr, b_r = three() # 重置门 W_xh, W_hh, b_h = three() # 候选隐藏状态 W_hq = normal((num_hiddens, num_outputs)) b_q = torch.zeros(num_outputs, device=device) params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] for param in params: param.requires_grad_(True) return params这段代码就是 GRU 从零实现的起点。
5. 为什么这里有三组“输入-隐藏-偏置”参数
因为 GRU 要分别计算三类东西:
第一组:更新门参数
W_xz, W_hz, b_z它们负责控制“旧状态保留多少”。
第二组:重置门参数
W_xr, W_hr, b_r它们负责控制“旧状态在候选状态里参与多少”。
第三组:候选隐藏状态参数
W_xh, W_hh, b_h它们负责生成新的候选状态。
所以,GRU 比 RNN 参数更多,不是因为它“乱加复杂度”,
而是因为它确实要分别处理三种不同功能。
6. 隐藏状态初始化和 RNN 一样吗
基本一样。
因为 GRU 最终对外仍然只有一个隐藏状态H_t,
不像 LSTM 还会多一个单独记忆单元。
所以状态初始化通常仍然写成:
def init_gru_state(batch_size, num_hiddens, device): return (torch.zeros((batch_size, num_hiddens), device=device), )也就是说:
每个样本一份隐藏状态
初始时全零
返回成元组形式,方便接口统一
7. GRU 的前向传播是这一节最核心的代码
常见从零实现写法如下:
def gru(inputs, state, params): W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params H, = state outputs = [] for X in inputs: Z = torch.sigmoid(torch.mm(X, W_xz) + torch.mm(H, W_hz) + b_z) R = torch.sigmoid(torch.mm(X, W_xr) + torch.mm(H, W_hr) + b_r) H_tilde = torch.tanh(torch.mm(X, W_xh) + torch.mm(R * H, W_hh) + b_h) H = Z * H + (1 - Z) * H_tilde Y = torch.mm(H, W_hq) + b_q outputs.append(Y) return torch.cat(outputs, dim=0), (H,)如果你把这段代码真正看明白,GRU 就已经学通了大半。
8. 更新门这两行代码怎么理解
先看:
Z = torch.sigmoid(torch.mm(X, W_xz) + torch.mm(H, W_hz) + b_z)它对应的就是更新门公式:
Z_t = σ(X_t W_xz + H_{t-1} W_hz + b_z)含义是:
当前输入
X和上一隐藏状态
H共同决定当前时间步“该保留多少旧状态”
Z的形状通常是:
(batch_size, num_hiddens)也就是说,对每个样本、每个隐藏单元,都有一个门控值。
这意味着:
GRU 不是粗暴地对整个状态“一刀切”,而是按隐藏单元逐个控制。
9. 重置门代码怎么理解
再看:
R = torch.sigmoid(torch.mm(X, W_xr) + torch.mm(H, W_hr) + b_r)它对应:
R_t = σ(X_t W_xr + H_{t-1} W_hr + b_r)它决定的是:
在计算候选隐藏状态时,旧隐藏状态该参与多少。
你可以把它理解成一种“历史清洗器”:
R大,说明历史信息还很重要R小,说明历史信息该弱化一些
10. 候选隐藏状态代码怎么理解
这是 GRU 最关键的一步之一:
H_tilde = torch.tanh(torch.mm(X, W_xh) + torch.mm(R * H, W_hh) + b_h)它对应:
H_t_tilde = tanh(X_t W_xh + (R_t ⊙ H_{t-1}) W_hh + b_h)要点在于:
不是直接用H
而是先做:
R * H这表示:
旧隐藏状态先被重置门筛一遍,再参与候选状态生成。
这就是 GRU 相比基础 RNN 非常关键的精细化控制。
11. 最终隐藏状态更新代码怎么理解
看这一句:
H = Z * H + (1 - Z) * H_tilde这就是:
H_t = Z_t ⊙ H_{t-1} + (1 - Z_t) ⊙ H_t_tilde直观上它就是在做:
旧状态和新候选状态的加权平均
如果:
Z大,则更偏向旧状态Z小,则更偏向新候选状态
所以最终隐藏状态不是“全用旧的”也不是“全用新的”,
而是模型自己学出来的动态折中。
这正是 GRU 最大的魅力所在。
12. 为什么说 GRU 的记忆更可控
从这几行代码里你就能直接看出来:
基础 RNN 的状态更新基本是一条路:
输入来
历史来
一起过
tanh更新完成
而 GRU 明显多了控制环节:
先决定历史该参与多少
再生成候选状态
再决定最终保留多少旧状态
所以 GRU 本质上是一种:
可学习的信息流控制机制
而不是单纯“算出一个新状态”而已。
13. 输出层为什么和 RNN 一样
注意这一句:
Y = torch.mm(H, W_hq) + b_q这和基础 RNN 完全一样。
为什么?
因为 GRU 改进的是:
隐藏状态的内部更新机制
而不是语言模型最终输出的形式。
对于字符级语言模型来说,最后仍然是:
当前隐藏状态
映射到词表空间
得到对每个字符的打分
所以输出层部分不需要大改。
14. 从零实现的 GRU 如何封装成模型类
和上一节 RNN 一样,通常也会封装一个“手写模型容器”。
例如继续用之前那种思路:
net = d2l.RNNModelScratch(vocab_size, num_hiddens, device, get_params, init_gru_state, gru)注意这非常漂亮。
你会发现:
模型容器类几乎不用改
只需要把
参数初始化函数
状态初始化函数
前向传播函数
换成 GRU 版本
整个模型就从 RNN 变成了 GRU。
这说明什么?
说明 GRU 和 RNN 在“接口层面”是一致的,
真正变化的是内部单元。
15. 简洁实现:PyTorch 里的nn.GRU
从零实现看懂以后,简洁实现就很好接受了。
PyTorch 已经封装好了:
nn.GRU基本用法和nn.RNN非常相似。
例如:
gru_layer = nn.GRU(input_size=vocab_size, hidden_size=num_hiddens)这里:
input_size仍然是每个时间步输入向量维度hidden_size仍然是隐藏状态维度
然后前向传播接口也很像:
Y, state = gru_layer(X, state)输出逻辑和nn.RNN一样:
Y:所有时间步输出state:最后隐藏状态
16. 为什么nn.GRU用法和nn.RNN这么像
因为从更高层视角看,它们都属于:
循环序列建模模块
它们对外接口基本一致:
输入序列
初始状态
输出序列
最终状态
只不过内部单元计算方式不同:
nn.RNN
内部是基础循环单元。
nn.GRU
内部是门控循环单元。
所以你可以把 GRU 看成是:
RNN API 体系下的一个更强单元版本
这对工程使用非常友好。
17. 简洁实现里的模型封装怎么写
和简洁版 RNN 一样,通常可以继续复用同样的模型外壳:
GRU 层负责序列递推
线性层负责把隐藏状态映射到词表空间
也就是说,模型代码结构上几乎不用大动,
只要把:
rnn_layer = nn.RNN(...)换成:
gru_layer = nn.GRU(...)就能得到 GRU 版本模型。
这说明:
GRU 的升级主要体现在循环单元内部,不在外围框架。
18. GRU 代码和 RNN 代码最本质的区别在哪里
这是这一节最该点破的地方。
RNN 从零实现
核心只有一条隐藏状态更新公式。
GRU 从零实现
核心变成了:
先算更新门
再算重置门
再算候选状态
再融合新旧状态
所以如果你问:
GRU 相比 RNN,代码上本质多了什么?
答案就是:
多了门控变量和基于门控的状态融合机制
其他外部流程,例如:
one-hot 输入
初始状态
输出层
文本生成
其实都还是同一套套路。
19. GRU 训练时需要额外特殊处理吗
整体训练流程和 RNN 基本一致。
仍然是:
输入 token 序列
输出对下一个 token 的预测
用交叉熵损失
反向传播
梯度裁剪
参数更新
所以从训练框架角度,GRU 并没有引入额外特别陌生的流程。
变化的是内部状态更新更智能了。
这也是为什么在工程里,GRU 往往可以很平滑地替代基础 RNN。
20. 为什么 GRU 常常比基础 RNN 更实用
从代码层面你已经能看出来原因:
第一,它可以显式保留旧状态
更新门允许旧信息直接穿过去。
第二,它可以有选择地忽略部分历史
重置门让模型不会被无关旧信息拖累。
第三,它改进了梯度传播路径
虽然不能彻底消除所有问题,但比基础 RNN 更容易训练稳定。
所以很多时候,GRU 是一种:
在复杂度和性能之间比较均衡的循环结构
21. 这一节最该掌握什么
如果从学习重点来看,这一节最关键的是下面几件事。
21.1 看懂参数初始化比 RNN 多在哪
知道为什么 GRU 至少需要三组核心参数。
21.2 看懂四条核心公式如何一一落到代码
尤其是:
ZRH_tildeH
之间的关系。
21.3 理解最终状态更新是“新旧信息加权融合”
这是 GRU 相比 RNN 的本质增强。
21.4 知道nn.GRU和nn.RNN的接口非常接近
方便后面工程应用。
21.5 明白 GRU 的改进点主要发生在单元内部
外围训练逻辑其实变化不大。
22. 本节总结
这一节我们学习了 GRU 的代码实现,核心内容可以总结为以下几点。
22.1 GRU 从零实现比 RNN 多了门控参数
主要包括:
更新门参数
重置门参数
候选状态参数
22.2 GRU 前向传播的关键是先算门,再更新状态
这让信息流更加可控。
22.3 最终隐藏状态是旧状态和候选状态的门控融合
而不是像基础 RNN 那样一次性混合更新。
22.4 简洁实现里nn.GRU和nn.RNN用法高度相似
只不过内部计算更强。
22.5 GRU 是基础 RNN 向更强序列建模迈出的重要一步
也是后面 LSTM 的前置基础。
23. 学习感悟
这一节很有意思,因为你会第一次非常明显地感受到:
一个模型的提升,不一定来自“推翻重来”,也可能来自对信息流路径做精细控制。
GRU 其实没有把循环神经网络完全重写,
它只是问了两个更聪明的问题:
以前的信息还值不值得保留?
新的信息该不该立刻写进去?
就是这两个问题,让它比基础 RNN 更能“记事”。
从这个角度看,GRU 的优雅之处,不在于公式复杂,
而在于它让记忆管理第一次变得真正有策略。