往期相关文章:

sglang的模型执行

deepseek V2 MLA 的理解

vllm prefill 和 decode 的kernel代码解读

在前面文章中分析了一下 MLA 的 transformers 实现:deepseek V2 MLA 的理解

在8月5日,sglang 合入了MLA的实现, upport MLA for DeepSeek-V2 with Triton - step 1#905,虽然已经了解了矩阵吸收的概念,但是还是难以想象具体该怎么实现,这个 sglang 的实现也看上去似懂非懂的,索性看了章明星老师的知乎文章:DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子,了解了其中的关键(话说章老师的标题总是这么震惊吗 orz)。

公式部分

最重要的就两点:

1、key 的矩阵的吸收

(cqTWUQT)(WUKckv)=(cqTWUQTWUK)ckv(c_q^T\cdot W_{UQ}^T)\cdot (W_{UK}\cdot c_{kv}) = (c_q^T \cdot W_{UQ}^T \cdot W_{UK})\cdot c_{kv}

这样算 query 矩阵乘的时候直接把解析 key的矩阵一起算了,这样就可以一直使用 ckvc_{kv}

2、value 的矩阵的吸收

(score(ckvWUV))WO=(scoreckv)(WUVWO)(score\cdot (c_{kv}\cdot W_{UV}))\cdot W_{O}=(score\cdot c_{kv})\cdot (W_{UV}\cdot W_{O})

然后算出分数之后与value相成时,也可以不用解析出 value, 而是先算矩阵乘,然后一直使用 ckvc_{kv}

然后我们就看到 key、value根本没有发生像 transformers 那样的broadcast,而且用的是低秩压缩的 kv,确实省了显存,sglang确实牛的。

代码部分

那我们来把 sglang 中的实现跟这个公式对应一下

跟 transformers的实现一样,sglang 也是有对于kv的低秩压缩矩阵和恢复矩阵的linear 层

主要看这个 deepseek_v2.py

        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
        )

不一样的地方在于:

        kv_b_proj = self.kv_b_proj
        w_kc, w_vc = kv_b_proj.weight.unflatten(
            0, (-1, qk_nope_head_dim + v_head_dim)
        ).split([qk_nope_head_dim, v_head_dim], dim=1)
        self.w_kc = w_kc
        self.w_vc = w_vc

这里把恢复维度的 linear 层给 split 了,split 成恢复key 的 w_kc 和恢复value 的 w_vc,这是为什么呢,就得继续往下看了:

q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_nope_out = q_input[..., : self.kv_lora_rank]
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))

第一样行跟 transformers 的实现差不多,拆出做 NoPE 的1/3部分,和不做NoPE 的 2/3部分。

然后后面就不一样了 。这里第三行就是公式部分中的

cqTWUQTWUKc_q^T \cdot W_{UQ}^T \cdot W_{UK}

所以就知道为什么要把 linear 层拆开了。

另外 torch.bmm​ 是 PyTorch 中的批量矩阵乘法(Batch Matrix Multiplication)操作。它用于同时计算多个矩阵乘法。

具体来说:

  • bmm 接受两个三维张量作为输入:(batch_size, n, m) 和 (batch_size, m, p)
  • 输出也是一个三维张量:(batch_size, n, p)
  • 对每个 batch 中的矩阵执行矩阵乘法运算

	attn_output = self.attn(q_input, k_input, v_input, input_metadata)
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
        attn_bmm_output = attn_output.new_empty(
            q_len, self.num_local_heads, self.v_head_dim
        )
        torch.bmm(
            attn_output.transpose(0, 1),
            self.w_vc.transpose(1, 2).contiguous(),
            out=attn_bmm_output.transpose(0, 1),
        )

        attn_output = attn_bmm_output.flatten(1, 2)
        output, _ = self.o_proj(attn_output)

这里就对应第二个吸收公式了,把 v_input (其实是 CkvC_{kv})传进去,当成普通的attention 去算score,然后外面的 output 的linear 权重跟解析 value 的矩阵去做矩阵乘。一直用比较低维的压缩的 kv,节省了显存。但模型大做TP的话,就像MQA那样,每张卡上应该都得存一份。