Tensor Parallelism基础:模型切分原理
在大语言模型参数量突破千亿的今天,一个典型的LLM推理任务可能需要超过300GB显存——这几乎是8张NVIDIA A100的总和。面对这种现实挑战,单卡训练早已成为过去式。如何让模型“跨设备生长”,而不是被显存墙拦腰截断?答案正是张量并行(Tensor Parallelism)。
它不像数据并行那样简单复制模型,而是深入到矩阵乘法内部,把一个原本无法容纳的运算拆解成多个小块,在多张GPU上协同完成。这种细粒度的模型切分方式,是当前百亿、千亿级模型能够落地的核心支撑技术之一。
从矩阵乘法说起:张量并行的本质
Transformer架构中,最消耗显存和计算资源的操作集中在两个地方:注意力机制中的QKV投影和前馈网络(FFN)中的升维/降维层。这些操作本质上都是大规模矩阵乘法,例如 $ Y = X \cdot W $。
假设输入 $ X \in \mathbb{R}^{b \times d} $,权重 $ W \in \mathbb{R}^{d \times h} $,输出 $ Y \in \mathbb{R}^{b \times h} $。当 $ h $ 达到上万甚至数十万时,单个全连接层的参数就能轻松超过10GB。此时,我们不能再指望整块计算由一张卡独立完成。
张量并行的关键思路是:将这个矩阵$W$按维度切开,让每张卡只处理一部分,再通过通信合并结果。根据切分方向不同,主要分为两种模式:
列切分(Column-wise Splitting)
将 $ W $ 按列拆成 $ W_1, W_2 $,即每个设备负责输出的一部分维度:
- GPU0: $ Y_1 = X \cdot W_1 $
- GPU1: $ Y_2 = X \cdot W_2 $
由于输入 $ X $ 在所有设备上相同,局部输出 $ Y_i $ 只包含最终输出的部分维度。要得到完整 $ Y $,需要通过AllGather将各卡输出拼接起来:
$$ Y = [Y_1, Y_2] $$
这种方式常用于 FFN 层的升维操作(up-projection),比如将隐藏维度从 d 扩展到 4d。Megatron-LM 中称之为Column Linear。
行切分(Row-wise Splitting)
将 $ W $ 按行拆分,等价于对输入空间进行划分:
- GPU0: $ Y_1 = X \cdot W_1^T $
- GPU1: $ Y_2 = X \cdot W_2^T $
注意,这里的每一部分输出都已经是完整维度的向量。因此不需要拼接,而是通过AllReduce(SUM)直接求和即可:
$$ Y = Y_1 + Y_2 $$
这种方法避免了高维输出的传输与拼接,特别适合降维层或注意力输出合并后的线性变换。在 Megatron 中被称为Row Linear。
这两种策略往往组合使用,形成所谓的“2D张量并行”结构。例如在一个FFN层中:
- 升维用列并行→ 输出分散在各卡;
- 激活函数后,降维用行并行→ 各卡独立计算后再 AllReduce 求和。
这样既实现了参数切分,又控制了通信开销。
实现细节:不只是算子拆分
虽然听起来只是“切矩阵+通信”,但实际实现远比想象复杂。尤其是在反向传播阶段,梯度也需要正确地分布和同步。
以下是一个简化版的 PyTorch 实现示例,展示了基本通信逻辑:
import torch import torch.distributed as dist from torch.nn import functional as F def column_parallel_linear(x, weight_chunk, bias_chunk, rank, world_size): """ 列并行线性层:每个设备持有 weight 的一部分(按列) 输入 x 在所有设备上相同 输出需 AllGather 拼接 """ # 局部计算: x @ w_i partial_output = F.linear(x, weight_chunk, bias_chunk) # 收集所有设备的输出 [out0, out1] output_list = [torch.zeros_like(partial_output) for _ in range(world_size)] dist.all_gather(output_list, partial_output) # 拼接输出 full_output = torch.cat(output_list, dim=-1) return full_output def row_parallel_linear(x, weight_chunk, bias, rank, world_size): """ 行并行线性层:weight 按行切分 各设备独立计算,最后 AllReduce 求和 """ # 局部计算: x @ w_i^T partial_output = F.linear(x, weight_chunk) # 全局求和 dist.all_reduce(partial_output, op=dist.ReduceOp.SUM) # 加偏置(仅一次) if bias is not None: partial_output += bias return partial_output这里有几个关键点值得强调:
- AllGather vs AllReduce:前者用于拼接(保留所有数据副本),后者用于聚合(如梯度求和)。选择不当会导致显存爆炸或结果错误。
- 偏置项处理:在行并行中,bias 只应在最终求和后加一次;若每张卡都加,会造成重复叠加。
- 通信时机:前向传播结束后必须立即通信,否则后续层无法继续计算。
- 反向传播一致性:本地梯度计算完成后,仍需跨设备同步参数梯度,确保优化器更新一致。
在真实框架(如ms-swift或 Megatron-LM)中,这些逻辑都被封装进自定义Linear层中,用户无需手动管理通信流程。
融入Megatron:构建高效混合并行系统
单纯使用张量并行仍有局限。当设备数超过一定规模(如 >8卡),TP带来的通信开销会显著影响效率。为此,现代训练框架普遍采用三维并行策略—— 将 TP 与 Pipeline Parallelism(PP)、Data Parallelism(DP)结合使用。
以ms-swift框架为例,其底层支持完整的 Megatron 风格并行架构:
# 启动脚本示例:使用 8 卡进行张量并行训练 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 swift train \ --model_type llama3-8b-instruct \ --tensor_parallel_size 8 \ --pipeline_parallel_size 1 \ --data_parallel_size 1 \ --train_dataset alpaca-en \ --max_steps 1000 \ --lora_rank 8 \ --use_megatron true在这个配置中:
--tensor_parallel_size 8表示启用8路张量并行,每张卡承载1/8的模型权重;--use_megatron true触发内置的 Megatron 式算子重写与通信优化;- 即使启用了 LoRA 微调,原始模型仍可通过 TP 切分,兼顾显存节省与微调灵活性。
更进一步,对于超大规模模型(如70B以上),可以开启 PP 分层调度:
--tensor_parallel_size 4 \ --pipeline_parallel_size 2 \ --data_parallel_size 1这意味着整个集群被划分为两组,每组内部做4路TP,两组之间形成流水线阶段。Micro-batch 流水执行可大幅提升设备利用率。
此外,ms-swift还集成了一系列性能增强特性:
- Sequence Parallelism:将序列维度也进行切分,减少 TP 中的 AllReduce 数据量;
- Fused Kernels:使用融合内核(如 FusedMLP、FlashAttention)降低内存访问次数;
- Zero Redundancy Optimizer (ZeRO):结合 DeepSpeed 技术,进一步切分优化器状态和梯度,释放更多显存。
这些技术共同作用,使得千亿级模型可以在数百张GPU上稳定高效训练。
实际部署中的权衡与考量
尽管张量并行功能强大,但在真实场景中仍需谨慎设计。以下是几个常见工程问题及其应对策略:
显存 vs 通信带宽的平衡
TP 虽然降低了单卡显存占用,但引入了频繁的集合通信。如果设备间互联带宽不足(如仅靠 PCIe),通信将成为瓶颈。
建议:
- 使用 NVLink 或 InfiniBand 网络,至少保证设备间双向带宽 ≥150 GB/s;
- 控制 TP 规模:PCIe 环境下建议 ≤4,NVLink 环境下可扩展至8~16。
切分粒度的选择
并非所有模型都需要高倍率TP。盲目切分反而可能导致负载不均或通信开销过大。
| 模型规模 | 推荐并行策略 |
|---|---|
| <13B 参数 | 优先使用 DP + LoRA/Q-LoRA |
| 13B~70B | TP=4~8,视硬件而定 |
| >70B | 必须结合 TP+PP 混合并行 |
例如,在8*A100上微调 Llama3-8B,虽然单卡勉强能跑,但激活值和优化器状态会让显存吃紧。此时使用 TP=8 不仅提升稳定性,还能利用全部显存资源。
多模态模型的支持
随着图文音联合建模兴起,张量并行还需适配非Transformer模块(如视觉编码器、音频卷积层)。幸运的是,ms-swift已实现统一并行化接口,自动识别模型结构并注入切分逻辑,支持包括 CLIP、BLIP、Qwen-VL 等在内的300+多模态模型。
推理阶段的延续性
训练时用了TP,推理也不能掉链子。好在主流推理引擎如vLLM、LmDeploy、SGLang均支持张量并行推理。它们通过连续批处理(continuous batching)和 PagedAttention 技术,在保持低延迟的同时充分利用多卡算力。
更重要的是,量化也不缺席。即使将模型导出为 GPTQ 或 AWQ 格式,TP 推理依然可用,进一步压低部署成本。
架构视角下的协同生态
在ms-swift的整体架构中,张量并行并非孤立存在,而是嵌入在一个完整的分布式训练闭环中:
graph TD A[用户接口层] --> B[训练任务管理器] B --> C[分布式并行执行引擎] C --> D[模型加载与切分模块] D --> E[推理 & 评测 & 量化] subgraph "核心能力" C -->|TP + PP + DP + ZeRO| C D -->|自动识别结构并切分| D end style A fill:#f9f,stroke:#333 style E fill:#bbf,stroke:#333每一环都在为TP服务:
- 任务管理器解析用户命令,决定是否启用 Megatron 模式;
- 并行引擎调度 NCCL/HCCl 通信原语,协调多卡协作;
- 模型切分模块自动分析 Hugging Face 格式的模型结构,精准定位 QKV、FFN 等可切分层;
- 推理模块支持 OpenAI 兼容 API 输出,无缝对接应用系统。
这种端到端的整合能力,大大降低了开发者使用门槛。你不需要懂 CUDA kernel 如何融合,也不必手写通信逻辑,只需一条命令即可启动大规模训练。
写在最后:理解轮子,才能驾驭巨兽
张量并行从来不是一项炫技式的黑科技,它是面对物理限制时的必然选择。当我们谈论“百亿模型能否跑起来”时,本质是在问:“它的计算图能不能被合理切开,并在多设备间高效流转。”
掌握 Tensor Parallelism 的原理,意味着你能判断:
- 什么时候该用 TP,什么时候该用 DP?
- 为什么某些情况下 TP=8 反而比 TP=4 慢?
- 如何排查 AllReduce 死锁或显存溢出?
这些问题的答案,决定了你是被动调参的使用者,还是能主动优化系统的工程师。
正如ms-swift所倡导的:“站在巨人的肩上,走得更远。” 我们不必重复造轮子,但必须理解轮子如何转动——尤其是在驱动千卡集群的时候。