ResNet18多分类实战:花卉识别从数据到部署全流程
引言
当你需要让计算机识别不同种类的花卉时,ResNet18就像一位经验丰富的植物学家,能快速准确地告诉你眼前的花朵属于哪一类。这个轻量级神经网络特别适合像大学生竞赛这样的场景,尤其是当你面对200种花卉的大数据集,而笔记本电脑又跑不动的时候。
想象一下,你要参加一个花卉识别比赛,数据集包含200种不同的花卉图片,每类有几十到上百张样本。直接在笔记本上训练可能会遇到内存不足、速度缓慢甚至死机的问题。这时候,云端GPU资源就像给你的电脑外接了一个超级引擎,让训练过程快如闪电。
ResNet18作为经典的轻量级卷积神经网络,在保持较高准确率的同时,对计算资源的需求相对较低。它就像一辆省油但动力十足的小跑车,特别适合学生党在有限预算下完成深度学习任务。接下来,我将带你从数据准备到模型部署,完整走一遍花卉识别的全流程。
1. 环境准备与数据获取
1.1 选择云端GPU环境
对于200类花卉识别任务,建议选择至少8GB显存的GPU。在CSDN算力平台上,你可以找到预装了PyTorch和CUDA的基础镜像,这些环境已经配置好了深度学习所需的各种依赖。
# 检查GPU是否可用 import torch print(torch.cuda.is_available()) # 应该返回True print(torch.cuda.get_device_name(0)) # 显示你的GPU型号1.2 准备花卉数据集
花卉识别常用的数据集包括Oxford 102 Flowers和Flowers Recognition等。这里我们以Oxford 102 Flowers为例,它包含102类花卉,每类40-258张图片。
# 下载并解压数据集 !wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz !tar -xzf 102flowers.tgz如果你的比赛要求200类,可能需要自行收集更多数据或合并多个数据集。确保数据集结构如下:
flowers_dataset/ class_1/ img1.jpg img2.jpg ... class_2/ img1.jpg img2.jpg ... ...2. 数据预处理与增强
2.1 基本数据预处理
花卉图片需要统一尺寸并做归一化处理。ResNet18的标准输入尺寸是224x224像素。
from torchvision import transforms # 定义训练集和验证集的变换 train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])2.2 数据增强技巧
对于花卉识别,这些增强特别有效:
- 随机旋转(小角度,因为花朵通常有明确的方向)
- 颜色抖动(模拟不同光照条件)
- 随机裁剪(关注花朵的不同部位)
# 更丰富的数据增强 advanced_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])3. 构建ResNet18模型
3.1 加载预训练模型
使用在ImageNet上预训练的ResNet18作为基础,替换最后的全连接层以适应我们的200类分类任务。
import torchvision.models as models import torch.nn as nn # 加载预训练模型 model = models.resnet18(pretrained=True) # 修改最后一层全连接层 num_classes = 200 # 根据你的类别数调整 model.fc = nn.Linear(model.fc.in_features, num_classes) # 将模型移到GPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device)3.2 模型结构解析
ResNet18之所以适合这个任务,是因为:
- 18层深度:足够捕捉花卉的复杂特征,又不会过度消耗资源
- 残差连接:解决了深层网络训练困难的问题
- 约1100万参数:相比更大的模型,在200类任务上仍有不错表现
4. 训练模型
4.1 设置训练参数
对于花卉识别,这些参数组合通常效果不错:
import torch.optim as optim criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)4.2 训练循环
下面是标准的训练循环,加入了验证阶段:
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25): for epoch in range(num_epochs): # 训练阶段 model.train() running_loss = 0.0 for inputs, labels in dataloaders['train']: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() scheduler.step() # 验证阶段 model.eval() val_loss = 0.0 corrects = 0 with torch.no_grad(): for inputs, labels in dataloaders['val']: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() _, preds = torch.max(outputs, 1) corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(dataloaders['train'].dataset) epoch_val_loss = val_loss / len(dataloaders['val'].dataset) epoch_acc = corrects.double() / len(dataloaders['val'].dataset) print(f'Epoch {epoch}/{num_epochs-1}') print(f'Train Loss: {epoch_loss:.4f} Val Loss: {epoch_val_loss:.4f} Val Acc: {epoch_acc:.4f}') return model5. 模型评估与优化
5.1 评估指标
除了准确率,对于多分类问题还应该关注:
- 混淆矩阵:查看哪些类别容易混淆
- 每类的精确率、召回率和F1分数
- Top-k准确率(特别是Top-3或Top-5)
from sklearn.metrics import classification_report, confusion_matrix def evaluate_model(model, dataloader, class_names): model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for inputs, labels in dataloader: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) print(classification_report(all_labels, all_preds, target_names=class_names)) print("Confusion Matrix:") print(confusion_matrix(all_labels, all_preds))5.2 常见问题与解决方案
问题1:某些花卉类别准确率特别低
- 解决方案:检查这些类别的样本数量是否过少,考虑数据增强或收集更多样本
问题2:模型过拟合
- 解决方案:增加Dropout层,或使用更强的数据增强
问题3:训练速度慢
- 解决方案:增大batch size(根据GPU显存调整),或使用混合精度训练
from torch.cuda.amp import GradScaler, autocast # 混合精度训练示例 scaler = GradScaler() for inputs, labels in dataloaders['train']: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() with autocast(): outputs = model(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()6. 模型部署与应用
6.1 保存训练好的模型
# 保存整个模型 torch.save(model, 'flower_resnet18.pth') # 或者只保存状态字典(推荐) torch.save(model.state_dict(), 'flower_resnet18_state_dict.pth')6.2 创建简单的推理API
使用Flask创建一个简单的Web服务:
from flask import Flask, request, jsonify from PIL import Image import io app = Flask(__name__) # 加载模型 model = models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features, 200) model.load_state_dict(torch.load('flower_resnet18_state_dict.pth')) model.eval() @app.route('/predict', methods=['POST']) def predict(): if 'file' not in request.files: return jsonify({'error': 'no file uploaded'}) file = request.files['file'] img_bytes = file.read() img = Image.open(io.BytesIO(img_bytes)) # 应用相同的变换 img = val_transform(img).unsqueeze(0) with torch.no_grad(): output = model(img) _, pred = torch.max(output, 1) return jsonify({'class_id': int(pred), 'class_name': class_names[int(pred)]}) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000)6.3 性能优化技巧
- TorchScript:将模型转换为TorchScript提高推理速度
- ONNX:导出为ONNX格式,兼容更多推理引擎
- 量化:使用8位整数量化减小模型大小,提高速度
# 量化示例 quantized_model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 )总结
通过本教程,你已经完成了从数据准备到模型部署的花卉识别全流程。以下是核心要点:
- ResNet18是轻量级但强大的选择,特别适合200类花卉识别这样的多分类任务
- 云端GPU资源可以轻松应对大数据集训练,避免笔记本性能不足的问题
- 数据增强对花卉识别特别重要,能显著提高模型泛化能力
- 模型评估不应只看整体准确率,还要分析各类别的表现
- 部署时可以灵活选择方案,从简单API到优化后的推理引擎
现在你就可以在CSDN算力平台上尝试这个流程,使用预置的PyTorch镜像快速开始你的花卉识别项目。
💡获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。