news 2026/5/30 17:10:40

从0到97%!PyTorch MLP实现MNIST手写数字分类全攻略

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
从0到97%!PyTorch MLP实现MNIST手写数字分类全攻略

本文记录了从PyTorch入门到成功实现MNIST手写数字分类(准确率达97%+)的完整过程,涵盖数据集认知、代码搭建、核心疑问解答及训练优化等关键环节,适合刚接触PyTorch的新手参考学习。

一、开篇:MNIST数据集与核心目标

我们的核心任务是用PyTorch搭建多层感知器(MLP),完成MNIST手写数字的分类任务。首先得明确MNIST数据集的基础属性——这是一个经典的手写数字识别数据集,包含60000张训练图和10000张测试图,所有图片均为28×28像素的灰度图,标签为0-9的整数(对应手写数字)。

新手常见第一个疑问:下载数据集后为什么会有多个文件?其实这些是MNIST的原始压缩包(.gz后缀)和解压后的原始文件,分为训练集/测试集的图片文件和标签文件,属于PyTorch自动下载解压的正常现象,无需手动处理。

二、核心代码搭建:从数据加载到模型训练

我们逐步搭建了完整的训练流程,从库导入到最终训练测试,每一步都针对新手常见疑问做了详细解析。

2.1 完整可运行代码(优化后最终版)

# 导入所需库 import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms # 加载数据时修改transform,增加Normalize transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) # MNIST的均值和标准差(官方推荐值) ]) # 1. 加载MNIST数据集 training_data = datasets.MNIST( root="data", train=True, download=True, # transform=transforms.ToTensor() transform=transform ) test_data = datasets.MNIST( root="data", train=False, download=True, # transform=transforms.ToTensor() transform=transform ) # 2. 创建DataLoader(数据加载器) batch_size = 64 train_dataloader = DataLoader(training_data, batch_size=batch_size,shuffle=True) test_dataloader = DataLoader(test_data, batch_size=batch_size) # 3. 判断设备(CPU/GPU) device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using {device} device") # 4. 定义神经网络模型的类 class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() # 展开28×28的图片(官方图片集固定的28*28像素)为一维向量 self.hidden1 = nn.Linear(28 * 28, 128) # 输入层→隐藏层1,本质是矩阵运算,这里我们的批次是64,因此实际进行64行28*28列作为矩阵传入第一层计算 self.hidden2 = nn.Linear(128, 256) # 隐藏层1→隐藏层2 self.out = nn.Linear(256, 10) # 隐藏层2→输出层(10类数字) # 前向传播(方法名规范为forward,小写) def forward(self, x): x = self.flatten(x) x = self.hidden1(x) # x = torch.sigmoid(x) # 激活函数 x = torch.relu(x) # 替换sigmoid为relu x = self.hidden2(x) # x = torch.sigmoid(x) # 激活函数 x = torch.relu(x) # 替换sigmoid为relu x = self.out(x) return x '''neural神经 batch批次''' # 初始化类并移动到设备 model = NeuralNetwork().to(device) print(model) # 5. 定义训练函数 def train(dataloader, model, loss_fn, optimizer): model.train() # 切换到训练模式 batch_size_num = 1 # 统计batch数量 for X, y in dataloader: # dataloader是函数形参,调用时传入train_dataloader/test_dataloader;迭代返回批次级数据:X为批次图片张量(shape[batch_size,1,28,28]),y为批次标签张量(shape[batch_size]),对应批次内所有样本的图片和标签 # 数据移动到设备 X, y = X.to(device), y.to(device) # 前向传播计算预测值 pred = model(X) # 可省略.forward,model(X)会自动调用forward loss = loss_fn(pred, y) # 计算损失 # 反向传播更新参数 optimizer.zero_grad() # 梯度清零 loss.backward() # 反向传播计算梯度(原代码中Loss大写,修正为loss) optimizer.step() # 更新模型参数 # 每100个batch打印一次损失 loss_value = loss.item() if batch_size_num % 100 == 0: print(f"Loss: {loss_value:>7f} [batch: {batch_size_num}]") batch_size_num += 1 # 6. 定义测试函数 def test(dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集总样本数 num_batches = len(dataloader) # 测试集batch数量 model.eval() # 切换到测试模式 test_loss, correct = 0, 0 # 测试时关闭梯度计算,节省资源 with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() # 累计损失 # 计算正确预测数(取预测最大值对应的索引) correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 计算平均损失和准确率 test_loss /= num_batches correct /= size print(f"Test result: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") # 7. 初始化损失函数和优化器 loss_fn = nn.CrossEntropyLoss() # 交叉熵损失(适用于分类任务) # optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # SGD优化器 # 原代码的optimizer替换为Adam optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Adam默认lr=0.001即可 # 8. 执行训练和测试 # print("开始训练:") # train(train_dataloader, model, loss_fn, optimizer) # 训练一轮 # 原代码只调用了1次train,改成循环10轮 epochs = 10 # 训练10轮 for t in range(epochs): print(f"\n训练轮数 {t+1}/{epochs}") train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model, loss_fn) # 每轮训练后测试 print("训练完成!") print("开始测试:") test(test_dataloader, model, loss_fn)

运行结果:

2.2 关键代码解析(新手必看)

(1)datasets.MNIST:数据集加载神器

datasets.MNIST是PyTorch封装好的数据集加载工具,核心作用是跳过“下载-解压-解析”的繁琐步骤,直接得到可使用的Dataset对象。关键参数说明:

参数名

核心作用

常用取值

root

指定数据集保存/读取目录

root="data"(当前目录下创建data文件夹)

train

选择加载训练集/测试集

train=True(训练集)、train=False(测试集)

download

本地无数据时自动下载

首次使用设为True,后续设为False

transform

图片预处理

转Tensor+标准化(提升训练效果)

(2)DataLoader:批量训练的核心

代码:train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

核心作用是将零散的Dataset样本打包成批次张量,适配神经网络批量训练逻辑:

  • batch_size=64:每批包含64个样本,生成形状为[64,1,28,28]的批次图片张量和[64]的批次标签张量;

  • shuffle=True:训练前打乱样本顺序,避免模型记住样本顺序导致泛化能力差;

  • 迭代使用:通过for X, y in dataloader遍历,X为批次图片,y为对应批次标签(形参dataloader可接收train_dataloader或test_dataloader)。

(3)批次张量:高效训练的关键

批次张量是将多个样本的图片/标签合并成的大张量,核心优势是利用矩阵运算实现批量计算:

  • 单张图片张量:[1,28,28](1通道+28×28像素);

  • 批次图片张量(64个样本):[64,1,28,28](新增批次维度);

  • 优势:神经网络通过一次矩阵运算就能完成64张图片的前向传播,效率远高于逐个计算。

(4)nn.Module:神经网络的骨架

自定义模型必须继承nn.Module,它是PyTorch所有神经网络模块的基类,提供参数管理、GPU适配、前向传播支持等核心功能(并非具体模型):

  • __init__方法:初始化网络层(如nn.Flatten展平层、nn.Linear全连接层),需调用super().__init__()启用父类功能;

  • forward方法:定义前向传播逻辑,调用model(X)时会自动执行,无需手动调用。

(5).item():标量张量转普通数值

训练中计算的损失loss是标量张量(shape=[]),无法直接用于Python格式化输出,通过loss.item()可剥离张量的设备、梯度等额外信息,转成普通float数值。

三、训练优化:从30%到97%的关键调整

初始训练准确率仅30%左右,通过以下关键调整,第六轮训练准确率就达到97%+:

  1. 替换激活函数:将sigmoid改为ReLU,解决深层网络梯度消失问题,让模型快速收敛;

  2. 增加训练轮数:从1轮改为10轮,让模型充分学习数据特征;

  3. 训练数据打乱:在DataLoader中添加shuffle=True,提升模型泛化能力;

  4. 优化器调整:从SGD改为Adam(自适应学习率),加速收敛;

  5. 数据标准化:添加transforms.Normalize,让模型更快收敛。

四、核心疑问解答(新手避坑指南)

Q1:为什么img, label = training_data[i]能将一个值拆成两个变量?

training_data[i]返回的是包含两个元素的元组(图片数据+标签),这是Python的元组解包语法,会自动将元组中的两个值按顺序赋值给img和label,等价于img = training_data[i][0]label = training_data[i][1]

Q2:怎么确认MNIST图片是28×28像素?

28×28是MNIST数据集的官方固有属性,可通过代码验证:取单个样本img = training_data[0][0],打印img.shape会输出torch.Size([1,28,28]),明确显示图片尺寸为28×28。

Q3:输入层接收28×28特征,怎么同时训练64张图片?

通过nn.Flatten()将64张图片的[64,1,28,28]张量展平为[64,784](784=28×28),输入层nn.Linear(784,128)通过矩阵运算(64×784的输入矩阵 × 784×128的权重矩阵),一次性完成64张图片的特征转换,实现批量训练。

Q4:nn.Module和感知器的关系?

感知器是单个神经元的计算逻辑(如nn.Linear层的一个神经元),而nn.Module是承载感知器/网络层的骨架,提供参数管理、GPU计算等运行能力,两者是“被承载者”和“承载框架”的关系,并非等价关系。

五、总结与成果

本次从MNIST数据集认知开始,逐步完成了数据加载、模型搭建、训练测试全流程,解决了新手常见的语法疑问和训练优化问题,最终通过多层感知器实现了97%+的分类准确率。核心收获:

  • 掌握PyTorch核心工具(datasets.MNIST、DataLoader、nn.Module)的使用逻辑;

  • 理解批量训练、批次张量的核心原理;

  • 掌握神经网络训练优化的关键技巧(激活函数、优化器、数据预处理等)。

对于刚接触PyTorch的新手来说,从疑问重重到实现高准确率训练,每一步都是成长。后续可尝试调整batch_size、学习率等参数,进一步优化模型性能,或探索CNN等更复杂的网络结构在MNIST上的表现。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/28 15:31:30

Open-AutoGLM微信自动化实战教程(从零到精通必看)

第一章:Open-AutoGLM微信自动化实战概述Open-AutoGLM 是一个基于大语言模型与自动化控制技术的开源框架,专为实现微信客户端的智能化操作而设计。它结合了自然语言理解能力与图形用户界面(GUI)自动化技术,能够在无需人…

作者头像 李华
网站建设 2026/5/28 19:45:57

Java毕设项目推荐-基于javaweb校园兼职招聘系统的设计与实现基于JavaWeb的校园招聘管理系统简历投递管理【附源码+文档,调试定制服务】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

作者头像 李华
网站建设 2026/5/29 21:43:48

算法题 括号的分数

856. 括号的分数 问题描述 给定一个平衡括号字符串 s,按下述规则计算该字符串的分数: () 得 1 分AB 得 A B 分,其中 A 和 B 是平衡括号字符串(A) 得 2 * A 分,其中 A 是平衡括号字符串 返回字符串 s 的分数。 示例&#xff…

作者头像 李华
网站建设 2026/5/28 23:33:23

Java计算机毕设之基于JavaWeb的校园招聘管理系统高校校园招聘信息服务系统 (完整前后端代码+说明文档+LW,调试定制等)

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

作者头像 李华
网站建设 2026/5/29 0:28:40

温度影响精度?高精度模拟量采集模块适配攻略来了

高精度模拟量采集模块的核心功能是将温度、湿度、压力、电流等物理量转换为标准模拟信号(如4-20mA、0-10V)并精准采集,其应用环境温度直接影响采集精度、稳定性和使用寿命。一、应用场景 高精度模拟量采集模块的应用环境温度需与模块自身工作温度范围匹配&#xff0…

作者头像 李华
网站建设 2026/5/29 22:08:58

ERP企业资源管理系统代码(Java)

1. 仓库管理模块物料分类:采用ABC/XYZ多维分类法,结合物料属性与消耗规律,建立动态管理档案。追溯管控:通过条码/RFID技术实现全生命周期追溯,支持批次号与保质期管理。作业优化:WMS系统智能分配库位&#…

作者头像 李华