往期相关文章:

sglang的模型执行

sglang 的 MLA 代码跟踪

deepseek V2 MLA 的理解

vllm prefill 和 decode 的kernel代码解读

在11月16日,sglang 合并了对MLA的dp支持,工作还是很棒的,pr链接为:https://github.com/sgl-project/sglang/pull/1970,在12月4日,sglang发布了V0.4.0版本,并发了一篇博客,其中对dp的工作原理进行了详细的介绍。

https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models

image

会在 mla 之前去把数据进行拆分,然后在完成mla 之后在moe之前进行allgather,moe 使用tp并行,文中也提到了ep的思路,但pr还没合并。整体效果还是非常不错的。

image

为什么要用dp呢,其实是 mla 如果不用dp,那么似乎就会退化成 MQA,做tp会在每个rank上都有一个kv,那这样其实是浪费了一些显存,那与其这样,莫不如就让它在每个rank上有个replica,然后把数据切分开,这样不就ok了吗。

接下来,我门就看看这个dp的代码是怎样的吧,pr链接为:https://github.com/sgl-project/sglang/pull/1970。

1、数据拆分与结果合并

结果合并的代码比较好找,在python/sglang/srt/models/deepseek_v2.py 文件中

    if self.enable_dp_attention:
        hidden_states, start_idx, end_idx = all_gather(
            hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
        )
        hidden_states = self.mlp(hidden_states)
        hidden_states = hidden_states[start_idx:end_idx]

这里调用 all_gather 来将各个 dp rank 做 mla 之后的结果 allgather过来来做后边的 moe ,同时会记住本rank负责的请求的 start_idx 和 end_idx。

在完成 moe 的 mlp 层之后,会通过 hidden_states[start_idx:end_idx]​ 来拿到本rank的 请求,执行下一层的 mla。

all_gather 的实现如下,

def all_gather(
    input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
):
    if world_size == 1:
        return input_tensor

    all_lens = forward_batch.global_num_tokens
    max_len = max(forward_batch.global_num_tokens)

    padded_tensor = torch.nn.functional.pad(
        input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
    )

    torch.distributed.all_gather_into_tensor(
        forward_batch.gathered_buffer, padded_tensor, group=group
    )

    gathered_tensors = torch.concat(
        [
            forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
            for i in range(world_size)
        ]
    )

    start_index = 0 if rank == 0 else sum(all_lens[:rank])
    end_index = start_index + all_lens[rank]

    return gathered_tensors, start_index, end_index

这里有一些注意点:

  • global_num_tokens 是一个 list,收集了每个 rank 的 token 数量。然后会根据最大token数对数据进行 pad,
  • all_gather要求所有进程的tensor大小必须相同, 所以需要将每个进程的tensor都pad到最大长度。

input_tensor 的 shape 为 [num_tokens, hidden_dim]

pad = (0, 0,        # 最后一维(hidden_dim)不做padding
       0, max_len - input_tensor.shape[0])  # 第一维(seq_len)只在末尾padding
  • padding后的shape会变成:[max_len, hidden_dim]

  • gather 之后要先通过[i * max_len : i * max_len + all_lens[i]]​ 切片来切出有效的数据部分,扔掉padding的部分。取出有效部分之后再做 concat。

  • 返回的时候会把 本rank token 开始和结束的行号也返回,方便昨晚moe MLP 之后还原来做后边层的MLA。

2、数据是怎么切分的

在 python/sglang/srt/managers/data_parallel_controller.py 文件中,每次拿到数据都会通过 round_robin 的方式轮循到下一个worker,然后通过zmq 发给它

    def round_robin_scheduler(self, req):
        self.workers[self.round_robin_counter].send_pyobj(req)
        self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)

    def event_loop(self):
        while True:
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                if isinstance(
                    recv_req,
                    (
                        TokenizedGenerateReqInput,
                        TokenizedEmbeddingReqInput,
                    ),
                ):
                    self.dispatching(recv_req)
                else:
                    # Send other control messages to all workers
                    for worker in self.workers:
                        worker.send_pyobj(recv_req)

在 python/sglang/srt/managers/scheduler.py 文件中,会拿到分发来的数据,然后进行调度

		if self.tp_rank == 0 or self.server_args.enable_dp_attention:
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name
            )

需要注意的是 在准备下一个batch的时候,会调用prepare_dp_attn_batch

        if self.server_args.enable_dp_attention:
            ret = self.prepare_dp_attn_batch(ret)

        return ret

而第一部分说的 allgather 中的 global_num_tokens 就是从这里来的。

    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

        local_num_tokens = torch.tensor(
            num_tokens, dtype=torch.int64, device=self.device
        )
        global_num_tokens = torch.empty(
            self.tp_size, dtype=torch.int64, device=self.device
        )
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
            group=self.tp_worker.get_tp_device_group(),
        )
		....

从宏观来看,在整个模型进行一次forward 之前,会获得每个rank上拿到的 token数,然后来进程N层的mla 和 moe 时会反复复用这个 token 数,所以会把对token数对 allgather 放在这里,跟实际数据的 allgather 不在一块。