上个月deepseekv2 发布了, https://arxiv.org/pdf/2405.04434,学习了一下论文,感觉MLA还挺重要的,不过只看到了 transformers 的实现,感觉deepseek内部应该有更快的实现,期待vllm的版本。

MLA 中 表示低秩压缩的KV, 是多query共享的 k, 是key做了RoPE的部分,
的shape 为 (4*d_h)=512, 的维度为 d_h/2 = 64, d_h=128
"hidden_size": 5120,
"q_lora_rank": 1536,
"num_attention_heads": 128,
"kv_lora_rank": 512,
"v_head_dim": 128,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
query 的 linear 层
代码来源: https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.q_a_proj = nn.Linear(
self.hidden_size, config.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
self.q_b_proj = nn.Linear(
config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
q_a_proj: (5120, 1536)
q_b_proj: (1536, 128*192=24576)
kv 低秩压缩层
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
config.kv_lora_rank + config.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
self.kv_b_proj = nn.Linear(
config.kv_lora_rank,
self.num_heads
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
kv_a_proj_with_mqa: (5120, 576)
kv_b_proj(512, 128*(192-64+128) )
# 7168->576
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
# compressed_kv:512, k_pe:64
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
# 升维:[bsz, 1, q_len,64]
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
# 【bsz, 128, q_len, 128+128】
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
# 然后拆分出value 部分和不做 rope 的 key部分
k_nope, value_states = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
整个过程就是先把 hidden_states 经过一个低秩压缩(kv_lora_rank),然后从中取出要做rope的那部分key,剩下的部分升维到多head 的 key 和 value 。(128, 128)
然后拆分出value 部分和不做 rope 的 key部分。
不过 这个版本并看出来MLA kvcache 哪里省显存,下边 key_states 还是分配了 num_heads
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
query_states 和 key_states 都是 nope 部分和做了 rope的部分拼接而成的
query_states 的 shape 为 (bsz, num_heads, q_len, q_head_dim)
key_states 的 shape 为 (bsz, num_heads, kv_seq_len, q_head_dim)
得到的 q,k包括两部分拼接而成:一部分是做了低秩压缩得到的 q,k
向量,一部分是增加了RoPE位置编码的 q,k 向量。
query 与 key 的转置进行矩阵乘,后面的操作就跟普通 MHA差不多了
attn_weight shape 为 (bsz, num_heads, q_len, kv_seq_len)
value_states 的 shape 为(bsz, num_heads, kv_seq_len, v_head_dim)
最后 attn_weight 与 value_states 矩阵乘得到 bsz, num_heads, q_len, v_head_dim)
可以看到 q,k包括两部分拼接而成:一部分是做了低秩压缩得到的 q,k
向量,一部分是增加了RoPE位置编码的 q,k 向量。v 是只有一份的,key 的最后一维比value 多了 rope部分。
单独计算了两个带着位置编码的q_pe, k_pe, 维度为 单Attention Head维度的一半:128/2=64, K_pe 也是一层一份的,多query的头共享同一个,注意它的shape, num_head 的那一维是1
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
然后在复制给 key_states 时进行了broadcast
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
https://spaces.ac.cn/archives/10091 2024-05-13
原理上苏剑林大佬已经解释的非常清晰了,“RoPE与低秩KV不兼容,没法做矩阵吸收计
算”, 2/3 的key不需要做位置相关的 ROPE 所以可以通过矩阵吸收在算Q的时候先算好对key的运算,可以缓存低秩压缩低部分,做ROPE的key部分是多head 共享的,所以也可以省kc cache 的显存。这也是MLA的压缩KV Cache的核心原理。
从效果上看,虽然MLA缓存的Latent KV比较短(相当于2.25个MQA的缓存量),但MLA有恢复全 的能力,特征表达能力显著比GQA、MQA要强。
这里 transformers 的实现明显并不是实际的实现,后续有时间再看看别的仓库的实现吧。