news 2026/4/14 18:13:08

动手学深度学习——GRU代码

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
动手学深度学习——GRU代码

1. 前言

上一篇我们已经从原理上认识了GRU(门控循环单元)

  • 它是对基础 RNN 的改进

  • 它引入了门控机制

  • 它通过更新门和重置门来控制信息流

  • 它更擅长处理长期依赖问题

但是,只理解公式还不够。
和前面 RNN 一样,真正把 GRU 学扎实,最好的方式还是:

把公式一步一步写成代码。

这一节的任务就是把 GRU 真正落到实现层面。
你会看到:

  • GRU 比 RNN 多了哪些参数

  • 更新门、重置门在代码里怎么写

  • 候选隐藏状态如何计算

  • 最终隐藏状态如何更新

  • 简洁实现和从零实现分别怎么对应

这一篇本质上就是:

把“门控记忆”变成可运行的程序。


2. GRU 代码实现要解决什么

如果把这一节拆开看,核心其实就 4 件事:

2.1 初始化更多参数

相比基础 RNN,GRU 不再只有一组隐藏状态更新参数,
而是需要分别为:

  • 更新门

  • 重置门

  • 候选隐藏状态

各自准备参数。

2.2 写新的状态更新公式

也就是把上一篇的四条核心公式真正变成代码。

2.3 保持语言模型训练接口一致

GRU 虽然内部更复杂,但对外仍然要能接:

  • 输入序列

  • 初始状态

  • 输出预测

  • 最终状态

2.4 对照简洁实现

看 PyTorch 里的nn.GRU到底帮我们封装了哪些内容。


3. 先回顾 GRU 的核心公式

写代码前,先把最关键的四条公式再捋清楚。

更新门

Z_t = σ(X_t W_xz + H_{t-1} W_hz + b_z)

重置门

R_t = σ(X_t W_xr + H_{t-1} W_hr + b_r)

候选隐藏状态

H_t_tilde = tanh(X_t W_xh + (R_t ⊙ H_{t-1}) W_hh + b_h)

最终隐藏状态

H_t = Z_t ⊙ H_{t-1} + (1 - Z_t) ⊙ H_t_tilde

其中:

  • σ是 sigmoid

  • 是按元素乘法

你可以看到,GRU 和基础 RNN 最大的不同就在于:

更新隐藏状态之前,先算门。


4. GRU 从零实现:先初始化参数

基础 RNN 只需要一套隐藏更新参数。
而 GRU 至少要准备三套。

常见写法如下:

def get_params(vocab_size, num_hiddens, device): num_inputs = num_outputs = vocab_size def normal(shape): return torch.randn(size=shape, device=device) * 0.01 def three(): return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)), torch.zeros(num_hiddens, device=device)) W_xz, W_hz, b_z = three() # 更新门 W_xr, W_hr, b_r = three() # 重置门 W_xh, W_hh, b_h = three() # 候选隐藏状态 W_hq = normal((num_hiddens, num_outputs)) b_q = torch.zeros(num_outputs, device=device) params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] for param in params: param.requires_grad_(True) return params

这段代码就是 GRU 从零实现的起点。


5. 为什么这里有三组“输入-隐藏-偏置”参数

因为 GRU 要分别计算三类东西:

第一组:更新门参数

W_xz, W_hz, b_z

它们负责控制“旧状态保留多少”。

第二组:重置门参数

W_xr, W_hr, b_r

它们负责控制“旧状态在候选状态里参与多少”。

第三组:候选隐藏状态参数

W_xh, W_hh, b_h

它们负责生成新的候选状态。

所以,GRU 比 RNN 参数更多,不是因为它“乱加复杂度”,
而是因为它确实要分别处理三种不同功能。


6. 隐藏状态初始化和 RNN 一样吗

基本一样。

因为 GRU 最终对外仍然只有一个隐藏状态H_t
不像 LSTM 还会多一个单独记忆单元。

所以状态初始化通常仍然写成:

def init_gru_state(batch_size, num_hiddens, device): return (torch.zeros((batch_size, num_hiddens), device=device), )

也就是说:

  • 每个样本一份隐藏状态

  • 初始时全零

  • 返回成元组形式,方便接口统一


7. GRU 的前向传播是这一节最核心的代码

常见从零实现写法如下:

def gru(inputs, state, params): W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params H, = state outputs = [] for X in inputs: Z = torch.sigmoid(torch.mm(X, W_xz) + torch.mm(H, W_hz) + b_z) R = torch.sigmoid(torch.mm(X, W_xr) + torch.mm(H, W_hr) + b_r) H_tilde = torch.tanh(torch.mm(X, W_xh) + torch.mm(R * H, W_hh) + b_h) H = Z * H + (1 - Z) * H_tilde Y = torch.mm(H, W_hq) + b_q outputs.append(Y) return torch.cat(outputs, dim=0), (H,)

如果你把这段代码真正看明白,GRU 就已经学通了大半。


8. 更新门这两行代码怎么理解

先看:

Z = torch.sigmoid(torch.mm(X, W_xz) + torch.mm(H, W_hz) + b_z)

它对应的就是更新门公式:

Z_t = σ(X_t W_xz + H_{t-1} W_hz + b_z)

含义是:

  • 当前输入X

  • 和上一隐藏状态H

  • 共同决定当前时间步“该保留多少旧状态”

Z的形状通常是:

(batch_size, num_hiddens)

也就是说,对每个样本、每个隐藏单元,都有一个门控值。

这意味着:

GRU 不是粗暴地对整个状态“一刀切”,而是按隐藏单元逐个控制。


9. 重置门代码怎么理解

再看:

R = torch.sigmoid(torch.mm(X, W_xr) + torch.mm(H, W_hr) + b_r)

它对应:

R_t = σ(X_t W_xr + H_{t-1} W_hr + b_r)

它决定的是:

在计算候选隐藏状态时,旧隐藏状态该参与多少。

你可以把它理解成一种“历史清洗器”:

  • R大,说明历史信息还很重要

  • R小,说明历史信息该弱化一些


10. 候选隐藏状态代码怎么理解

这是 GRU 最关键的一步之一:

H_tilde = torch.tanh(torch.mm(X, W_xh) + torch.mm(R * H, W_hh) + b_h)

它对应:

H_t_tilde = tanh(X_t W_xh + (R_t ⊙ H_{t-1}) W_hh + b_h)

要点在于:

不是直接用H

而是先做:

R * H

这表示:

旧隐藏状态先被重置门筛一遍,再参与候选状态生成。

这就是 GRU 相比基础 RNN 非常关键的精细化控制。


11. 最终隐藏状态更新代码怎么理解

看这一句:

H = Z * H + (1 - Z) * H_tilde

这就是:

H_t = Z_t ⊙ H_{t-1} + (1 - Z_t) ⊙ H_t_tilde

直观上它就是在做:

旧状态和新候选状态的加权平均

如果:

  • Z大,则更偏向旧状态

  • Z小,则更偏向新候选状态

所以最终隐藏状态不是“全用旧的”也不是“全用新的”,
而是模型自己学出来的动态折中。

这正是 GRU 最大的魅力所在。


12. 为什么说 GRU 的记忆更可控

从这几行代码里你就能直接看出来:

基础 RNN 的状态更新基本是一条路:

  • 输入来

  • 历史来

  • 一起过tanh

  • 更新完成

而 GRU 明显多了控制环节:

  • 先决定历史该参与多少

  • 再生成候选状态

  • 再决定最终保留多少旧状态

所以 GRU 本质上是一种:

可学习的信息流控制机制

而不是单纯“算出一个新状态”而已。


13. 输出层为什么和 RNN 一样

注意这一句:

Y = torch.mm(H, W_hq) + b_q

这和基础 RNN 完全一样。

为什么?

因为 GRU 改进的是:

隐藏状态的内部更新机制

而不是语言模型最终输出的形式。

对于字符级语言模型来说,最后仍然是:

  • 当前隐藏状态

  • 映射到词表空间

  • 得到对每个字符的打分

所以输出层部分不需要大改。


14. 从零实现的 GRU 如何封装成模型类

和上一节 RNN 一样,通常也会封装一个“手写模型容器”。

例如继续用之前那种思路:

net = d2l.RNNModelScratch(vocab_size, num_hiddens, device, get_params, init_gru_state, gru)

注意这非常漂亮。

你会发现:

  • 模型容器类几乎不用改

  • 只需要把

    • 参数初始化函数

    • 状态初始化函数

    • 前向传播函数
      换成 GRU 版本

整个模型就从 RNN 变成了 GRU。

这说明什么?

说明 GRU 和 RNN 在“接口层面”是一致的,
真正变化的是内部单元


15. 简洁实现:PyTorch 里的nn.GRU

从零实现看懂以后,简洁实现就很好接受了。

PyTorch 已经封装好了:

nn.GRU

基本用法和nn.RNN非常相似。

例如:

gru_layer = nn.GRU(input_size=vocab_size, hidden_size=num_hiddens)

这里:

  • input_size仍然是每个时间步输入向量维度

  • hidden_size仍然是隐藏状态维度

然后前向传播接口也很像:

Y, state = gru_layer(X, state)

输出逻辑和nn.RNN一样:

  • Y:所有时间步输出

  • state:最后隐藏状态


16. 为什么nn.GRU用法和nn.RNN这么像

因为从更高层视角看,它们都属于:

循环序列建模模块

它们对外接口基本一致:

  • 输入序列

  • 初始状态

  • 输出序列

  • 最终状态

只不过内部单元计算方式不同:

nn.RNN

内部是基础循环单元。

nn.GRU

内部是门控循环单元。

所以你可以把 GRU 看成是:

RNN API 体系下的一个更强单元版本

这对工程使用非常友好。


17. 简洁实现里的模型封装怎么写

和简洁版 RNN 一样,通常可以继续复用同样的模型外壳:

  • GRU 层负责序列递推

  • 线性层负责把隐藏状态映射到词表空间

也就是说,模型代码结构上几乎不用大动,
只要把:

rnn_layer = nn.RNN(...)

换成:

gru_layer = nn.GRU(...)

就能得到 GRU 版本模型。

这说明:

GRU 的升级主要体现在循环单元内部,不在外围框架。


18. GRU 代码和 RNN 代码最本质的区别在哪里

这是这一节最该点破的地方。

RNN 从零实现

核心只有一条隐藏状态更新公式。

GRU 从零实现

核心变成了:

  • 先算更新门

  • 再算重置门

  • 再算候选状态

  • 再融合新旧状态

所以如果你问:

GRU 相比 RNN,代码上本质多了什么?

答案就是:

多了门控变量和基于门控的状态融合机制

其他外部流程,例如:

  • one-hot 输入

  • 初始状态

  • 输出层

  • 文本生成

其实都还是同一套套路。


19. GRU 训练时需要额外特殊处理吗

整体训练流程和 RNN 基本一致。

仍然是:

  • 输入 token 序列

  • 输出对下一个 token 的预测

  • 用交叉熵损失

  • 反向传播

  • 梯度裁剪

  • 参数更新

所以从训练框架角度,GRU 并没有引入额外特别陌生的流程。
变化的是内部状态更新更智能了。

这也是为什么在工程里,GRU 往往可以很平滑地替代基础 RNN。


20. 为什么 GRU 常常比基础 RNN 更实用

从代码层面你已经能看出来原因:

第一,它可以显式保留旧状态

更新门允许旧信息直接穿过去。

第二,它可以有选择地忽略部分历史

重置门让模型不会被无关旧信息拖累。

第三,它改进了梯度传播路径

虽然不能彻底消除所有问题,但比基础 RNN 更容易训练稳定。

所以很多时候,GRU 是一种:

在复杂度和性能之间比较均衡的循环结构


21. 这一节最该掌握什么

如果从学习重点来看,这一节最关键的是下面几件事。

21.1 看懂参数初始化比 RNN 多在哪

知道为什么 GRU 至少需要三组核心参数。

21.2 看懂四条核心公式如何一一落到代码

尤其是:

  • Z

  • R

  • H_tilde

  • H

之间的关系。

21.3 理解最终状态更新是“新旧信息加权融合”

这是 GRU 相比 RNN 的本质增强。

21.4 知道nn.GRUnn.RNN的接口非常接近

方便后面工程应用。

21.5 明白 GRU 的改进点主要发生在单元内部

外围训练逻辑其实变化不大。


22. 本节总结

这一节我们学习了 GRU 的代码实现,核心内容可以总结为以下几点。

22.1 GRU 从零实现比 RNN 多了门控参数

主要包括:

  • 更新门参数

  • 重置门参数

  • 候选状态参数

22.2 GRU 前向传播的关键是先算门,再更新状态

这让信息流更加可控。

22.3 最终隐藏状态是旧状态和候选状态的门控融合

而不是像基础 RNN 那样一次性混合更新。

22.4 简洁实现里nn.GRUnn.RNN用法高度相似

只不过内部计算更强。

22.5 GRU 是基础 RNN 向更强序列建模迈出的重要一步

也是后面 LSTM 的前置基础。


23. 学习感悟

这一节很有意思,因为你会第一次非常明显地感受到:

一个模型的提升,不一定来自“推翻重来”,也可能来自对信息流路径做精细控制。

GRU 其实没有把循环神经网络完全重写,
它只是问了两个更聪明的问题:

  • 以前的信息还值不值得保留?

  • 新的信息该不该立刻写进去?

就是这两个问题,让它比基础 RNN 更能“记事”。

从这个角度看,GRU 的优雅之处,不在于公式复杂,
而在于它让记忆管理第一次变得真正有策略。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/14 18:12:20

机器学习模型部署专家:职业蓝海揭秘

当测试遇见模型部署在AI工业化落地的浪潮中,机器学习模型部署正成为技术生态的关键枢纽。据《2025全球AI工程化报告》显示,85%的AI项目因部署环节失效未能产生商业价值,而精通部署的工程师缺口高达72万人。对于软件测试从业者而言&#xff0c…

作者头像 李华
网站建设 2026/4/14 18:12:11

第8篇:嵌入式芯片内存架构详解:SRAM_Flash_Cache与外部存储的层级设计

引言:内存架构是决定嵌入式芯片实时性能与功耗的关键瓶颈 嵌入式系统广泛应用于可穿戴设备、工业机器人、车载控制、物联网终端等各类场景,其核心需求高度聚焦于实时响应、低功耗与高可靠性,而内存架构正是决定这些核心指标的关键瓶颈。与通用…

作者头像 李华
网站建设 2026/4/14 18:06:09

Ubuntu桌面应用开机自启动终极指南:从.desktop配置到环境变量设置

Ubuntu桌面应用开机自启动终极指南:从.desktop配置到环境变量设置 在Ubuntu桌面环境中,让应用程序随系统启动自动运行是提升工作效率的常见需求。无论是开发工具、监控程序还是日常生产力软件,合理的自启动配置都能让我们省去每次手动打开的麻…

作者头像 李华
网站建设 2026/4/14 18:05:33

APKMirror客户端:3个理由让你告别繁琐的安卓应用下载

APKMirror客户端:3个理由让你告别繁琐的安卓应用下载 【免费下载链接】APKMirror 项目地址: https://gitcode.com/gh_mirrors/ap/APKMirror 你是否曾因Google Play商店的版本延迟而苦恼?是否在寻找某个应用的历史版本时感到无从下手?…

作者头像 李华