专家并行
核心思想:专家并行 = 专家分布 + 动态路由 + All2All 通信
1 专家并行概述
专家并行的目标是将一个 MoE 层中的众多专家分布到不同的设备上,每个设备负责一部分专家。如果某个设备上的计算需要其他设备的专家,可以通过All2All通信实现。
专家并行的思想来自于论文:GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding,如下图所示。

具体来说,MoE模型通常使用 Gating 模块来决定每个输入数据样本应该由哪些专家来处理。假设有一个输入数据样本位于设备 A 上,而 Gating 模块决定该样本应该由设备 B 和设备 C 上的专家来处理,那么就需要将该数据样本从设备 A 传输到设备 B 和设备 C。
2 All2All通信
All2All是分布式计算、并行计算和高性能计算中的一种核心通信模式。
在一个由 N 个节点组成的群体中,每一个节点都需要向其他 所有 N-1 个节点发送一份不同的数据,同时也需要从其他 所有 N-1 个节点接收一份不同的数据。
即“所有人给所有人发消息”。如下图所示,看起来很像一个矩阵的转置操作。

1.1 标准 All2All
标准All2All即均匀发送和接受数据,发送到和接受自不同设备的数据量相同。
可以使用 torch.distributed
实现。
import torch
import torch.distributed as dist
def run_standard_all2all():
rank = dist.get_rank() # rank代表进程
size = dist.get_world_size() # size代表分布式组中参与进程的总数
device = torch.device(f"cuda:{rank}")
# 在每个rank上创建对应的input_tensor
input_tensor = torch.ones(size, dtype=torch.int32, device=device) * rank
print(f"Rank {rank} before all2all, input_tensor: {input_tensor.tolist()}")
output_tensor = torch.empty(size, dtype=torch.int32, device=device)
dist.all_to_all_single(output_tensor, input_tensor)
print(f"Rank {rank} after all2all, output_tensor: {output_tensor.tolist()}")
def main():
dist.init_process_group(backend='nccl') # 初始化分布式环境
run_standard_all2all()
dist.destroy_process_group() # 销毁分布式环境
if __name__ == "__main__":
main()
假设文件名为 test.py
,启动命令如下:
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=127.0.0.1 --master_port=29500 test.py
输出结果如下:
Rank 0 before all2all, input_tensor: [0, 0, 0, 0]
Rank 2 before all2all, input_tensor: [2, 2, 2, 2]
Rank 1 before all2all, input_tensor: [1, 1, 1, 1]
Rank 3 before all2all, input_tensor: [3, 3, 3, 3]
Rank 1 after all2all, output_tensor: [0, 1, 2, 3]
Rank 2 after all2all, output_tensor: [0, 1, 2, 3]
Rank 3 after all2all, output_tensor: [0, 1, 2, 3]
Rank 0 after all2all, output_tensor: [0, 1, 2, 3]
1.2 非标准 All2All
实际上有些场景并非均匀发送和接收,有可能发送到不同设备的数据量不同,从不同设备接收的数据量也可能不同。Pytorch 的 torch.distributed.all_to_all_single
提供了 input_split_sizes
和 output_split_sizes
参数来支持:
input_split_sizes
表示向每个设备发送的数据量。output_split_sizes
表示从每个设备接收的数据量。
假设有4个GPU,每个GPU包含10个数据:
4 个 GPU 都向 GPU k 发送 k+1 个数据
即,都向 GPU 0 发送 1 条数据,向 GPU 3 发送 4 条数据
GPU k 从其余每个 GPU 都接收 k+1 个数据
即,GPU 0 从其余每个 GPU 接收 1 条数据,共接收 3 条
如下图所示:

代码实现如下:
import torch
import torch.distributed as dist
def run_nonstandard_all2all():
rank = dist.get_rank()
size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")
input_splits = [i+1 for i in range(size)]
input_tensor = torch.ones(sum(input_splits), dtype=torch.int32, device=device) * rank
print(f"Rank {rank} before all2all, input_tensor: {input_tensor.tolist()}")
output_splits = [rank + 1] * size
output_tensor = torch.empty(sum(output_splits), dtype=torch.int32, device=device)
dist.all_to_all_single(output_tensor, input_tensor, output_splits, input_splits)
print(f"Rank {rank} after all2all, output_tensor: {output_tensor.tolist()}")
def main():
dist.init_process_group(backend='nccl') # 初始化分布式环境
run_nonstandard_all2all()
dist.destroy_process_group() # 销毁分布式环境
if __name__ == "__main__":
main()
输出如下:
Rank 3 before all2all, input_tensor: [3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
Rank 1 before all2all, input_tensor: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Rank 2 before all2all, input_tensor: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
Rank 0 before all2all, input_tensor: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Rank 2 after all2all, output_tensor: [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
Rank 3 after all2all, output_tensor: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]
Rank 0 after all2all, output_tensor: [0, 1, 2, 3]
Rank 1 after all2all, output_tensor: [0, 0, 1, 1, 2, 2, 3, 3]
1.3 两次 All2All
上述非标准 All2All 中有个问题:有些时候当前设备只知道要向其他设备发送多少数据,而并不知道需要从其他设备接收多少数据。
这个问题可以通过 2 次 all2all 来解决:
第一次 all2all 交换要传输的数据量信息,这是一个标准的 all2all 操作。
第二次 all2all 根据上述获取的数据量信息来执行真正的数据传输,此时是一个非标准 all2all 操作。
代码如下:
def run_all2al_twice():
rank = dist.get_rank()
size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")
# 第一次 all_to_all
input_splits = [i + 1 for i in range(size)]
input_tensor = torch.ones(sum(input_splits), dtype=torch.int32, device=device) * rank
print(f"Rank {rank} before first all2all, input_tensor: {input_tensor.tolist()}")
input_splits_pt = torch.tensor(input_splits, dtype=torch.int32, device=device)
output_splits_pt = torch.empty(size, dtype=torch.int32, device=device)
dist.all_to_all_single(output_splits_pt, input_splits_pt)
output_splits = output_splits_pt.tolist()
print(f"Rank {rank} after first all2all, output_splits: {output_splits}")
output_tensor = torch.empty(sum(output_splits), dtype=torch.int32, device=device)
dist.all_to_all_single(output_tensor, input_tensor, output_splits, input_splits)
print(f"Rank {rank} after first all2all, output_tensor: {output_tensor.tolist()}")
def main():
dist.init_process_group(backend='nccl') # 初始化分布式环境
run_all2al_twice()
dist.destroy_process_group() # 销毁分布式环境
if __name__ == "__main__":
main()
输出如下:
Rank 3 before first all2all, input_tensor: [3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
Rank 1 before first all2all, input_tensor: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Rank 0 before first all2all, input_tensor: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Rank 2 before first all2all, input_tensor: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
Rank 0 after first all2all, output_splits: [1, 1, 1, 1]
Rank 1 after first all2all, output_splits: [2, 2, 2, 2]
Rank 3 after first all2all, output_splits: [4, 4, 4, 4]
Rank 2 after first all2all, output_splits: [3, 3, 3, 3]
Rank 3 after first all2all, output_tensor: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]
Rank 0 after first all2all, output_tensor: [0, 1, 2, 3]
Rank 1 after first all2all, output_tensor: [0, 0, 1, 1, 2, 2, 3, 3]
Rank 2 after first all2all, output_tensor: [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
3 专家并行分片示例
参考DeepSeek的PEFT相关工作: ESFT/train_ep.py at main · deepseek-ai/ESFT
并行组划分:
def init_parallel_groups(ep_size=1):
dist.init_process_group("nccl")
world_size = int(os.getenv("WORLD_SIZE", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
ep_group = edp_group = None
for i in range(0, world_size, ep_size):
ranks = list(range(i, i + ep_size))
group = dist.new_group(ranks)
if local_rank in ranks:
ep_group = group
edp_group = None
for i in range(ep_size):
ranks = list(range(i, world_size, ep_size))
group = dist.new_group(ranks)
if local_rank in ranks:
edp_group = group
dist.all_reduce(torch.zeros(1, device="cuda"), group=ep_group)
dist.all_reduce(torch.zeros(1, device="cuda"), group=edp_group)
return world_size, local_rank, ep_group, edp_group
world_size
:全局 GPU 总数(所有节点)local_rank
:当前 GPU 在节点内的本地编号(0~N-1)ep_group
:专家并行组(Expert Parallelism Group)edp_group
:专家数据并行组(Expert Data Parallelism Group)
通信组划分逻辑
假设有 8 个 GPU(world_size=8),ep_size=2(每个专家组包含 2 个 GPU):
专家并行组(ep_group)划分
for i in range(0, world_size, ep_size): ranks = list(range(i, i + ep_size))
组0:GPU [0, 1] → 共同处理专家A
组1:GPU [2, 3] → 共同处理专家B
组2:GPU [4, 5] → 共同处理专家C
组3:GPU [6, 7] → 共同处理专家D
专家数据并行组(edp_group)划分
for i in range(ep_size): ranks = list(range(i, world_size, ep_size))
组0:GPU [0, 2, 4, 6] → 专家A的不同数据分片
组1:GPU [1, 3, 5, 7] → 专家B的不同数据分片
通信验证
dist.all_reduce(torch.zeros(1), group=ep_group) dist.all_reduce(torch.zeros(1), group=edp_group)
测试两个通信组是否正常工作
确保所有 GPU 都能在各自组内通信
4 发展
时间
研究团队
技术名称
主要改进点
适用场景
参考文献
参考
Last updated