往期相关文章:

vllm prefill 和 decode 的kernel代码解读

前天7月27日,sglang 发布了 v0.2.5,还发了一篇博客:https://lmsys.org/blog/2024-07-25-sglang-llama3/ ,我们来分析一下这个代码,就像分析vllm 代码一样

代码的起点是在 python/sglang/srt/server.py 中

def launch_server(
    server_args: ServerArgs,
    model_overide_args: Optional[dict] = None,
    pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
    server_args.check_server_args()
	if server_args.dp_size == 1:
        start_process = start_controller_process_single
    else:
        start_process = start_controller_process_multi
    proc_controller = mp.Process(
        target=start_process,
        args=(server_args, port_args, pipe_controller_writer, model_overide_args),
    )

会根据 dp_size 分别调用 start_controller_process_single​ 和 start_controller_process_multi​,简单点我们就来看 single 版本,会构造一个 ControllerSingle​并陷入loop。

python/sglang/srt/managers/controller_single.py

    try:
        controller = ControllerSingle(
            server_args,
            port_args,
            model_overide_args,
            gpu_ids,
            is_data_parallel_worker,
            dp_worker_id,
            queue,
        )
    except Exception:
        pipe_writer.send(get_exception_traceback())
        raise

    pipe_writer.send("init ok")

    try:
        controller.loop_for_forward()

controller 的 loop_for_forward 是一个接收请求并调用 tp_server 进行执行的简单控制器

    def loop_for_forward(self):
        while True:
            if not self.is_dp_worker:
                recv_reqs = self.recv_requests_from_zmq()
            else:
                recv_reqs = self.recv_requests_from_mp_queue()

            if self.tp_size > 1:
                broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)

            out_pyobjs = self.tp_server.exposed_step(recv_reqs)

            for obj in out_pyobjs:
                self.send_to_detokenizer.send_pyobj(obj)

接收消息可以从 zmq 或着 multiprocessing.Queue 中拿数据。

tp_server 是用来真正执行调用 forward 的:

python/sglang/srt/managers/tp_worker.py

    def exposed_step(self, recv_reqs):
        try:
            # Recv requests
            for recv_req in recv_reqs:
                if isinstance(recv_req, TokenizedGenerateReqInput):
                    self.handle_generate_request(recv_req)
                elif isinstance(recv_req, FlushCacheReq):
                    self.flush_cache()
                elif isinstance(recv_req, AbortReq):
                    self.abort_request(recv_req)
                else:
                    raise ValueError(f"Invalid request: {recv_req}")

            # Forward
            self.forward_step()

通过 handle_generate_request 来将请求放入 waiting 队列,forward_step 来处理队列里的请求。

python/sglang/srt/managers/tp_worker.py

    @torch.inference_mode()
    def forward_step(self):
        new_batch = self.get_new_prefill_batch()

        if new_batch is not None:
            # Run a new prefill batch
            self.forward_prefill_batch(new_batch)
        else:
            # Run a decode batch
            if self.running_batch is not None:
                # Run a few decode batches continuously for reducing overhead
                for _ in range(global_config.num_continue_decode_steps):
                    self.num_generated_tokens += len(self.running_batch.reqs)
                    self.forward_decode_batch(self.running_batch)

forward_step 函数会根据 batch 的情况来执行 prefill 或者 decode

如果是 prefill 就会调用 batch 的 prepare_for_extend,然后以 extend 的模式调用 model_runner 的 forward

如果是 decode 模式就会调用 batch 的 prepare_for_decode,然后以 decode 的模式调用 model_runner 的 forward

    def prepare_for_decode(self, input_ids=None):
        if input_ids is None:
            input_ids = [
                r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
            ]

prepare_for_decode 会取每个句子的最后一个 token,因为decode就是以最后一个token去计算然后生成下一个token的。

在 model_runner 里才会真正调用到具体模型的 forward 函数

    def forward_extend(self, batch: Batch):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.EXTEND,
            req_pool_indices=batch.req_pool_indices,
            seq_lens=batch.seq_lens,
            prefix_lens=batch.prefix_lens,
            position_ids_offsets=batch.position_ids_offsets,
            out_cache_loc=batch.out_cache_loc,
            top_logprobs_nums=batch.top_logprobs_nums,
            return_logprob=batch.return_logprob,
        )
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
        )

在模型中会调用对应的 attention 的 forward,对于 decode 和 extend 都会先把本步的 kv 保存在kvcache 中,然后调用各自的 kernel.

python/sglang/srt/layers/radix_attention.py

    def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
        o = torch.empty_like(q)
        self.store_kv_cache(k, v, input_metadata)
        extend_attention_fwd(
            q.view(-1, self.tp_q_head_num, self.head_dim),
            k.contiguous(),
            v.contiguous(),
            o.view(-1, self.tp_q_head_num, self.head_dim),
            input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
            input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
            input_metadata.req_to_token_pool.req_to_token,
            input_metadata.req_pool_indices,
            input_metadata.triton_start_loc,
            input_metadata.seq_lens,
            input_metadata.triton_prefix_lens,
            input_metadata.extend_start_loc,
            input_metadata.extend_seq_lens,
            input_metadata.triton_max_seq_len,
            input_metadata.triton_max_extend_len,
            sm_scale=self.scaling,
            logit_cap=self.logit_cap,
        )

        return o

    def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
        o = torch.empty_like(q)
        self.store_kv_cache(k, v, input_metadata)

        token_attention_fwd(
            q.view(-1, self.tp_q_head_num, self.head_dim),
            input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
            input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
            o.view(-1, self.tp_q_head_num, self.head_dim),
            input_metadata.req_to_token_pool.req_to_token,
            input_metadata.req_pool_indices,
            input_metadata.triton_start_loc,
            input_metadata.seq_lens,
            input_metadata.triton_max_seq_len,
            input_metadata.total_num_tokens,
            sm_scale=self.scaling,
            logit_cap=self.logit_cap,
        )

input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)

可以拿到 key cache 的位置。