往期相关文章:
vllm prefill 和 decode 的kernel代码解读
在前面文章中分析了一下 MLA 的 transformers 实现:deepseek V2 MLA 的理解
在8月5日,sglang 合入了MLA的实现, upport MLA for DeepSeek-V2 with Triton - step 1#905,虽然已经了解了矩阵吸收的概念,但是还是难以想象具体该怎么实现,这个 sglang 的实现也看上去似懂非懂的,索性看了章明星老师的知乎文章:DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子,了解了其中的关键(话说章老师的标题总是这么震惊吗 orz)。
公式部分
最重要的就两点:
1、key 的矩阵的吸收
这样算 query 矩阵乘的时候直接把解析 key的矩阵一起算了,这样就可以一直使用 。
2、value 的矩阵的吸收
然后算出分数之后与value相成时,也可以不用解析出 value, 而是先算矩阵乘,然后一直使用 。
然后我们就看到 key、value根本没有发生像 transformers 那样的broadcast,而且用的是低秩压缩的 kv,确实省了显存,sglang确实牛的。
代码部分
那我们来把 sglang 中的实现跟这个公式对应一下
跟 transformers的实现一样,sglang 也是有对于kv的低秩压缩矩阵和恢复矩阵的linear 层
主要看这个 deepseek_v2.py
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
)
不一样的地方在于:
kv_b_proj = self.kv_b_proj
w_kc, w_vc = kv_b_proj.weight.unflatten(
0, (-1, qk_nope_head_dim + v_head_dim)
).split([qk_nope_head_dim, v_head_dim], dim=1)
self.w_kc = w_kc
self.w_vc = w_vc
这里把恢复维度的 linear 层给 split 了,split 成恢复key 的 w_kc 和恢复value 的 w_vc,这是为什么呢,就得继续往下看了:
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_nope_out = q_input[..., : self.kv_lora_rank]
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
第一样行跟 transformers 的实现差不多,拆出做 NoPE 的1/3部分,和不做NoPE 的 2/3部分。
然后后面就不一样了 。这里第三行就是公式部分中的
所以就知道为什么要把 linear 层拆开了。
另外 torch.bmm
是 PyTorch 中的批量矩阵乘法(Batch Matrix Multiplication)操作。它用于同时计算多个矩阵乘法。
具体来说:
- bmm 接受两个三维张量作为输入:(batch_size, n, m) 和 (batch_size, m, p)
- 输出也是一个三维张量:(batch_size, n, p)
- 对每个 batch 中的矩阵执行矩阵乘法运算
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
attn_bmm_output = attn_output.new_empty(
q_len, self.num_local_heads, self.v_head_dim
)
torch.bmm(
attn_output.transpose(0, 1),
self.w_vc.transpose(1, 2).contiguous(),
out=attn_bmm_output.transpose(0, 1),
)
attn_output = attn_bmm_output.flatten(1, 2)
output, _ = self.o_proj(attn_output)
这里就对应第二个吸收公式了,把 v_input (其实是 )传进去,当成普通的attention 去算score,然后外面的 output 的linear 权重跟解析 value 的矩阵去做矩阵乘。一直用比较低维的压缩的 kv,节省了显存。但模型大做TP的话,就像MQA那样,每张卡上应该都得存一份。