往期相关文章:
vllm prefill 和 decode 的kernel代码解读
sglang 在 12月6日终于合入了ep功能,pr 连接:MoE Expert Parallel,可以拿来学习一下。
开启了一个 enable_ep_moe
的开关,在模型中通过判断是否开启来确定使用什么moe实现
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
具体 ep 实现是在 python/sglang/srt/layers/moe/ep_moe/layer.py,看这代码不知道为啥还有点激动。
EPMoE的核心是将专家分布在不同的 tensor parallel rank上,然后在 token 推理时每个rank 会过滤出属于自己负责的 expert 的 token,处理完之后各rank通过 allreduce 拿到全量的部分。
1、通信部分
ep 前,如果 attention 做了 dp,就会做 allgather 让每个rank把数据都拿到;
代码位于 python/sglang/srt/models/deepseek_v2.py
# Fully Connected
if self.enable_dp_attention:
hidden_states, start_idx, end_idx = all_gather(
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states[start_idx:end_idx]
else:
hidden_states = self.mlp(hidden_states)
ep 每个rank会保留属于自己负责的expert 的数据;
ep 后每个rank会通过 allreduce 来拿到所有数据的处理结果
代码位于:python/sglang/srt/models/deepseek_v2.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# ...
final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits)
if self.tp_size > 1:
# MoE输出需要在不同GPU间做all_reduce
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
核心的处理在 python/sglang/srt/layers/moe/ep_moe/layer.py
2、expert 分配和数据划分
EPMoE将experts平均分配到不同的GPU上
代码位于 python/sglang/srt/layers/moe/ep_moe/layer.py
self.num_experts = num_experts
assert self.num_experts % self.tp_size == 0 # 确保experts可以平均分配
self.num_experts_per_partition = self.num_experts // self.tp_size
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
首先通过router_logits计算每个token应该被哪些experts处理:
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
...
)
对expert ids排序,构建数据结构:
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, self.num_experts)
生成了三个重要的数据结构:
- reorder_topk_ids: 排序后的expert ids
- src2dst: 原始位置到重排序后位置的映射
- seg_indptr: 每个expert处理的token范围的索引
这里跟稀疏架构的请求ps 或alltoall 通信前的操作有点类似,不过稀疏架构更多是为了去重。
重排输入数据:
pre_reorder_triton_kernel(
hidden_states,
gateup_input, # 重排后的输入
src2dst,
topk_ids,
...
)
这个 kernel 完成了关键的数据重排序工作:
- 将输入 token 按照它们请求的专家 ID 重新排序
- 只保留当前 rank 负责的专家对应的请求,过滤掉其他 rank 的数据
3、expert 计算与结果恢复
gateup_output = self.grouped_gemm_runner(
a=gateup_input,
b=self.w13_weight,
c=gateup_output,
batch_size=self.num_experts_per_partition,
使用分组矩阵乘法,每个专家单独处理自己负责的 token:
- seg_indptr_cur_rank 指示每个专家负责的 token 范围
- 每个专家只处理分配给自己的 token
post_reorder_triton_kernel(hidden_states.size(0),),
BLOCK_SIZE=512,
)
这个 kernel 完成了最终的输出重排序:
- 将专家计算的结果重新排回原始 token 顺序
- 应用专家权重进行加权求和
- 只处理当前 rank 负责的专家输出,其他 rank 的专家输出会在各自的 rank 上处理