从单卡到多卡:PyTorch分布式训练的核心代码改造指南
当你第一次尝试将PyTorch训练脚本从单卡扩展到多卡时,可能会误以为只需要修改启动命令就万事大吉。然而,真正的挑战在于训练脚本内部的改造。本文将带你深入理解分布式数据并行(DDP)的核心原理,并逐步演示如何将一个典型的单卡训练脚本升级为支持多卡并行的工业级实现。
1. 分布式训练基础概念
在开始代码改造之前,我们需要明确几个关键概念。分布式数据并行(Distributed Data Parallel, DDP)是PyTorch提供的多GPU训练方案,它通过在多个GPU上复制模型,并将数据分片到不同GPU上并行处理来加速训练。
与单卡训练相比,DDP训练有几个显著不同点:
- 模型复制:每个GPU上都有一份完整的模型副本
- 数据分片:数据集被均匀分配到不同GPU上
- 梯度同步:每个GPU独立计算梯度后,通过All-Reduce操作同步梯度
# 单卡训练的基本结构 model = MyModel().to(device) optimizer = torch.optim.Adam(model.parameters()) for epoch in range(epochs): for batch in dataloader: inputs, labels = batch outputs = model(inputs.to(device)) loss = criterion(outputs, labels.to(device)) loss.backward() optimizer.step() optimizer.zero_grad()2. 核心代码改造点
2.1 初始化分布式环境
在DDP训练中,第一步是初始化进程组。这需要在训练脚本的最开始处完成,确保所有进程能够互相通信。
import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP def setup(rank, world_size): # 初始化进程组 dist.init_process_group( backend='nccl', # NVIDIA的通信后端,推荐用于GPU训练 init_method='env://', # 从环境变量获取初始化信息 rank=rank, world_size=world_size ) # 设置当前进程的默认GPU torch.cuda.set_device(rank)注意:
init_method也可以指定为TCP地址(如'tcp://127.0.0.1:1234'),但在torchrun中更推荐使用环境变量方式。
2.2 模型包装为DDP
单卡训练中,我们直接将模型放到GPU上即可。但在DDP中,需要将模型包装为DistributedDataParallel对象。
def prepare_model(model, rank): model = model.to(rank) ddp_model = DDP(model, device_ids=[rank]) return ddp_modelDDP包装器会自动处理:
- 模型参数在进程间的同步
- 前向传播时的数据分发
- 反向传播时的梯度聚合
2.3 数据加载器改造
普通的数据加载器会将完整数据集加载到单个进程中,而DDP需要每个进程只处理数据的一个子集。PyTorch提供了DistributedSampler来实现这一点。
from torch.utils.data.distributed import DistributedSampler def prepare_dataloader(dataset, batch_size, rank, world_size): sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=True ) loader = DataLoader( dataset, batch_size=batch_size, sampler=sampler, num_workers=4, pin_memory=True ) return loader关键参数说明:
| 参数 | 说明 | 推荐值 |
|---|---|---|
| num_replicas | 参与训练的进程总数 | world_size |
| rank | 当前进程的序号 | 0到world_size-1 |
| shuffle | 是否打乱数据顺序 | True(训练集)/False(验证集) |
2.4 训练循环调整
在DDP训练中,每个epoch开始时需要调用sampler的set_epoch方法,确保每个epoch的数据划分不同。
def train(ddp_model, train_loader, optimizer, criterion, epoch, rank): ddp_model.train() train_loader.sampler.set_epoch(epoch) # 重要! for batch_idx, (inputs, labels) in enumerate(train_loader): inputs = inputs.to(rank, non_blocking=True) labels = labels.to(rank, non_blocking=True) optimizer.zero_grad() outputs = ddp_model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step()2.5 模型保存与加载
在多卡训练中,我们需要避免每个进程都保存一次模型。通常只在rank 0进程上保存即可。
def save_checkpoint(model, optimizer, epoch, filename, rank): if rank == 0: # 只在主进程保存 checkpoint = { 'model_state_dict': model.module.state_dict(), # 注意.module 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch } torch.save(checkpoint, filename)加载检查点时,需要先加载到适当的设备上:
def load_checkpoint(filename, model, optimizer, rank): map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} # 将rank 0的参数映射到当前rank checkpoint = torch.load(filename, map_location=map_location) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) return checkpoint['epoch']3. 完整代码对比
让我们看一个完整的单卡与多卡训练脚本的对比。假设我们有一个简单的图像分类任务。
3.1 单卡训练脚本
# train_single_gpu.py import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms # 1. 准备数据 transform = transforms.Compose([...]) train_dataset = datasets.ImageFolder('data/train', transform=transform) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) # 2. 定义模型 model = MyModel().cuda() optimizer = torch.optim.Adam(model.parameters()) criterion = torch.nn.CrossEntropyLoss() # 3. 训练循环 for epoch in range(100): model.train() for inputs, labels in train_loader: inputs, labels = inputs.cuda(), labels.cuda() optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 保存模型 torch.save(model.state_dict(), f'checkpoint_{epoch}.pth')3.2 多卡训练脚本
# train_multi_gpu.py import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler from torchvision import datasets, transforms def main(rank, world_size): # 1. 初始化分布式环境 setup(rank, world_size) # 2. 准备数据 transform = transforms.Compose([...]) train_dataset = datasets.ImageFolder('data/train', transform=transform) train_loader = prepare_dataloader(train_dataset, 32, rank, world_size) # 3. 定义模型 model = MyModel() ddp_model = prepare_model(model, rank) optimizer = torch.optim.Adam(ddp_model.parameters()) criterion = torch.nn.CrossEntropyLoss() # 4. 训练循环 for epoch in range(100): train(ddp_model, train_loader, optimizer, criterion, epoch, rank) save_checkpoint(ddp_model, optimizer, epoch, f'checkpoint_{epoch}.pth', rank) # 5. 清理 dist.destroy_process_group() if __name__ == '__main__': import os rank = int(os.environ['RANK']) world_size = int(os.environ['WORLD_SIZE']) main(rank, world_size)4. 常见问题与调试技巧
4.1 内存不足问题
当使用多卡时,每个GPU上的batch size会减小,但总的内存消耗会增加。常见的内存问题包括:
- CUDA out of memory:尝试减小每个GPU上的batch size
- CPU内存不足:减少DataLoader的num_workers数量
4.2 性能优化技巧
- 使用pin_memory:在DataLoader中设置
pin_memory=True可以加速CPU到GPU的数据传输 - 重叠计算与通信:DDP默认会重叠反向传播和梯度同步,无需额外设置
- 梯度累积:当GPU内存有限时,可以通过多次小batch的前后向传播累积梯度后再更新参数
accum_steps = 4 # 累积4个batch的梯度 optimizer.zero_grad() for i, (inputs, labels) in enumerate(train_loader): outputs = model(inputs) loss = criterion(outputs, labels) / accum_steps # 注意除以累积步数 loss.backward() if (i + 1) % accum_steps == 0: optimizer.step() optimizer.zero_grad()4.3 调试分布式训练
调试DDP训练可能比较困难,因为错误可能只出现在特定rank上。一些有用的技巧:
- 限制rank 0输出:使用
if rank == 0:包装print语句 - 同步调试:在关键位置添加
dist.barrier()确保所有进程同步 - 单进程调试:可以先在单进程模式下测试脚本是否正确
# 临时禁用DDP进行调试 ddp_model = model if world_size == 1 else DDP(model, device_ids=[rank])5. 进阶话题
5.1 混合精度训练
结合DDP和混合精度训练可以进一步提升训练速度和减少内存占用。
from torch.cuda.amp import GradScaler, autocast scaler = GradScaler() for inputs, labels in train_loader: inputs, labels = inputs.to(rank), labels.to(rank) optimizer.zero_grad() with autocast(): outputs = model(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()5.2 梯度裁剪
在分布式训练中,梯度裁剪需要在All-Reduce之后进行。
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)5.3 自定义分布式操作
有时你可能需要自定义跨进程的操作,PyTorch提供了多种集体通信原语:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)dist.broadcast(tensor, src)dist.all_gather(tensor_list, tensor)
例如,计算所有进程上的平均损失:
def reduce_tensor(tensor, world_size): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= world_size return rt loss = reduce_tensor(loss, world_size)在实际项目中,我经常发现DDP训练初期最容易出错的地方是数据划分和模型保存。特别是当数据集不能被world_size整除时,DistributedSampler的行为需要特别注意。另一个常见陷阱是忘记从DDP模型中提取原始模型(.module)进行保存或评估。