news 2026/4/15 17:01:13

ResNet18多分类实战:花卉识别从数据到部署全流程

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
ResNet18多分类实战:花卉识别从数据到部署全流程

ResNet18多分类实战:花卉识别从数据到部署全流程

引言

当你需要让计算机识别不同种类的花卉时,ResNet18就像一位经验丰富的植物学家,能快速准确地告诉你眼前的花朵属于哪一类。这个轻量级神经网络特别适合像大学生竞赛这样的场景,尤其是当你面对200种花卉的大数据集,而笔记本电脑又跑不动的时候。

想象一下,你要参加一个花卉识别比赛,数据集包含200种不同的花卉图片,每类有几十到上百张样本。直接在笔记本上训练可能会遇到内存不足、速度缓慢甚至死机的问题。这时候,云端GPU资源就像给你的电脑外接了一个超级引擎,让训练过程快如闪电。

ResNet18作为经典的轻量级卷积神经网络,在保持较高准确率的同时,对计算资源的需求相对较低。它就像一辆省油但动力十足的小跑车,特别适合学生党在有限预算下完成深度学习任务。接下来,我将带你从数据准备到模型部署,完整走一遍花卉识别的全流程。

1. 环境准备与数据获取

1.1 选择云端GPU环境

对于200类花卉识别任务,建议选择至少8GB显存的GPU。在CSDN算力平台上,你可以找到预装了PyTorch和CUDA的基础镜像,这些环境已经配置好了深度学习所需的各种依赖。

# 检查GPU是否可用 import torch print(torch.cuda.is_available()) # 应该返回True print(torch.cuda.get_device_name(0)) # 显示你的GPU型号

1.2 准备花卉数据集

花卉识别常用的数据集包括Oxford 102 Flowers和Flowers Recognition等。这里我们以Oxford 102 Flowers为例,它包含102类花卉,每类40-258张图片。

# 下载并解压数据集 !wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz !tar -xzf 102flowers.tgz

如果你的比赛要求200类,可能需要自行收集更多数据或合并多个数据集。确保数据集结构如下:

flowers_dataset/ class_1/ img1.jpg img2.jpg ... class_2/ img1.jpg img2.jpg ... ...

2. 数据预处理与增强

2.1 基本数据预处理

花卉图片需要统一尺寸并做归一化处理。ResNet18的标准输入尺寸是224x224像素。

from torchvision import transforms # 定义训练集和验证集的变换 train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])

2.2 数据增强技巧

对于花卉识别,这些增强特别有效:

  • 随机旋转(小角度,因为花朵通常有明确的方向)
  • 颜色抖动(模拟不同光照条件)
  • 随机裁剪(关注花朵的不同部位)
# 更丰富的数据增强 advanced_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])

3. 构建ResNet18模型

3.1 加载预训练模型

使用在ImageNet上预训练的ResNet18作为基础,替换最后的全连接层以适应我们的200类分类任务。

import torchvision.models as models import torch.nn as nn # 加载预训练模型 model = models.resnet18(pretrained=True) # 修改最后一层全连接层 num_classes = 200 # 根据你的类别数调整 model.fc = nn.Linear(model.fc.in_features, num_classes) # 将模型移到GPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device)

3.2 模型结构解析

ResNet18之所以适合这个任务,是因为:

  1. 18层深度:足够捕捉花卉的复杂特征,又不会过度消耗资源
  2. 残差连接:解决了深层网络训练困难的问题
  3. 约1100万参数:相比更大的模型,在200类任务上仍有不错表现

4. 训练模型

4.1 设置训练参数

对于花卉识别,这些参数组合通常效果不错:

import torch.optim as optim criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

4.2 训练循环

下面是标准的训练循环,加入了验证阶段:

def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25): for epoch in range(num_epochs): # 训练阶段 model.train() running_loss = 0.0 for inputs, labels in dataloaders['train']: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() scheduler.step() # 验证阶段 model.eval() val_loss = 0.0 corrects = 0 with torch.no_grad(): for inputs, labels in dataloaders['val']: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() _, preds = torch.max(outputs, 1) corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(dataloaders['train'].dataset) epoch_val_loss = val_loss / len(dataloaders['val'].dataset) epoch_acc = corrects.double() / len(dataloaders['val'].dataset) print(f'Epoch {epoch}/{num_epochs-1}') print(f'Train Loss: {epoch_loss:.4f} Val Loss: {epoch_val_loss:.4f} Val Acc: {epoch_acc:.4f}') return model

5. 模型评估与优化

5.1 评估指标

除了准确率,对于多分类问题还应该关注:

  • 混淆矩阵:查看哪些类别容易混淆
  • 每类的精确率、召回率和F1分数
  • Top-k准确率(特别是Top-3或Top-5)
from sklearn.metrics import classification_report, confusion_matrix def evaluate_model(model, dataloader, class_names): model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for inputs, labels in dataloader: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) print(classification_report(all_labels, all_preds, target_names=class_names)) print("Confusion Matrix:") print(confusion_matrix(all_labels, all_preds))

5.2 常见问题与解决方案

问题1:某些花卉类别准确率特别低

  • 解决方案:检查这些类别的样本数量是否过少,考虑数据增强或收集更多样本

问题2:模型过拟合

  • 解决方案:增加Dropout层,或使用更强的数据增强

问题3:训练速度慢

  • 解决方案:增大batch size(根据GPU显存调整),或使用混合精度训练
from torch.cuda.amp import GradScaler, autocast # 混合精度训练示例 scaler = GradScaler() for inputs, labels in dataloaders['train']: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() with autocast(): outputs = model(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()

6. 模型部署与应用

6.1 保存训练好的模型

# 保存整个模型 torch.save(model, 'flower_resnet18.pth') # 或者只保存状态字典(推荐) torch.save(model.state_dict(), 'flower_resnet18_state_dict.pth')

6.2 创建简单的推理API

使用Flask创建一个简单的Web服务:

from flask import Flask, request, jsonify from PIL import Image import io app = Flask(__name__) # 加载模型 model = models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features, 200) model.load_state_dict(torch.load('flower_resnet18_state_dict.pth')) model.eval() @app.route('/predict', methods=['POST']) def predict(): if 'file' not in request.files: return jsonify({'error': 'no file uploaded'}) file = request.files['file'] img_bytes = file.read() img = Image.open(io.BytesIO(img_bytes)) # 应用相同的变换 img = val_transform(img).unsqueeze(0) with torch.no_grad(): output = model(img) _, pred = torch.max(output, 1) return jsonify({'class_id': int(pred), 'class_name': class_names[int(pred)]}) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000)

6.3 性能优化技巧

  1. TorchScript:将模型转换为TorchScript提高推理速度
  2. ONNX:导出为ONNX格式,兼容更多推理引擎
  3. 量化:使用8位整数量化减小模型大小,提高速度
# 量化示例 quantized_model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 )

总结

通过本教程,你已经完成了从数据准备到模型部署的花卉识别全流程。以下是核心要点:

  • ResNet18是轻量级但强大的选择,特别适合200类花卉识别这样的多分类任务
  • 云端GPU资源可以轻松应对大数据集训练,避免笔记本性能不足的问题
  • 数据增强对花卉识别特别重要,能显著提高模型泛化能力
  • 模型评估不应只看整体准确率,还要分析各类别的表现
  • 部署时可以灵活选择方案,从简单API到优化后的推理引擎

现在你就可以在CSDN算力平台上尝试这个流程,使用预置的PyTorch镜像快速开始你的花卉识别项目。


💡获取更多AI镜像

想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/10 15:12:03

ResNet18持续集成:云端GitHub Actions自动化测试

ResNet18持续集成:云端GitHub Actions自动化测试 引言 在AI模型开发中,ResNet18作为经典的轻量级卷积神经网络,被广泛应用于图像分类、目标检测等任务。但对于团队协作开发来说,如何确保每次代码提交都能自动完成模型训练和测试…

作者头像 李华
网站建设 2026/4/2 5:34:16

对比传统开发:XPERT如何让字节跳动效率提升300%

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容: 创建一个效率对比工具:1) 传统方式:手动编写一个电商商品详情页(前端后端测试) 2) XPERT方式:通过自然语言描述生成相同…

作者头像 李华
网站建设 2026/4/4 18:41:54

ResNet18图像分类实战:云端GPU 10分钟出结果,2块钱玩转

ResNet18图像分类实战:云端GPU 10分钟出结果,2块钱玩转 1. 为什么设计师需要ResNet18? 作为一名设计师,你可能经常遇到这样的困扰:电脑里存了大量设计素材,却很难快速找到特定类型的图片。比如想找"…

作者头像 李华
网站建设 2026/4/11 18:30:32

福建云安全独角兽估值已近30亿,战略大调整疑冲刺港股IPO

福建云安全独角兽估值已近30亿,战略大调整疑冲刺港股IPO 中国网络安全行业正在经历一次新的周期变化。AI的全面渗透正在重塑安全体系的底层结构,云计算的普及让攻击面迅速扩大,传统防护方式已难以跟上复杂攻击的演进节奏。行业的下一轮竞争焦…

作者头像 李华
网站建设 2026/4/15 9:20:27

电商系统中的RPC实战:从秒杀到分布式事务

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容: 开发一个电商系统RPC调用模拟器,模拟秒杀场景下的高并发RPC调用。要求实现商品库存服务、订单服务和支付服务三个微服务,通过RPC进行通信。包含流量控制、熔…

作者头像 李华