news 2026/5/26 22:42:46

MNIST 入门实战:从数据流到模型训练与评估(含完整代码与流程图)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
MNIST 入门实战:从数据流到模型训练与评估(含完整代码与流程图)

📺B站:博主个人介绍

📘博主书籍-京东购买链接*:Yocto项目实战教程

📘加博主微信,进技术交流群jerrydev


MNIST 入门实战:从数据流到模型训练与评估(含完整代码与流程图)

目标:用 MNIST 把深度学习最核心的一条主线走通——数据 → 模型 → 损失 → 反向传播 → 更新参数 → 评估 → 保存/加载 → 推理。读完你会对“训练到底在做什么”“评估到底评估了什么”“.pth 到底是什么”形成清晰、稳定的概念框架,并具备继续突破到更复杂任务(CIFAR、检测、分割、YOLO、Transformer)的基础。


1. 你已经做到了什么?为什么这条路径很重要

你已经跑通了一个完整的 CNN 分类项目:

  • 数据:MNIST
  • 模型:两层卷积 + 两层全连接(带 Dropout)
  • 训练:CrossEntropyLoss + SGD
  • 评估:test loss + test acc
  • 保存:生成mnist_cnn.pth

并且得到了非常合理的结果:

  • test acc ≈ 97%+
  • test loss 逐步下降

这说明:

  1. 数据流(DataLoader)没问题;
  2. 模型结构(forward)没问题;
  3. 训练闭环(loss → backward → step)成立;
  4. 评估逻辑(eval + no_grad + 统计)正确;
  5. 参数存档(.pth)可复用。

这条路径是所有 AI 任务的“地基”。以后你换数据集、换模型、换任务,本质上还是这条主线,只是每一段更复杂。


2. 全流程总览:一张逻辑图把它串起来

下面这张“主线流程图”非常建议你记住:

┌──────────────┐ │ 数据集 Dataset │ MNIST(train/test) └──────┬───────┘ │ transform(ToTensor/Normalize) ▼ ┌──────────────┐ │ DataLoader │ batch化、shuffle、并行加载 └──────┬───────┘ │ (data:[B,1,28,28], target:[B]) ▼ ┌──────────────┐ │ 模型 Net │ forward: [B,1,28,28] → logits[B,10] └──────┬───────┘ │ ▼ ┌──────────────┐n│ Loss 函数 │ CrossEntropyLoss(logits, target) └──────┬───────┘ │ ▼ ┌──────────────┐ │ backward() │ 自动求梯度: p.grad └──────┬───────┘ │ ▼ ┌──────────────┐ │ optimizer.step│ 更新参数: θ ← θ - lr * grad └──────┬───────┘ │ ▼ ┌──────────────┐ │ 评估 Evaluate │ net.eval + no_grad + 指标统计 └──────┬───────┘ │ ▼ ┌──────────────┐ │ 保存/加载 .pth │ state_dict() / load_state_dict() └──────────────┘

你可以把它理解为:

  • Net是“会计算的结构”;
  • Loss把“对不对”变成“可优化的数字”;
  • Backward给每个参数算出“该往哪改”;
  • Optimizer真正“改参数”;
  • Evaluate用测试集验证“改得值不值”。

3. 数据部分:Dataset / Transform / DataLoader 到底做了什么

3.1 MNIST 是什么数据?

MNIST 是手写数字识别数据集:

  • 图片:28×28 灰度图(单通道)
  • 标签:0~9 十个类别
  • 训练集:60000
  • 测试集:10000

在 PyTorch 里,一条样本通常表现为:

  • image:torch.Tensor,形状[1,28,28]
  • label:int,范围 0…9

3.2 为什么 DataLoader 出来是[B,1,28,28]

你训练时拿到的是一个 batch:

  • data.shape == [B, 1, 28, 28]
  • target.shape == [B]

四维的含义非常固定:

维度含义MNIST 示例
Bbatch_size32
C通道数1(灰度)
H高度28
W宽度28

只要你看到[B,C,H,W],你就知道它是“图片批次输入”。

3.3 transform:ToTensor + Normalize 为什么必做?

你的 transform 典型是:

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])
  • ToTensor():把图像从0~255的像素转换为0~1的 float 张量
  • Normalize(mean, std):做标准化:

[
x’ = \frac{x - mean}{std}
]

为什么要标准化?

  • 让输入分布更稳定;
  • 梯度更容易优化;
  • 训练更快、收敛更稳定。

你看到min/max出现负数,就是 Normalize 生效的直接证据。

3.4 你本地的 data/ 目录意味着什么?

你看到的:

data/MNIST/raw/ train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte

说明:

  • 数据源来自torchvision.datasets.MNIST(root='data/', download=True, ...)
  • 评估时使用train=False对应t10k-*(测试集);
  • 训练时使用train=True对应train-*(训练集)。

4. 模型部分:Net 是“结构”,参数是“记忆”

4.1 你写的 Net 到底是什么?

在 PyTorch 里:

  • class Net(nn.Module)定义了网络结构(哪些层、如何连接);
  • forward(self, x)定义了前向计算(输入如何流过这些层);
  • 模型参数(权重、偏置)被nn.Conv2d / nn.Linear等层自动创建。

一句话:

模型 = 结构(Net) + 参数(weights/bias)

4.2 CNN 为什么有效?卷积、池化、全连接各自负责什么

可以把 CNN 分为两段:

  • 特征提取器:Conv/Pool
  • 分类器:Linear

直觉理解:

  • Conv:在局部区域找“笔画、边缘、拐角”
  • Pool:让特征更稳、更省计算
  • FC:把多个特征组合,输出 10 类得分

4.3 形状追踪:你必须会“追输出形状”

以你常用结构为例:

输入:[B,1,28,28]

  • conv1(1→8, k=3):[B,8,26,26]
  • maxpool(2):[B,8,13,13]
  • conv2(8→16, k=3):[B,16,11,11]
  • maxpool(2):[B,16,5,5]
  • flatten:[B,400](因为 1655=400)
  • fc1:[B,40]
  • fc2:[B,10](logits)

只要你追得出 400 的来源,你对 CNN 的理解就非常稳。

4.4 logits 是什么?为什么不是直接输出概率?

模型输出[B,10],这不是概率,而是 logits(打分)。

预测类别通常是:

pred=logits.argmax(dim=1)# [B]

训练时用CrossEntropyLoss(logits, target),它内部会做 softmax 相关处理,所以你不需要手写 softmax。


5. 训练部分:三件套把“计算结构”变成“会学习的系统”

5.1 训练三件套:loss / backward / optimizer

训练的核心闭环:

  1. forward:模型输出 logits
  2. loss:把“错得程度”变成一个数字
  3. backward:求每个参数的梯度(该往哪改)
  4. step:更新参数(真的改)

最简骨架:

optimizer.zero_grad()logits=net(data)loss=criterion(logits,target)loss.backward()optimizer.step()

5.2 为什么一定要 zero_grad()

PyTorch 默认会把梯度累加;如果不清零,上一个 batch 的梯度会污染当前 batch。

你可以把它理解为:

  • 每个 batch 都要单独算一次“该怎么改”;
  • 不能把历史梯度一直加下去。

5.3 net.train() 与 net.eval() 的区别(非常关键)

这不是“可有可无”。尤其你用了 Dropout。

  • net.train():训练模式,Dropout 生效
  • net.eval():评估模式,Dropout 关闭(输出更稳定)

你出现 test acc 比 train acc 高一点,在带 Dropout 的训练里不罕见:训练时更难,评估时更容易。

5.4 为什么 loss 会从 2.3 降下来?

在 10 类分类里,模型完全随机时:

  • 每类概率大约 1/10
  • 交叉熵大约 (\ln(10) ≈ 2.3026)

所以你最开始看到 loss ≈ 2.3 非常合理。

训练后,模型对正确类别给更高分,loss 就下降。


6. 评估部分:你到底评估了什么?关键函数和关键变量

你目前做的基础评估非常正确:

  • 平均 loss(test avg loss)
  • 准确率(test accuracy)

评估的关键点:

6.1 评估数据从哪里来?

评估使用的是MNIST 测试集

test_ds=torchvision.datasets.MNIST('data/',train=False,...)

对应你目录里的:

  • t10k-images-idx3-ubyte
  • t10k-labels-idx1-ubyte

6.2 评估为什么要 no_grad()?

评估只需要 forward,不需要 backward。

torch.no_grad()的价值:

  • 更快
  • 更省显存
  • 更不容易内存爆

6.3 评估最核心的统计逻辑是什么?

  • logits:net(data)
  • loss:criterion(logits, target)
  • pred:logits.argmax(dim=1)
  • correct:(pred == target).sum()

最终:

  • avg_loss = total_loss / total_samples
  • acc = total_correct / total_samples

7. 训练日志该怎么读?你这份结果说明了什么

你的输出:

Epoch 1: train loss=0.6006, train acc=80.70% test loss=0.1068, test acc=96.75% Epoch 2: train loss=0.3271, train acc=89.95% test loss=0.0856, test acc=97.34% Epoch 3: train loss=0.2842, train acc=91.26% test loss=0.0718, test acc=97.60%

可以简单做三点结论:

  1. loss 在下降:说明优化器在持续把模型往正确方向推;
  2. acc 在上升:说明预测正确率提升;
  3. test acc 达到 97%+:说明模型已经学到 MNIST 的主特征;

你的模型在 3 个 epoch 已经接近“够用”。如果继续训练可能还会涨,但收益会变小。


8.mnist_cnn.pth到底是什么?怎么理解它的“类型”

8.1.pth不是模型结构,而是模型参数

你保存的是:

torch.save(net.state_dict(),'mnist_cnn.pth')
  • state_dict()是一个“参数字典”:键是参数名,值是张量
  • .pth只是常用后缀名,表示 PyTorch 的保存文件

一句话:

mnist_cnn.pth保存的是你训练得到的“记忆”(权重/偏置)

8.2 为什么必须同时有 Net 才能用.pth

.pth本身不包含forward的代码逻辑。

可用模型 =Net()+load_state_dict(mnist_cnn.pth)

8.3 你可以用一行看它内部是什么

importtorch sd=torch.load('mnist_cnn.pth',map_location='cpu')print(type(sd))print(list(sd.keys())[:10])

你会看到:它是 dict,里面有conv1.weightfc2.bias等键。


9. 一份“初学者最推荐”的工程组织方式

你现在的结构已经很健康:

jerry_mnist/ create_model.py # 模型结构 get_data.py # 数据准备(可选) train_model.py # 训练 + 评估 + 保存 eval_basic.py # (建议新增)单独评估 mnist_cnn.pth # 训练结果参数 data/ # MNIST 数据缓存

推荐原则:

  • 模型结构独立文件(便于复用)
  • 训练脚本只管训练
  • 评估脚本只管评估
  • 推理脚本只管推理

这样你后面做 CIFAR10、做猫狗分类、做 YOLO,都可以复用思路。


10. 代码模板:最简训练脚本(带注释)

下面是一份“看一眼就懂”的训练框架(与你现在的逻辑一致):

# train_model.py(结构示意)# 1) 准备 DataLoader(train_loader / test_loader)# 2) 创建模型 net = Net().to(device)# 3) 定义 loss + optimizer# 4) 循环 epoch:# 4.1 net.train()# 4.2 对 train_loader:zero_grad → forward → loss → backward → step# 4.3 net.eval() + no_grad()# 4.4 对 test_loader:forward → 统计 loss/acc# 5) 保存参数 torch.save(net.state_dict(), 'mnist_cnn.pth')

你要掌握的是“骨架”,以后换任何任务都能套进去。


11. 评估进阶:从两个指标走向“知道错在哪”

你现在评估了 loss/acc,这已经是基础合格。

接下来想提高理解和实战能力,建议加三种评估:

  1. 混淆矩阵:看哪些数字互相混淆
  2. 每类准确率:哪个类最弱
  3. 错误样本可视化:错的样本长什么样

这三件事会让你从“知道结果不错”升级到“知道怎么继续提升”。


12. 从 MNIST 走向更大突破:下一步练什么最有效

当你把 MNIST 这一套跑稳后,推荐你按这个顺序升级:

12.1 升级数据:CIFAR-10

  • 32×32 彩色图(3 通道)
  • 更贴近真实视觉任务
  • 你会遇到:数据增强、过拟合、模型更深

12.2 升级模型:更规范的 CNN(BatchNorm、更多层)

  • 加 BatchNorm 稳定训练
  • 更深的网络、残差结构(ResNet)

12.3 升级任务:目标检测(YOLO)

  • 输入输出不再是 10 类得分
  • 变为:框位置 + 类别 + 置信度

但注意:无论怎么升级,“主线流程图”依旧成立。


13. 关键知识点清单(建议你定期扫一遍)

数据

  • Dataset / DataLoader
  • [B,C,H,W]含义
  • transform:ToTensor / Normalize

模型

  • nn.Module/forward
  • Conv / Pool / Flatten / Linear
  • logits 与 argmax

训练

  • criterion(CrossEntropyLoss)
  • optimizer(SGD/Adam)
  • zero_grad → forward → loss → backward → step
  • train()vseval()

评估

  • no_grad()
  • avg loss / accuracy
  • 错误分析(混淆矩阵、错例)

保存

  • state_dict()/load_state_dict()
  • .pth是参数存档,不是结构

14. 你目前这份项目的一句话“专业总结”

你已经完成了一个可复用的 PyTorch 图像分类最小工程:

  • 使用torchvision.datasets.MNIST构建训练/测试数据流;
  • 使用自定义 CNN(两层卷积 + 两层全连接 + Dropout)进行分类;
  • 使用CrossEntropyLoss + SGD完成训练闭环并达到97%+测试准确率;
  • 使用state_dict()将训练得到的参数保存为mnist_cnn.pth,可在任意环境中通过load_state_dict()复现推理效果。

15. 附:你可以直接复制的“推理脚本”骨架(可选)

如果你想快速验证.pth的价值:

# predict_one_batch.pyimporttorchimporttorchvisionfromtorch.utils.dataimportDataLoaderfromtorchvisionimporttransformsfromcreate_modelimportNet device='cuda'iftorch.cuda.is_available()else'cpu'# 1) load modelnet=Net().to(device)net.load_state_dict(torch.load('mnist_cnn.pth',map_location=device))net.eval()# 2) test loadertransform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])test_ds=torchvision.datasets.MNIST('data/',train=False,download=True,transform=transform)test_loader=DataLoader(test_ds,batch_size=32,shuffle=True)data,target=next(iter(test_loader))data=data.to(device)withtorch.no_grad():logits=net(data)pred=logits.argmax(dim=1).cpu()print('pred[:10] =',pred[:10].tolist())print('target[:10]=',target[:10].tolist())

结尾:你下一步应该做什么(最实用)

如果你想把基础打得更牢,建议你立刻做两个小任务:

  1. 写一个eval_basic.py:只输出 test loss / test acc(你已经理解清楚)
  2. 写一个eval_confusion.py:输出混淆矩阵 + 错例图(知道错在哪)

做完这两步,你会对“评估”不再停留在“跑出一个数字”,而是能解释为什么这样、如何继续提升


如果你希望我继续按“最小步骤”推进:下一步我会在不增加太多代码复杂度的前提下,带你实现混淆矩阵 + 错例可视化(这一步对形成实战直觉特别有帮助)。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/21 15:02:44

基于51单片机的智能浇花系统

基于51单片机的智能浇花系统设计 一、设计背景与意义 家庭园艺、阳台种植已成为日常休闲的重要方式,但传统人工浇花存在浇水时机不精准、外出无人照料、水量控制不当等问题,易导致花卉缺水枯萎或积水烂根。现有智能浇花系统多依赖物联网平台与高端控制…

作者头像 李华
网站建设 2026/5/22 18:14:50

大数据深度学习|计算机毕设项目|计算机毕设答辩|大数据多因子模型在股票投资策略中的实现

一、项目介绍 随着金融市场的发展和信息技术的进步,大数据多因子模型在股票投资策略中的应用日益广泛,为投资者提供了更为科学、高效的投资决策依据。本研究聚焦于大数据多因子模型在股票投资策略中的实现过程,旨在深入剖析该模型如何精准挖…

作者头像 李华
网站建设 2026/5/23 12:11:28

DNS劫持全解析:原理、危害与终极防护指南

一、核心定义:互联网的“电话簿”被篡改了 想象一下,互联网就像一本巨大的电话簿(DNS)。你想访问“百度”,不是直接输入复杂的IP地址(如 39.156.66.10),而是输入好记的域名 www.bai…

作者头像 李华
网站建设 2026/5/11 4:18:21

基于51单片机的门禁系统的研究与设计

基于51单片机的门禁系统的研究与设计 一、设计背景与意义 门禁系统是楼宇、办公区、小区等场景的核心安防设施,传统机械门禁存在易复制、安全性低、无使用记录等问题,而高端智能门禁系统依赖复杂的嵌入式平台与网络架构,成本高、部署难度大&a…

作者头像 李华
网站建设 2026/5/16 8:18:27

计算机毕业设计springboot4S店管理系统设计与实现 基于SpringBoot的汽车销售与售后服务一体化平台设计与实现

计算机毕业设计springboot4S店管理系统设计与实现gn093018 (配套有源码 程序 mysql数据库 论文) 本套源码可以在文本联xi,先看具体系统功能演示视频领取,可分享源码参考。 随着汽车产业的蓬勃发展和消费市场的持续升级,汽车4S店作…

作者头像 李华