3个ResNet18实战项目:从入门到精通
引言
对于想要转行AI领域的朋友来说,最头疼的问题莫过于"没有实际项目经验"。而ResNet18作为计算机视觉领域的经典模型,是构建AI项目经验的绝佳起点。但很多初学者都会遇到一个现实问题:本地电脑配置太低,完全跑不动深度学习模型。
别担心,本文将带你用ResNet18完成3个实用项目,从图像分类到目标检测,再到工业质检应用。即使你只有一台普通笔记本,也能通过云GPU资源快速上手实践。每个项目都包含完整代码和详细说明,确保你能真正掌握ResNet18的应用技巧。
1. 项目一:花卉分类系统(入门级)
1.1 项目简介
花卉分类是计算机视觉最经典的入门项目。我们将使用ResNet18在Oxford 102 Flowers数据集上训练一个能识别102种花卉的分类器。这个项目特别适合初学者,因为:
- 数据集规范且易于获取
- 分类任务直观易懂
- ResNet18在这个规模的数据集上表现优异
1.2 环境准备
首先我们需要准备GPU环境。如果你本地没有显卡,可以使用CSDN星图镜像广场提供的PyTorch预置镜像,它已经配置好了CUDA和必要的深度学习库。
# 基础环境安装(如果使用云镜像可跳过) conda create -n flower python=3.8 conda activate flower pip install torch torchvision torchaudio1.3 数据准备
下载并解压Oxford 102 Flowers数据集:
import torchvision.datasets as datasets from torchvision import transforms # 数据预处理 transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 加载数据集 train_data = datasets.Flowers102(root='./data', split='train', download=True, transform=transform) val_data = datasets.Flowers102(root='./data', split='val', transform=transform) test_data = datasets.Flowers102(root='./data', split='test', transform=transform)1.4 模型训练
使用预训练的ResNet18进行微调:
import torch.nn as nn import torch.optim as optim from torchvision import models # 加载预训练模型 model = models.resnet18(pretrained=True) # 修改最后一层全连接层 num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 102) # 102个花卉类别 # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 训练循环(简化版) for epoch in range(10): model.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step()1.5 效果评估
训练完成后,我们可以在测试集上评估模型性能:
correct = 0 total = 0 model.eval() with torch.no_grad(): for inputs, labels in test_loader: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f'Accuracy: {100 * correct / total}%')典型情况下,经过10个epoch的训练,模型在测试集上的准确率可以达到85%左右。
2. 项目二:实时目标检测系统(进阶级)
2.1 项目简介
在这个项目中,我们将ResNet18作为Backbone,构建一个轻量级的实时目标检测系统。与单纯的分类不同,目标检测需要同时识别物体类别和位置,更具挑战性也更有实用价值。
2.2 模型架构
我们采用Faster R-CNN框架,但用ResNet18替换原本的ResNet50/101,以提升推理速度:
from torchvision.models.detection import fasterrcnn_resnet50_fpn from torchvision.models.detection.faster_rcnn import FastRCNNPredictor # 加载预训练模型 model = fasterrcnn_resnet50_fpn(pretrained=True) # 替换Backbone为ResNet18 backbone = models.resnet18(pretrained=True) backbone.out_channels = 512 # ResNet18最后一层的输出通道数 # 修改分类头 num_classes = 2 # 假设我们只检测人和车 in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)2.3 数据准备
使用COCO或VOC格式的数据集,这里以PASCAL VOC为例:
from torchvision.datasets import VOCDetection # 自定义转换函数 def transform_fn(target): # 将VOC标注格式转换为模型需要的格式 boxes = [] labels = [] for obj in target['annotation']['object']: bbox = obj['bndbox'] boxes.append([float(bbox['xmin']), float(bbox['ymin']), float(bbox['xmax']), float(bbox['ymax'])]) labels.append(1 if obj['name'] == 'person' else 2) # 假设1是人,2是车 return {'boxes': torch.as_tensor(boxes, dtype=torch.float32), 'labels': torch.as_tensor(labels, dtype=torch.int64)} # 加载数据集 train_dataset = VOCDetection(root='./data', year='2012', image_set='train', transforms=transform_fn, download=True)2.4 实时检测实现
使用OpenCV实现摄像头实时检测:
import cv2 import numpy as np # 加载训练好的模型 model.eval() # 初始化摄像头 cap = cv2.VideoCapture(0) while True: ret, frame = cap.read() if not ret: break # 预处理 frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) image_tensor = transform(frame_rgb).unsqueeze(0) # 推理 with torch.no_grad(): predictions = model(image_tensor) # 绘制检测框 for box, label in zip(predictions[0]['boxes'], predictions[0]['labels']): x1, y1, x2, y2 = map(int, box) cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(frame, f'{"person" if label==1 else "car"}', (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0,255,0), 2) cv2.imshow('Real-time Detection', frame) if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows()2.5 性能优化技巧
- 输入尺寸调整:减小输入图像尺寸可以显著提升FPS
- 半精度推理:使用FP16可以加速计算且几乎不影响精度
- ONNX转换:将模型导出为ONNX格式并使用ONNX Runtime推理
3. 项目三:工业缺陷检测系统(专业级)
3.1 项目简介
工业质检是计算机视觉的重要应用场景。我们将使用ResNet18构建一个表面缺陷检测系统,适用于电子产品、纺织品等工业场景。
3.2 数据准备
使用MVTec AD数据集或自建数据集:
from torch.utils.data import Dataset import os from PIL import Image class DefectDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.images = [] self.labels = [] # 假设目录结构:root_dir/class_name/*.jpg for class_name in os.listdir(root_dir): class_dir = os.path.join(root_dir, class_name) if os.path.isdir(class_dir): label = 0 if class_name == 'good' else 1 # 0表示正常,1表示缺陷 for img_name in os.listdir(class_dir): self.images.append(os.path.join(class_dir, img_name)) self.labels.append(label) def __len__(self): return len(self.images) def __getitem__(self, idx): image = Image.open(self.images[idx]).convert('RGB') label = self.labels[idx] if self.transform: image = self.transform(image) return image, label3.3 模型改进
针对缺陷检测任务,我们对ResNet18进行改进:
class DefectDetector(nn.Module): def __init__(self): super(DefectDetector, self).__init__() self.backbone = models.resnet18(pretrained=True) # 移除最后的全连接层 self.features = nn.Sequential(*list(self.backbone.children())[:-2]) # 添加自定义头 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.classifier = nn.Sequential( nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, x): x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x3.4 训练技巧
数据增强:针对工业图像特点使用特定增强
python train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(10), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])损失函数选择:使用Focal Loss处理类别不平衡 ```python class FocalLoss(nn.Module): definit(self, alpha=0.25, gamma=2): super(FocalLoss, self).init() self.alpha = alpha self.gamma = gamma
def forward(self, inputs, targets): BCE_loss = nn.BCELoss(reduction='none')(inputs, targets) pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss return F_loss.mean() ```
3.5 部署优化
模型量化:减小模型大小,提升推理速度
python quantized_model = torch.quantization.quantize_dynamic( model, {nn.Linear}, dtype=torch.qint8 )TensorRT加速:将模型转换为TensorRT引擎
python # 需要安装torch2trt from torch2trt import torch2trt model_trt = torch2trt(model, [dummy_input], fp16_mode=True)
4. 总结
通过这三个项目,你已经掌握了ResNet18在不同场景下的应用方法:
- 花卉分类项目:最基础的图像分类任务,适合理解ResNet18的基本用法
- 目标检测系统:展示了如何将ResNet18作为Backbone构建更复杂的视觉系统
- 工业缺陷检测:演示了针对特定任务的模型改进和优化技巧
关键收获:
- ResNet18虽然结构简单,但通过合理的微调和改进,可以胜任多种计算机视觉任务
- 对于计算资源有限的场景,ResNet18是平衡性能和效率的绝佳选择
- 模型优化技巧(如量化、剪枝)可以显著提升推理速度,使模型更适合工业部署
- 数据预处理和增强对模型性能的影响不亚于模型架构本身
现在你就可以选择一个最感兴趣的项目开始实践了!记住,在AI领域,动手实践比理论学习更重要。
💡获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。