news 2025/12/22 0:59:35

第P2周:CIFAR10彩色图片识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
第P2周:CIFAR10彩色图片识别
  • 🍨本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖原作者:K同学啊

目录

一、 前期准备

1. 设置GPU

2. 导入数据

3. 数据可视化

二、构建简单的CNN网络

三、 训练模型

1. 设置超参数

2. 编写训练函数

3. 编写测试函数

4. 正式训练

四、 结果可视化

五、个人总结

一、 前期准备

1. 设置GPU

import torch import torch.nn as nn import matplotlib.pyplot as plt import torchvision device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device

2. 导入数据

train_ds = torchvision.datasets.CIFAR10('data', train=True, transform=torchvision.transforms.ToTensor(), download=True) test_ds = torchvision.datasets.CIFAR10('data', train=False, transform=torchvision.transforms.ToTensor(), download=True) batch_size = 32 train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True) test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size) imgs, labels = next(iter(train_dl)) imgs.shape

3. 数据可视化

import numpy as np plt.figure(figsize=(20, 5)) for i, imgs in enumerate(imgs[:20]): npimg = imgs.numpy().transpose((1, 2, 0)) plt.subplot(2, 10, i+1) plt.imshow(npimg, cmap=plt.cm.binary) plt.axis('off')

二、构建简单的CNN网络

import torch.nn.functional as F num_classes = 10 class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = nn.Conv2d(64, 64, kernel_size=3) self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = nn.Conv2d(64, 128, kernel_size=3) self.pool3 = nn.MaxPool2d(kernel_size=2) self.fc1 = nn.Linear(512, 256) self.fc2 = nn.Linear(256, num_classes) def forward(self, x): x = self.pool1(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) x = self.pool3(F.relu(self.conv3(x))) x = torch.flatten(x, start_dim=1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x from torchinfo import summary # 将模型转移到GPU中(我们模型运行均在GPU中进行) model = Model().to(device) summary(model)

三、 训练模型

1. 设置超参数

loss_fn = nn.CrossEntropyLoss() # 创建损失函数 learn_rate = 1e-2 # 学习率 opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

2. 编写训练函数

# 训练循环 def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 训练集的大小,一共60000张图片 num_batches = len(dataloader) # 批次数目,1875(60000/32) train_loss, train_acc = 0, 0 # 初始化训练损失和正确率 for X, y in dataloader: # 获取图片及其标签 X, y = X.to(device), y.to(device) # 计算预测误差 pred = model(X) # 网络输出 loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失 # 反向传播 optimizer.zero_grad() # grad属性归零 loss.backward() # 反向传播 optimizer.step() # 每一步自动更新 # 记录acc与loss train_acc += (pred.argmax(1) == y).type(torch.float).sum().item() train_loss += loss.item() train_acc /= size train_loss /= num_batches return train_acc, train_loss

3. 编写测试函数

def test (dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集的大小,一共10000张图片 num_batches = len(dataloader) # 批次数目,313(10000/32=312.5,向上取整) test_loss, test_acc = 0, 0 # 当不进行训练时,停止梯度更新,节省计算内存消耗 with torch.no_grad(): for imgs, target in dataloader: imgs, target = imgs.to(device), target.to(device) # 计算loss target_pred = model(imgs) loss = loss_fn(target_pred, target) test_loss += loss.item() test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item() test_acc /= size test_loss /= num_batches return test_acc, test_loss

4. 正式训练

epochs = 10 train_loss = [] train_acc = [] test_loss = [] test_acc = [] for epoch in range(epochs): model.train() epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt) model.eval() epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}') print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss)) print('Done')

四、 结果可视化

import matplotlib.pyplot as plt #隐藏警告 import warnings warnings.filterwarnings("ignore") #忽略警告信息 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 plt.rcParams['figure.dpi'] = 100 #分辨率 from datetime import datetime current_time = datetime.now() # 获取当前时间 epochs_range = range(epochs) plt.figure(figsize=(12, 3)) plt.subplot(1, 2, 1) plt.plot(epochs_range, train_acc, label='Training Accuracy') plt.plot(epochs_range, test_acc, label='Test Accuracy') plt.legend(loc='lower right') plt.title('Training and Validation Accuracy') plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效 plt.subplot(1, 2, 2) plt.plot(epochs_range, train_loss, label='Training Loss') plt.plot(epochs_range, test_loss, label='Test Loss') plt.legend(loc='upper right') plt.title('Training and Validation Loss') plt.show()

五、个人总结

逐渐熟悉CNN模型构建过程,并逐步理解其原理。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2025/12/17 11:00:41

jQuery EasyUI 菜单与按钮 - 创建分割按钮(Split Button)

jQuery EasyUI 菜单与按钮 - 创建分割按钮(Split Button) jQuery EasyUI 的 splitbutton 组件是一种特殊的菜单按钮,它将按钮分为两个部分: 左侧主体:可点击执行默认主要操作(例如“保存”)。…

作者头像 李华
网站建设 2025/12/17 11:00:30

彼得林奇对公司管理层薪酬结构的合理性评估

彼得林奇对公司管理层薪酬结构的合理性评估关键词:彼得林奇、公司管理层、薪酬结构、合理性评估、企业管理摘要:本文深入探讨了彼得林奇对公司管理层薪酬结构合理性评估的相关内容。从背景介绍入手,阐述了研究目的、预期读者等信息。接着详细…

作者头像 李华
网站建设 2025/12/17 10:59:37

jQuery EasyUI 数据网格 - 转换 HTML 表格为数据网格

jQuery EasyUI 数据网格 - 转换 HTML 表格为数据网格&#xff08;DataGrid&#xff09; jQuery EasyUI 的 datagrid 组件最强大的功能之一就是可以直接将现有的 HTML <table> 表格转换为功能丰富的 DataGrid&#xff0c;而无需重新定义列或数据源。这非常适合快速升级传…

作者头像 李华
网站建设 2025/12/17 10:58:39

FreePBX 修复多个严重漏洞

聚焦源代码安全&#xff0c;网罗国内外最新资讯&#xff01;编译&#xff1a;代码卫士开源的 PBX 平台 FreePBX 上存在多个漏洞&#xff0c;其中一个严重漏洞在某些配置下课导致认证绕过漏洞。这些漏洞由 Horizon3.ai 团队发现并在2025年9月15日报送给项目维护人员。这些漏洞如…

作者头像 李华
网站建设 2025/12/17 10:58:13

解码企业管理新范式:如何以技术驱动实现企业全周期价值跃升

在数字化浪潮与产业升级的双重变革下&#xff0c;企业管理的核心已从传统的流程管控&#xff0c;演进为以资本化加速、合规化运营、精益化增长为目标的战略赋能。选择一家真正具备深厚实力、技术底蕴与全景服务能力的合作伙伴&#xff0c;已成为企业在激烈竞争中构筑护城河的关…

作者头像 李华
网站建设 2025/12/17 10:58:05

EmotiVoice使用指南:快速上手高表现力TTS模型

EmotiVoice使用指南&#xff1a;快速上手高表现力TTS模型 在虚拟助手越来越“懂人心”、游戏角色开始“真情流露”的今天&#xff0c;语音合成早已不再是简单地把文字念出来。用户期待的是有温度、有情绪、像真人一样的声音——而这一点&#xff0c;正是传统TTS系统的短板。 机…

作者头像 李华