news 2026/4/27 6:02:23

Pytorch基础——(3)神经网络工具箱

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Pytorch基础——(3)神经网络工具箱

文章目录

  • 一、基础知识
  • 二、构建模型
    • 1.1 方法1:继承nn.Model基类构建模型
    • 1.2 方法2:使用 nn.Sequential 容器
      • 1.2.1 添加参数
      • 1.2.2 add_module可指定名称
      • 1.2.2 orderedDict可指定名称
    • 1.3 结合1和2,集成基类并使用模拟容器
      • 1.3.1 使用nn.Sequential()
      • 1.3.2 使用ModuleList
      • 1.3.3 使用ModuleDict
  • 三、训练和评估模型
    • 3.1 训练和评估步骤
    • 3.2 Python代码

一、基础知识

torch.nn:nn 是 Neural Networks(神经网络) 的缩写。它就是一个工具库,里面包含了深度学习常用的所有零件建造神经网络的工具箱。nn 就是给你提供积木,让你搭神经网络。最常用的:

  • 层结构
    • nn.Linear(5, 1) 线性层(全连接层)
    • nn.Conv2d 卷积层
    • nn.ReLU / nn.Sigmoid 激活函数
  • 损失函数
    • nn.MSELoss() 回归损失
    • nn.CrossEntropyLoss() 分类损失
  • 其他工具
    • nn.Sequential 快速堆叠网络
    • nn.BatchNorm2d 归一化

nn.model:是所有模型的骨架 / 模板 / 父类,模型都必须继承 nn.Module,因为它帮你自动完成:

  • 自动管理参数 w、b
  • 自动求梯度(backward)
  • 自动更新参数
  • 自动保存 / 加载模型
  • 自动把模型搬到 GPU

nn.functional:nn.functional(简称 F),是 PyTorch 里的数学函数工具箱它里面全是无参数、纯计算的函数,专门用来做:

  • 激活计算(relu、sigmoid)
  • 损失计算(mse_loss)
  • 池化计算(max_pool)
  • 归一化、softmax、dropout 等

nn.model 里面的 nn.Xxx(nn.Linear)

  • nn.Xxx 继承于 nn.model,需要先实例化并传入参数,然后以函数调用的方式,调用实例对象,并传入输入数据
  • nn.Xxx不需要自己定义和管理 weight 和 bias 参数
importtorchimporttorch.nnasnn# 1. 实例化(创建层,并传入参数 w 和 b)layer=nn.Linear(5,1)# 2. 构造输入数据x=torch.randn(10,5)# 3. 把层当函数用,函数调用实例对象 layer, 传入输入数据y=layer(x)# 等价于 y = layer.forward(x)

nn.functinal里面的函数 nn.functional.xxx(nn.functional.linear)

  • nn.functional.xxx需要自己定义和管理 weight 和 bias 参数,每次调用的时候需要手动传入。

建议:

  • 具有学习参数的:必须用 nn.XXX,如 Linear, Conv2d, BatchNorm,不能用 nn.functional 代替,因为权重不会被自动管理。
  • 没有学习参数的参数,如 relu, pool, softmax,推荐用 nn.functional
importtorch.nn.functionalasF F.relu(x)# 激活函数F.sigmoid(x)F.tanh(x)F.max_pool2d(x)# 池化F.avg_pool2d(x)F.dropout(x)# 随机失活F.softmax(x,dim=1)

二、构建模型

如下图,采用不同的方式构建一下神经网络:

1.1 方法1:继承nn.Model基类构建模型

nn.Module是 PyTorch 中所有神经网络模块的基类,它提供了三大核心功能:

  • 参数管理:自动跟踪和管理网络中的可学习参数(权重、偏置等),方便优化器更新。
  • 子模块管理:自动注册和管理网络中的子层(如 nn.Linear、nn.Conv2d 等),支持递归遍历。
  • 前向传播接口:规定必须实现 forward 方法,定义数据在网络中的流动逻辑。

完整步骤:

  • 导入必要的库;
  • 定义一个类继承 nn.Module 并实现2个核心方法:
    • 在_init_ 方法中定义需要用到的层·
      • 调用父类 nn.Module 的初始化方法
      • 定义网络中用到的所有层(如全连接层、批归一化层等),并绑定为类的属性(self.xxx)。
    • 在 forward 方法中手动编写数据流动逻辑(前向传播)
      • 模型的输入数据为x
      • 定义数据 x 如何从输入经过各层处理,最终输出结果
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassMyModule(nn.Module):def__init__(self,indim,h1,h2,outdim):super(MyModule,self).__init__()self.flatten=nn.Flatten()self.Linear1=nn.Linear(indim,h1)self.bn1=nn.BatchNorm1d(h1)self.Linear2=nn.Linear(h1,h2)self.bn2=nn.BatchNorm1d(h2)self.out=nn.Linear(h2,outdim)defforward(self,x):x=self.flatten(x)x=self.Linear1(x)x=self.bn1(x)x=F.relu(x)# 用nn.functionalx=F.relu(self.bn2(x=self.Linear2(x)))# 全连接层2+批归一化2+激活层2x=self.out(x)x=F.softmax(x,dim=1)# 用nn.functionalreturnx model=MyModule(28*28,300,100,10)print(model)
MyModule((flatten): Flatten(start_dim=1,end_dim=-1)(layer1): Linear(in_features=784,out_features=300,bias=True)(bn1): BatchNorm1d(300,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)(layer2): Linear(in_features=300,out_features=100,bias=True)(bn2): BatchNorm1d(100,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)(out): Linear(in_features=100,out_features
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/27 6:01:51

React-antd-admin-template实战:如何快速定制个性化后台界面

React-antd-admin-template实战:如何快速定制个性化后台界面 【免费下载链接】react-antd-admin-template 一个基于ReactAntd的后台管理模版,在线预览https://nlrx-wjc.github.io/react-antd-admin-template/ 项目地址: https://gitcode.com/gh_mirror…

作者头像 李华
网站建设 2026/4/27 5:59:38

Phi-3-mini-4k-instruct-gguf惊艳效果展示:10个真实Prompt生成结果全公开

Phi-3-mini-4k-instruct-gguf惊艳效果展示:10个真实Prompt生成结果全公开 1. 模型简介 Phi-3-Mini-4K-Instruct是一个38亿参数的轻量级开源模型,采用GGUF格式提供。作为Phi-3系列的一员,这个模型经过精心训练,专注于高质量内容和…

作者头像 李华
网站建设 2026/4/27 5:59:37

SageMath开发环境搭建:从源码编译到自定义构建

SageMath开发环境搭建:从源码编译到自定义构建 【免费下载链接】sage Main repository of SageMath 项目地址: https://gitcode.com/gh_mirrors/sag/sage SageMath是一个功能强大的开源数学软件系统,集成了众多数学计算库和工具。本文将详细介绍如…

作者头像 李华
网站建设 2026/4/27 5:59:28

Venera漫画阅读器:打造你的跨平台数字漫画图书馆

Venera漫画阅读器:打造你的跨平台数字漫画图书馆 还在为分散在不同设备和平台的漫画资源而烦恼吗?Venera漫画阅读器正是你需要的解决方案!这款基于Flutter开发的跨平台应用,能够完美整合本地与网络漫画资源,为你提供一…

作者头像 李华
网站建设 2026/4/27 5:58:34

DeOldify开发者体验优化:CLI命令行工具封装与Tab补全支持

DeOldify开发者体验优化:CLI命令行工具封装与Tab补全支持 1. 项目背景与价值 如果你是一名开发者,经常需要处理黑白照片上色的任务,可能会遇到这样的困扰:每次都要打开浏览器、上传图片、等待处理、下载结果,这样的流…

作者头像 李华