往期相关文章:
vllm prefill 和 decode 的kernel代码解读
在11月16日,sglang 合并了对MLA的dp支持,工作还是很棒的,pr链接为:https://github.com/sgl-project/sglang/pull/1970,在12月4日,sglang发布了V0.4.0版本,并发了一篇博客,其中对dp的工作原理进行了详细的介绍。
https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models
会在 mla 之前去把数据进行拆分,然后在完成mla 之后在moe之前进行allgather,moe 使用tp并行,文中也提到了ep的思路,但pr还没合并。整体效果还是非常不错的。
为什么要用dp呢,其实是 mla 如果不用dp,那么似乎就会退化成 MQA,做tp会在每个rank上都有一个kv,那这样其实是浪费了一些显存,那与其这样,莫不如就让它在每个rank上有个replica,然后把数据切分开,这样不就ok了吗。
接下来,我门就看看这个dp的代码是怎样的吧,pr链接为:https://github.com/sgl-project/sglang/pull/1970。
1、数据拆分与结果合并
结果合并的代码比较好找,在python/sglang/srt/models/deepseek_v2.py 文件中
if self.enable_dp_attention:
hidden_states, start_idx, end_idx = all_gather(
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states[start_idx:end_idx]
这里调用 all_gather 来将各个 dp rank 做 mla 之后的结果 allgather过来来做后边的 moe ,同时会记住本rank负责的请求的 start_idx 和 end_idx。
在完成 moe 的 mlp 层之后,会通过 hidden_states[start_idx:end_idx]
来拿到本rank的 请求,执行下一层的 mla。
all_gather 的实现如下,
def all_gather(
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
):
if world_size == 1:
return input_tensor
all_lens = forward_batch.global_num_tokens
max_len = max(forward_batch.global_num_tokens)
padded_tensor = torch.nn.functional.pad(
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
)
torch.distributed.all_gather_into_tensor(
forward_batch.gathered_buffer, padded_tensor, group=group
)
gathered_tensors = torch.concat(
[
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
for i in range(world_size)
]
)
start_index = 0 if rank == 0 else sum(all_lens[:rank])
end_index = start_index + all_lens[rank]
return gathered_tensors, start_index, end_index
这里有一些注意点:
- global_num_tokens 是一个 list,收集了每个 rank 的 token 数量。然后会根据最大token数对数据进行 pad,
- all_gather要求所有进程的tensor大小必须相同, 所以需要将每个进程的tensor都pad到最大长度。
input_tensor 的 shape 为 [num_tokens, hidden_dim]
pad = (0, 0, # 最后一维(hidden_dim)不做padding
0, max_len - input_tensor.shape[0]) # 第一维(seq_len)只在末尾padding
-
padding后的shape会变成:
[max_len, hidden_dim]
-
gather 之后要先通过
[i * max_len : i * max_len + all_lens[i]]
切片来切出有效的数据部分,扔掉padding的部分。取出有效部分之后再做 concat。 -
返回的时候会把 本rank token 开始和结束的行号也返回,方便昨晚moe MLP 之后还原来做后边层的MLA。
2、数据是怎么切分的
在 python/sglang/srt/managers/data_parallel_controller.py 文件中,每次拿到数据都会通过 round_robin 的方式轮循到下一个worker,然后通过zmq 发给它
def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
def event_loop(self):
while True:
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
if isinstance(
recv_req,
(
TokenizedGenerateReqInput,
TokenizedEmbeddingReqInput,
),
):
self.dispatching(recv_req)
else:
# Send other control messages to all workers
for worker in self.workers:
worker.send_pyobj(recv_req)
在 python/sglang/srt/managers/scheduler.py 文件中,会拿到分发来的数据,然后进行调度
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name
)
需要注意的是 在准备下一个batch的时候,会调用prepare_dp_attn_batch
if self.server_args.enable_dp_attention:
ret = self.prepare_dp_attn_batch(ret)
return ret
而第一部分说的 allgather 中的 global_num_tokens 就是从这里来的。
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
else:
num_tokens = local_batch.extend_num_tokens
local_num_tokens = torch.tensor(
num_tokens, dtype=torch.int64, device=self.device
)
global_num_tokens = torch.empty(
self.tp_size, dtype=torch.int64, device=self.device
)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_worker.get_tp_device_group(),
)
....
从宏观来看,在整个模型进行一次forward 之前,会获得每个rank上拿到的 token数,然后来进程N层的mla 和 moe 时会反复复用这个 token 数,所以会把对token数对 allgather 放在这里,跟实际数据的 allgather 不在一块。