往期相关文章:
vllm prefill 和 decode 的kernel代码解读
前天7月27日,sglang 发布了 v0.2.5,还发了一篇博客:https://lmsys.org/blog/2024-07-25-sglang-llama3/ ,我们来分析一下这个代码,就像分析vllm 代码一样
代码的起点是在 python/sglang/srt/server.py 中
def launch_server(
server_args: ServerArgs,
model_overide_args: Optional[dict] = None,
pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
server_args.check_server_args()
if server_args.dp_size == 1:
start_process = start_controller_process_single
else:
start_process = start_controller_process_multi
proc_controller = mp.Process(
target=start_process,
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
)
会根据 dp_size 分别调用 start_controller_process_single
和 start_controller_process_multi
,简单点我们就来看 single 版本,会构造一个 ControllerSingle
并陷入loop。
python/sglang/srt/managers/controller_single.py
try:
controller = ControllerSingle(
server_args,
port_args,
model_overide_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
queue,
)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
try:
controller.loop_for_forward()
controller 的 loop_for_forward 是一个接收请求并调用 tp_server 进行执行的简单控制器
def loop_for_forward(self):
while True:
if not self.is_dp_worker:
recv_reqs = self.recv_requests_from_zmq()
else:
recv_reqs = self.recv_requests_from_mp_queue()
if self.tp_size > 1:
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
接收消息可以从 zmq 或着 multiprocessing.Queue 中拿数据。
tp_server 是用来真正执行调用 forward 的:
python/sglang/srt/managers/tp_worker.py
def exposed_step(self, recv_reqs):
try:
# Recv requests
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
else:
raise ValueError(f"Invalid request: {recv_req}")
# Forward
self.forward_step()
通过 handle_generate_request 来将请求放入 waiting 队列,forward_step 来处理队列里的请求。
python/sglang/srt/managers/tp_worker.py
@torch.inference_mode()
def forward_step(self):
new_batch = self.get_new_prefill_batch()
if new_batch is not None:
# Run a new prefill batch
self.forward_prefill_batch(new_batch)
else:
# Run a decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
self.num_generated_tokens += len(self.running_batch.reqs)
self.forward_decode_batch(self.running_batch)
forward_step 函数会根据 batch 的情况来执行 prefill 或者 decode
如果是 prefill 就会调用 batch 的 prepare_for_extend,然后以 extend 的模式调用 model_runner 的 forward
如果是 decode 模式就会调用 batch 的 prepare_for_decode,然后以 decode 的模式调用 model_runner 的 forward
def prepare_for_decode(self, input_ids=None):
if input_ids is None:
input_ids = [
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
]
prepare_for_decode 会取每个句子的最后一个 token,因为decode就是以最后一个token去计算然后生成下一个token的。
在 model_runner 里才会真正调用到具体模型的 forward 函数
def forward_extend(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
在模型中会调用对应的 attention 的 forward,对于 decode 和 extend 都会先把本步的 kv 保存在kvcache 中,然后调用各自的 kernel.
python/sglang/srt/layers/radix_attention.py
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
extend_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.triton_start_loc,
input_metadata.seq_lens,
input_metadata.triton_prefix_lens,
input_metadata.extend_start_loc,
input_metadata.extend_seq_lens,
input_metadata.triton_max_seq_len,
input_metadata.triton_max_extend_len,
sm_scale=self.scaling,
logit_cap=self.logit_cap,
)
return o
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
token_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
o.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.triton_start_loc,
input_metadata.seq_lens,
input_metadata.triton_max_seq_len,
input_metadata.total_num_tokens,
sm_scale=self.scaling,
logit_cap=self.logit_cap,
)
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
可以拿到 key cache 的位置。