本文记录了从PyTorch入门到成功实现MNIST手写数字分类(准确率达97%+)的完整过程,涵盖数据集认知、代码搭建、核心疑问解答及训练优化等关键环节,适合刚接触PyTorch的新手参考学习。
一、开篇:MNIST数据集与核心目标
我们的核心任务是用PyTorch搭建多层感知器(MLP),完成MNIST手写数字的分类任务。首先得明确MNIST数据集的基础属性——这是一个经典的手写数字识别数据集,包含60000张训练图和10000张测试图,所有图片均为28×28像素的灰度图,标签为0-9的整数(对应手写数字)。
新手常见第一个疑问:下载数据集后为什么会有多个文件?其实这些是MNIST的原始压缩包(.gz后缀)和解压后的原始文件,分为训练集/测试集的图片文件和标签文件,属于PyTorch自动下载解压的正常现象,无需手动处理。
二、核心代码搭建:从数据加载到模型训练
我们逐步搭建了完整的训练流程,从库导入到最终训练测试,每一步都针对新手常见疑问做了详细解析。
2.1 完整可运行代码(优化后最终版)
# 导入所需库 import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms # 加载数据时修改transform,增加Normalize transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) # MNIST的均值和标准差(官方推荐值) ]) # 1. 加载MNIST数据集 training_data = datasets.MNIST( root="data", train=True, download=True, # transform=transforms.ToTensor() transform=transform ) test_data = datasets.MNIST( root="data", train=False, download=True, # transform=transforms.ToTensor() transform=transform ) # 2. 创建DataLoader(数据加载器) batch_size = 64 train_dataloader = DataLoader(training_data, batch_size=batch_size,shuffle=True) test_dataloader = DataLoader(test_data, batch_size=batch_size) # 3. 判断设备(CPU/GPU) device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using {device} device") # 4. 定义神经网络模型的类 class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() # 展开28×28的图片(官方图片集固定的28*28像素)为一维向量 self.hidden1 = nn.Linear(28 * 28, 128) # 输入层→隐藏层1,本质是矩阵运算,这里我们的批次是64,因此实际进行64行28*28列作为矩阵传入第一层计算 self.hidden2 = nn.Linear(128, 256) # 隐藏层1→隐藏层2 self.out = nn.Linear(256, 10) # 隐藏层2→输出层(10类数字) # 前向传播(方法名规范为forward,小写) def forward(self, x): x = self.flatten(x) x = self.hidden1(x) # x = torch.sigmoid(x) # 激活函数 x = torch.relu(x) # 替换sigmoid为relu x = self.hidden2(x) # x = torch.sigmoid(x) # 激活函数 x = torch.relu(x) # 替换sigmoid为relu x = self.out(x) return x '''neural神经 batch批次''' # 初始化类并移动到设备 model = NeuralNetwork().to(device) print(model) # 5. 定义训练函数 def train(dataloader, model, loss_fn, optimizer): model.train() # 切换到训练模式 batch_size_num = 1 # 统计batch数量 for X, y in dataloader: # dataloader是函数形参,调用时传入train_dataloader/test_dataloader;迭代返回批次级数据:X为批次图片张量(shape[batch_size,1,28,28]),y为批次标签张量(shape[batch_size]),对应批次内所有样本的图片和标签 # 数据移动到设备 X, y = X.to(device), y.to(device) # 前向传播计算预测值 pred = model(X) # 可省略.forward,model(X)会自动调用forward loss = loss_fn(pred, y) # 计算损失 # 反向传播更新参数 optimizer.zero_grad() # 梯度清零 loss.backward() # 反向传播计算梯度(原代码中Loss大写,修正为loss) optimizer.step() # 更新模型参数 # 每100个batch打印一次损失 loss_value = loss.item() if batch_size_num % 100 == 0: print(f"Loss: {loss_value:>7f} [batch: {batch_size_num}]") batch_size_num += 1 # 6. 定义测试函数 def test(dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集总样本数 num_batches = len(dataloader) # 测试集batch数量 model.eval() # 切换到测试模式 test_loss, correct = 0, 0 # 测试时关闭梯度计算,节省资源 with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() # 累计损失 # 计算正确预测数(取预测最大值对应的索引) correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 计算平均损失和准确率 test_loss /= num_batches correct /= size print(f"Test result: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") # 7. 初始化损失函数和优化器 loss_fn = nn.CrossEntropyLoss() # 交叉熵损失(适用于分类任务) # optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # SGD优化器 # 原代码的optimizer替换为Adam optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Adam默认lr=0.001即可 # 8. 执行训练和测试 # print("开始训练:") # train(train_dataloader, model, loss_fn, optimizer) # 训练一轮 # 原代码只调用了1次train,改成循环10轮 epochs = 10 # 训练10轮 for t in range(epochs): print(f"\n训练轮数 {t+1}/{epochs}") train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model, loss_fn) # 每轮训练后测试 print("训练完成!") print("开始测试:") test(test_dataloader, model, loss_fn)运行结果:
2.2 关键代码解析(新手必看)
(1)datasets.MNIST:数据集加载神器
datasets.MNIST是PyTorch封装好的数据集加载工具,核心作用是跳过“下载-解压-解析”的繁琐步骤,直接得到可使用的Dataset对象。关键参数说明:
参数名 | 核心作用 | 常用取值 |
|---|---|---|
root | 指定数据集保存/读取目录 | root="data"(当前目录下创建data文件夹) |
train | 选择加载训练集/测试集 | train=True(训练集)、train=False(测试集) |
download | 本地无数据时自动下载 | 首次使用设为True,后续设为False |
transform | 图片预处理 | 转Tensor+标准化(提升训练效果) |
(2)DataLoader:批量训练的核心
代码:train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
核心作用是将零散的Dataset样本打包成批次张量,适配神经网络批量训练逻辑:
batch_size=64:每批包含64个样本,生成形状为[64,1,28,28]的批次图片张量和[64]的批次标签张量;
shuffle=True:训练前打乱样本顺序,避免模型记住样本顺序导致泛化能力差;
迭代使用:通过
for X, y in dataloader遍历,X为批次图片,y为对应批次标签(形参dataloader可接收train_dataloader或test_dataloader)。
(3)批次张量:高效训练的关键
批次张量是将多个样本的图片/标签合并成的大张量,核心优势是利用矩阵运算实现批量计算:
单张图片张量:[1,28,28](1通道+28×28像素);
批次图片张量(64个样本):[64,1,28,28](新增批次维度);
优势:神经网络通过一次矩阵运算就能完成64张图片的前向传播,效率远高于逐个计算。
(4)nn.Module:神经网络的骨架
自定义模型必须继承nn.Module,它是PyTorch所有神经网络模块的基类,提供参数管理、GPU适配、前向传播支持等核心功能(并非具体模型):
__init__方法:初始化网络层(如nn.Flatten展平层、nn.Linear全连接层),需调用super().__init__()启用父类功能;
forward方法:定义前向传播逻辑,调用model(X)时会自动执行,无需手动调用。
(5).item():标量张量转普通数值
训练中计算的损失loss是标量张量(shape=[]),无法直接用于Python格式化输出,通过loss.item()可剥离张量的设备、梯度等额外信息,转成普通float数值。
三、训练优化:从30%到97%的关键调整
初始训练准确率仅30%左右,通过以下关键调整,第六轮训练准确率就达到97%+:
替换激活函数:将sigmoid改为ReLU,解决深层网络梯度消失问题,让模型快速收敛;
增加训练轮数:从1轮改为10轮,让模型充分学习数据特征;
训练数据打乱:在DataLoader中添加shuffle=True,提升模型泛化能力;
优化器调整:从SGD改为Adam(自适应学习率),加速收敛;
数据标准化:添加transforms.Normalize,让模型更快收敛。
四、核心疑问解答(新手避坑指南)
Q1:为什么img, label = training_data[i]能将一个值拆成两个变量?
training_data[i]返回的是包含两个元素的元组(图片数据+标签),这是Python的元组解包语法,会自动将元组中的两个值按顺序赋值给img和label,等价于img = training_data[i][0]和label = training_data[i][1]。
Q2:怎么确认MNIST图片是28×28像素?
28×28是MNIST数据集的官方固有属性,可通过代码验证:取单个样本img = training_data[0][0],打印img.shape会输出torch.Size([1,28,28]),明确显示图片尺寸为28×28。
Q3:输入层接收28×28特征,怎么同时训练64张图片?
通过nn.Flatten()将64张图片的[64,1,28,28]张量展平为[64,784](784=28×28),输入层nn.Linear(784,128)通过矩阵运算(64×784的输入矩阵 × 784×128的权重矩阵),一次性完成64张图片的特征转换,实现批量训练。
Q4:nn.Module和感知器的关系?
感知器是单个神经元的计算逻辑(如nn.Linear层的一个神经元),而nn.Module是承载感知器/网络层的骨架,提供参数管理、GPU计算等运行能力,两者是“被承载者”和“承载框架”的关系,并非等价关系。
五、总结与成果
本次从MNIST数据集认知开始,逐步完成了数据加载、模型搭建、训练测试全流程,解决了新手常见的语法疑问和训练优化问题,最终通过多层感知器实现了97%+的分类准确率。核心收获:
掌握PyTorch核心工具(datasets.MNIST、DataLoader、nn.Module)的使用逻辑;
理解批量训练、批次张量的核心原理;
掌握神经网络训练优化的关键技巧(激活函数、优化器、数据预处理等)。
对于刚接触PyTorch的新手来说,从疑问重重到实现高准确率训练,每一步都是成长。后续可尝试调整batch_size、学习率等参数,进一步优化模型性能,或探索CNN等更复杂的网络结构在MNIST上的表现。