ResNet18持续学习:云端保存进度,随时继续训练
引言
作为一名业余AI爱好者,你是否遇到过这样的困扰:每天只有1-2小时可以投入模型训练,但每次中断后又要从头开始?ResNet18作为经典的图像分类模型,训练过程往往需要数小时甚至更长时间。本文将介绍如何利用云端GPU资源和PyTorch的模型保存机制,实现训练进度的随时保存与恢复,让你的每一分钟训练时间都不浪费。
ResNet18是残差网络(Residual Network)的18层版本,通过引入"跳跃连接"解决了深层网络训练中的梯度消失问题。它特别适合中小型图像分类任务,如识别猫狗、花卉或医学影像等。通过本文,你将学会:
- 如何设置云端训练环境
- 保存和加载模型检查点的正确方法
- 从任意中断点恢复训练的完整流程
- 优化训练效率的实用技巧
1. 环境准备与镜像选择
1.1 选择预置镜像
在CSDN星图镜像广场中,搜索包含PyTorch和CUDA环境的镜像。推荐选择以下配置的基础镜像:
- PyTorch 1.12+版本
- CUDA 11.3及以上
- 预装常用库(torchvision, numpy等)
这类镜像通常已经配置好了GPU驱动和深度学习框架,省去了繁琐的环境搭建过程。
1.2 数据准备
训练ResNet18需要一个图像分类数据集。以下是几个常用选项:
- CIFAR-10:10类物体的小尺寸图像(32x32),适合快速验证
- 自定义数据集:按类别组织文件夹结构,如:
dataset/ ├── train/ │ ├── class1/ │ ├── class2/ │ └── ... └── val/ ├── class1/ ├── class2/ └── ...
2. 模型训练与进度保存
2.1 基础训练代码
以下是使用PyTorch训练ResNet18的基本框架:
import torch import torch.nn as nn import torch.optim as optim from torchvision import models, transforms, datasets # 初始化模型 model = models.resnet18(pretrained=True) num_classes = 10 # 根据你的数据集调整 model.fc = nn.Linear(model.fc.in_features, num_classes) # 数据预处理 transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 加载数据集 train_dataset = datasets.ImageFolder('dataset/train', transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)2.2 添加检查点保存功能
关键是在训练循环中加入模型保存逻辑。以下是修改后的训练代码片段:
def train_model(model, train_loader, criterion, optimizer, num_epochs=10): for epoch in range(num_epochs): model.train() running_loss = 0.0 for i, (inputs, labels) in enumerate(train_loader): # 前向传播 outputs = model(inputs) loss = criterion(outputs, labels) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() # 每100个batch保存一次检查点 if i % 100 == 99: checkpoint = { 'epoch': epoch, 'batch': i, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': running_loss / 100 } torch.save(checkpoint, 'checkpoint.pth') print(f'Checkpoint saved at epoch {epoch}, batch {i}') print(f'Epoch {epoch} completed') # 训练完成后保存最终模型 torch.save(model.state_dict(), 'final_model.pth')3. 从检查点恢复训练
3.1 加载保存的进度
当需要继续训练时,可以使用以下代码从检查点恢复:
def load_checkpoint(model, optimizer, checkpoint_path): checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] batch = checkpoint['batch'] loss = checkpoint['loss'] return model, optimizer, epoch, batch, loss # 初始化模型和优化器(必须与保存时相同) model = models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features, num_classes) optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 加载检查点 model, optimizer, start_epoch, start_batch, _ = load_checkpoint( model, optimizer, 'checkpoint.pth' ) # 继续训练 for epoch in range(start_epoch, num_epochs): model.train() # 如果是恢复的epoch,跳过已经处理过的batch if epoch == start_epoch: for i, (inputs, labels) in enumerate(train_loader): if i <= start_batch: continue # 正常训练流程... else: # 正常训练流程...3.2 云端训练的最佳实践
- 定时保存:除了手动保存,可以设置定时保存(如每小时自动保存)
- 版本控制:保存多个检查点,避免单一文件损坏
python torch.save(checkpoint, f'checkpoint_epoch{epoch}_batch{i}.pth') - 云存储同步:将检查点定期备份到云存储(如CSDN提供的存储服务)
- 资源监控:训练前检查GPU内存,避免因内存不足中断
4. 常见问题与优化技巧
4.1 常见问题解决
- 检查点加载失败:确保模型结构和优化器与保存时完全一致
GPU内存不足:减小batch size或使用梯度累积 ```python # 梯度累积示例 accumulation_steps = 4 for i, (inputs, labels) in enumerate(train_loader): outputs = model(inputs) loss = criterion(outputs, labels) loss = loss / accumulation_steps # 平均损失 loss.backward()
if (i+1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() ``` -训练不稳定:尝试降低学习率或使用学习率调度器
4.2 效率优化技巧
- 混合精度训练:减少显存占用,加快训练速度 ```python from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler() for inputs, labels in train_loader: optimizer.zero_grad()
with autocast(): outputs = model(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()2. **数据加载优化**:使用多线程数据加载python train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)3. **学习率调整**:使用余弦退火等策略python scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_epochs ) ```
总结
通过本文的实践方案,即使是每天只有1-2小时训练时间的业余爱好者,也能高效完成ResNet18模型的训练:
- 云端环境:利用预置镜像快速搭建训练环境,省去配置时间
- 检查点机制:通过保存模型状态和优化器状态,实现训练进度随时保存与恢复
- 高效训练:采用混合精度、梯度累积等技术最大化有限训练时间的价值
- 灵活中断:不再担心训练中断,可以随时暂停和继续
现在你就可以尝试在自己的项目中实现这一方案,充分利用碎片化时间进行模型训练。
💡获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。