PyTorch 模型并行策略:数据并行 vs 模型并行
1. 技术分析
1.1 并行策略对比
| 策略 | 描述 | 适用场景 | 通信开销 |
|---|
| 数据并行 | 数据拆分到多个 GPU | 模型小,数据量大 | 低 |
| 模型并行 | 模型拆分到多个 GPU | 模型大,无法单卡容纳 | 高 |
| 混合并行 | 同时使用数据和模型并行 | 超大模型训练 | 中 |
1.2 数据并行架构
数据并行 (Data Parallelism) ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ │ Model A │ │ Model A │ │ Model A │ │ Batch 0 │ │ Batch 1 │ │ Batch 2 │ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │ │ │ └──────────────────┼──────────────────┘ ▼ ┌─────────────┐ │ All-Reduce│ └─────────────┘
1.3 模型并行架构
模型并行 (Model Parallelism) ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ GPU 0 │───▶│ GPU 1 │───▶│ GPU 2 │ │ Layer 1 │ │ Layer 2 │ │ Layer 3 │ │ Full Batch │ │ Full Batch │ │ Full Batch │ └─────────────┘ └─────────────┘ └─────────────┘
2. 核心功能实现
2.1 数据并行实现
import torch import torch.nn as nn import torch.nn.functional as F class DataParallelModel(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) self.conv2 = nn.Conv2d(64, 128, kernel_size=3) self.fc = nn.Linear(128 * 10 * 10, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2) x = x.view(x.size(0), -1) x = self.fc(x) return x def train_data_parallel(): model = DataParallelModel() model = nn.DataParallel(model) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) loss_fn = nn.CrossEntropyLoss() for epoch in range(10): inputs = torch.randn(64, 3, 32, 32).to(device) targets = torch.randint(0, 10, (64,)).to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_fn(outputs, targets) loss.backward() optimizer.step() class CustomDataParallel(nn.Module): def __init__(self, module): super().__init__() self.module = module self.devices = list(range(torch.cuda.device_count())) self.replicas = nn.ModuleList([module.to(d) for d in self.devices]) def forward(self, x): batch_size = x.size(0) chunks = torch.chunk(x, len(self.devices)) outputs = [] for i, device in enumerate(self.devices): outputs.append(self.replicas[i](chunks[i].to(device))) return torch.cat(outputs, dim=0)
2.2 模型并行实现
class ModelParallelModel(nn.Module): def __init__(self): super().__init__() self.part1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3), nn.ReLU(), nn.MaxPool2d(2) ).to('cuda:0') self.part2 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3), nn.ReLU(), nn.MaxPool2d(2) ).to('cuda:1') self.part3 = nn.Sequential( nn.Linear(128 * 6 * 6, 256), nn.ReLU(), nn.Linear(256, 10) ).to('cuda:2') def forward(self, x): x = x.to('cuda:0') x = self.part1(x) x = x.to('cuda:1') x = self.part2(x) x = x.to('cuda:2') x = x.view(x.size(0), -1) x = self.part3(x) return x class PipelineParallelModel(nn.Module): def __init__(self, chunks=4): super().__init__() self.chunks = chunks self.stage1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3), nn.ReLU() ).to('cuda:0') self.stage2 = nn.Sequential( nn.MaxPool2d(2), nn.Conv2d(64, 128, kernel_size=3), nn.ReLU() ).to('cuda:1') self.stage3 = nn.Sequential( nn.MaxPool2d(2), nn.Flatten(), nn.Linear(128 * 6 * 6, 10) ).to('cuda:2') def forward(self, x): batch_size = x.size(0) chunk_size = batch_size // self.chunks outputs = [] for i in range(self.chunks): start = i * chunk_size end = start + chunk_size if i < self.chunks - 1 else batch_size chunk = x[start:end].to('cuda:0') chunk = self.stage1(chunk) chunk = chunk.to('cuda:1') chunk = self.stage2(chunk) chunk = chunk.to('cuda:2') chunk = self.stage3(chunk) outputs.append(chunk) return torch.cat(outputs, dim=0)
2.3 混合并行实现
class HybridParallelModel(nn.Module): def __init__(self): super().__init__() self.layer1 = nn.Sequential( nn.Conv2d(3, 256, kernel_size=3), nn.ReLU() ).to('cuda:0') self.layer2 = nn.Sequential( nn.Conv2d(256, 512, kernel_size=3), nn.ReLU() ).to('cuda:1') self.layer3 = nn.Linear(512 * 28 * 28, 1000).to('cuda:0') self.layer4 = nn.Linear(1000, 10).to('cuda:1') def forward(self, x): x = x.to('cuda:0') x = self.layer1(x) x = x.to('cuda:1') x = self.layer2(x) x = x.to('cuda:0') x = x.view(x.size(0), -1) x = self.layer3(x) x = x.to('cuda:1') x = self.layer4(x) return x class TensorParallelLinear(nn.Module): def __init__(self, in_features, out_features): super().__init__() self.devices = ['cuda:0', 'cuda:1'] self.weight1 = nn.Parameter(torch.randn(out_features // 2, in_features).to(self.devices[0])) self.bias1 = nn.Parameter(torch.randn(out_features // 2).to(self.devices[0])) self.weight2 = nn.Parameter(torch.randn(out_features // 2, in_features).to(self.devices[1])) self.bias2 = nn.Parameter(torch.randn(out_features // 2).to(self.devices[1])) def forward(self, x): x1 = x.to(self.devices[0]) x2 = x.to(self.devices[1]) y1 = F.linear(x1, self.weight1, self.bias1) y2 = F.linear(x2, self.weight2, self.bias2) return torch.cat([y1, y2], dim=1)
2.4 分布式数据并行
import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP def setup_distributed(): dist.init_process_group(backend='nccl') local_rank = dist.get_rank() torch.cuda.set_device(local_rank) return local_rank def train_distributed(): local_rank = setup_distributed() model = DataParallelModel().to(local_rank) model = DDP(model, device_ids=[local_rank]) sampler = torch.utils.data.distributed.DistributedSampler(dataset) loader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(10): sampler.set_epoch(epoch) for inputs, targets in loader: inputs = inputs.to(local_rank) targets = targets.to(local_rank) optimizer.zero_grad() outputs = model(inputs) loss = F.cross_entropy(outputs, targets) loss.backward() optimizer.step() dist.destroy_process_group()
3. 性能对比
3.1 并行策略性能
| 策略 | GPU数量 | 训练速度 | 内存占用 | 通信开销 |
|---|
| 数据并行 | 4 | 3.5x | 4x | 低 |
| 模型并行 | 4 | 2x | 0.25x | 高 |
| 混合并行 | 4 | 3x | 0.5x | 中 |
| 单卡训练 | 1 | 1x | 1x | 无 |
3.2 数据并行 vs 分布式数据并行
| 特性 | DataParallel | DDP |
|---|
| 启动方式 | 单进程多线程 | 多进程 |
| 内存均衡 | 不均衡 | 均衡 |
| 通信效率 | 中 | 高 |
| 扩展性 | 中等 | 优秀 |
3.3 模型并行通信开销
| 操作 | 数据并行 | 模型并行 | 混合并行 |
|---|
| All-Reduce | 是 | 否 | 是 |
| Point-to-Point | 否 | 是 | 是 |
| 通信量 | 中 | 高 | 高 |
| 延迟 | 低 | 高 | 中 |
4. 最佳实践
4.1 并行策略选择
class ParallelStrategySelector: def __init__(self, model_size, batch_size, num_gpus): self.model_size = model_size self.batch_size = batch_size self.num_gpus = num_gpus def select_strategy(self): if self.model_size < 1e9: if self.num_gpus > 1: return 'data_parallel' return 'single_gpu' elif self.model_size < 1e10: return 'model_parallel' else: return 'hybrid_parallel' class ParallelModelFactory: @staticmethod def create(model_class, strategy='data_parallel'): model = model_class() if strategy == 'data_parallel': model = nn.DataParallel(model) elif strategy == 'distributed': local_rank = dist.get_rank() model = DDP(model.to(local_rank), device_ids=[local_rank]) elif strategy == 'model_parallel': model = ModelParallelWrapper(model) return model
4.2 梯度累积优化
class GradientAccumulation: def __init__(self, model, optimizer, accumulation_steps=4): self.model = model self.optimizer = optimizer self.accumulation_steps = accumulation_steps self.step_count = 0 def step(self, loss): loss = loss / self.accumulation_steps loss.backward() self.step_count += 1 if self.step_count % self.accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() self.step_count = 0
5. 总结
PyTorch 并行策略是训练大规模模型的关键:
- 数据并行:适合数据量大、模型较小的场景
- 模型并行:适合模型太大无法单卡容纳的场景
- 混合并行:适合超大规模模型训练
- 分布式训练:适合多节点训练
对比数据如下:
- 数据并行在 4 卡上可获得约 3.5 倍加速
- 模型并行可将内存占用降低 75%
- DDP 比 DataParallel 通信效率更高
- 混合并行在超大模型上表现最佳