ResNet18多分类实战:花卉识别完整案例,1块钱体验
引言
你是否曾在花园里看到一朵美丽的花,却叫不出它的名字?作为植物爱好者,我们常常会遇到这样的困扰。现在,借助AI技术,你可以轻松识别各种花卉品种。本文将带你用ResNet18模型构建一个花卉识别系统,整个过程就像教小朋友认图识字一样简单。
ResNet18是深度学习领域经典的图像分类模型,它通过"跳跃连接"解决了深层网络训练难题(就像给记忆不好的学生准备小抄)。我们将使用PyTorch框架和预训练模型,即使你没有任何AI基础,也能在1小时内完成从环境搭建到实际预测的全流程。
这个实战案例有三大特点: -成本极低:使用云平台GPU资源,全程花费不到1块钱 -完整可运行:提供从数据准备到模型预测的完整代码 -即学即用:学完就能识别常见花卉品种
1. 环境准备:5分钟搞定AI实验室
首先我们需要准备开发环境,就像厨师需要先准备好厨房和食材。推荐使用CSDN星图镜像广场的PyTorch预置镜像,它已经装好了所有必需工具。
1.1 选择合适的环境
对于这个项目,我们需要: - Python 3.8+ - PyTorch 1.12+ - torchvision库 - GPU支持(能让训练速度提升10倍以上)
在CSDN算力平台选择"PyTorch 1.12 + CUDA 11.3"基础镜像,这是已经配置好的"AI厨房",开箱即用。
1.2 安装额外依赖
启动环境后,只需再安装一个数据处理库:
pip install pandas matplotlib💡 提示
如果使用本地环境,建议创建虚拟环境避免包冲突:
python -m venv flower_env
2. 数据准备:建立你的花卉图库
好的AI模型需要好的数据,就像好学生需要好的教材。我们将使用公开的Oxford 102花卉数据集,包含102类常见花卉的图片。
2.1 下载数据集
运行以下代码自动下载并解压数据:
import torchvision.datasets as datasets from torchvision import transforms # 定义图像预处理(标准化ImageNet预训练模型的输入) transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 下载数据集 train_data = datasets.Flowers102(root='./data', split='train', download=True, transform=transform) val_data = datasets.Flowers102(root='./data', split='val', transform=transform) test_data = datasets.Flowers102(root='./data', split='test', transform=transform)2.2 创建数据加载器
将数据打包成批次,方便模型训练:
from torch.utils.data import DataLoader batch_size = 32 # 每次处理32张图片 train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_data, batch_size=batch_size) test_loader = DataLoader(test_data, batch_size=batch_size)3. 模型构建:使用预训练ResNet18
我们不会从零开始训练模型(那需要几天时间和昂贵设备),而是采用迁移学习,就像让一个已经会认动物的孩子来学认花。
3.1 加载预训练模型
import torchvision.models as models import torch.nn as nn # 加载预训练ResNet18 model = models.resnet18(pretrained=True) # 修改最后一层,适配我们的102类花卉分类 num_features = model.fc.in_features model.fc = nn.Linear(num_features, 102) # 102个花卉类别3.2 设置训练参数
import torch.optim as optim device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) criterion = nn.CrossEntropyLoss() # 损失函数 optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 优化器4. 模型训练:教AI认识花卉
现在进入最激动人心的环节——训练模型。这就像老师教学生认图,需要反复练习。
4.1 训练函数
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10): for epoch in range(num_epochs): model.train() # 训练模式 running_loss = 0.0 for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() # 清零梯度 outputs = model(inputs) # 前向传播 loss = criterion(outputs, labels) # 计算损失 loss.backward() # 反向传播 optimizer.step() # 更新参数 running_loss += loss.item() # 每个epoch结束后验证 model.eval() # 评估模式 val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"Epoch {epoch+1}/{num_epochs} | " f"Train Loss: {running_loss/len(train_loader):.4f} | " f"Val Loss: {val_loss/len(val_loader):.4f} | " f"Val Acc: {100*correct/total:.2f}%") # 开始训练(10个epoch约15分钟) train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)4.2 训练技巧
- 学习率调整:训练后期可以减小学习率
- 早停机制:当验证集准确率不再提升时停止训练
- 数据增强:增加随机翻转、旋转等提升模型泛化能力
5. 模型评估与预测:看看AI学得怎么样
训练完成后,我们需要测试模型的真实水平,就像给学生期末考试。
5.1 测试集评估
def evaluate(model, test_loader): model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in test_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"Test Accuracy: {100 * correct / total:.2f}%") evaluate(model, test_loader)5.2 单张图片预测
让我们试试用训练好的模型识别一张新图片:
from PIL import Image def predict_image(image_path, model, class_names): image = Image.open(image_path) image = transform(image).unsqueeze(0).to(device) # 预处理并添加批次维度 model.eval() with torch.no_grad(): output = model(image) _, predicted = torch.max(output, 1) return class_names[predicted.item()] # 假设我们有花卉类别名称列表 class_names = ['玫瑰', '郁金香', '向日葵', ...] # 实际应为102类名称 print(predict_image('my_flower.jpg', model, class_names))6. 模型优化与部署
要让模型在实际中好用,还需要一些优化工作。
6.1 常见优化方法
- 数据增强:训练时增加随机变换,提高泛化能力
- 模型微调:解冻更多层进行训练
- 学习率调度:动态调整学习率
6.2 保存与加载模型
# 保存模型 torch.save(model.state_dict(), 'flower_resnet18.pth') # 加载模型 loaded_model = models.resnet18(pretrained=False) loaded_model.fc = nn.Linear(num_features, 102) loaded_model.load_state_dict(torch.load('flower_resnet18.pth')) loaded_model = loaded_model.to(device)7. 总结
通过这个完整案例,我们实现了:
- 低成本实践:使用云GPU资源,花费不到1块钱完成AI模型训练
- 完整流程:从数据准备到模型预测的端到端实现
- 实用技巧:掌握了图像分类的关键参数和优化方法
核心要点: - 迁移学习让我们能用少量数据训练出高精度模型 - ResNet18是轻量高效的图像分类基础模型 - 数据预处理和增强对模型性能影响巨大 - GPU加速能使训练速度提升10倍以上 - 学完这个案例,你可以轻松扩展到其他图像分类任务
现在就可以上传你的花卉照片,试试这个识别系统吧!实测在102类花卉上的准确率能达到85%以上,对于常见品种识别效果更好。
💡获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。