https://github.com/vllm-project/vllm/pull/3462

在3月25号 vllm 对 attention 进行了一下重构,估计一月一次的版本发布马上就来了,可以先看看代码了解它的工作机制。

我们已经知道推理是分位 prefill 阶段和 decode 阶段的,这个是在 vllm scheudler里控制的,

prefill阶段 先把句子设置为 prefill 状态,执行一下图,这prefill阶段会填充 kv_cache。

0. 整体执行

在 vllm 的 model_runner.py 里 execute_model 时调用模型的 forward 生成logit 然后做 sample.

代码位于 vllm/worker/model_runner.py

    def execute_model(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
        kv_caches: List[torch.Tensor],
    ) -> Optional[SamplerOutput]:
        # ......
        # 0. 准备数据
		(input_tokens, input_positions, attn_metadata, sampling_metadata,
         lora_requests, lora_mapping, multi_modal_input
         ) = self.prepare_input_tensors(seq_group_metadata_list)
		model_executable = self.model
        # 1. 执行模型前向
		hidden_states = model_executable(**execute_model_kwargs)
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
        if not sampling_metadata.perform_sampling:
            return None

        # Sample the next token.
 		# 3. 采样出下一个token
        output = self.model.sample(
            logits=logits,
            sampling_metadata=sampling_metadata,
        )
        return output

1、准备输入

这里是第一个关键点,也就是 prefill 把整个query 输入模型,decode 把当前token输入模型

    def prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
               Set[int], LoRAMapping, torch.Tensor]:
        if self.is_driver_worker:
            # NOTE: We assume that all sequences in the group are all prompts or
            # all decodes.
            is_prompt = seq_group_metadata_list[0].is_prompt
            # Prepare input tensors.
            if is_prompt:
                # 1、如果是 prefill 就调用 _prepare_prompt 
                (input_tokens, input_positions, attn_metadata, prompt_lens,
                 subquery_lens, lora_index_mapping, lora_prompt_mapping,
                 lora_requests, multi_modal_input
                 ) = self._prepare_prompt(seq_group_metadata_list)
            else:
				# 2、# 1、如果是 decode 就调用 _prepare_decode 
                (input_tokens, input_positions, attn_metadata,
                 lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_decode(seq_group_metadata_list)
                prompt_lens = []
                subquery_lens = None

_prepare_prompt​ 是去拿到 prefill 的 token

def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
               List[int], List[int], List[int], Set[LoRARequest],
               torch.Tensor]:
		prefill_end = min(seq_data.get_len(),
                              computed_len + token_chunk_size)
            # TODO(sang): Rename it after chunked prefill is introduced.
            prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]

可以看到 prefill 的时候是把整个seq 的 prompt 的 token 都拿到

def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
               List[int], Set[LoRARequest]]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[int] = []
        for seq_group_metadata in seq_group_metadata_list:
        
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append(generation_token)

可以看到 decode 得时候是拿到最后的一个 token (get_last_token_id)

2、attention 准备

在模型 forward的时候会调到每一层的forward,在 attention 这一层就进入 attention 的 forward,以llama 为例,在 vllm/model_executor/models/llama.py 中,LlamaAttention 部分会通过 qkw linear层 算出 q,k,v,这里在 prefill 阶段算的 q k v是seqlen 的长query对应的值,这里算出了 k 和 v,就可以把 k 和 v 写进 kvcache 里。

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output

在 vllm/attention/layer.py 中,会调用真正的attention 实现

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: Optional[torch.Tensor],
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        return self.impl.forward(query, key, value, kv_cache, attn_metadata)

3、写入kv cache

以flash_attention 为例,在 vllm/attention/backends/flash_attn.py 中,调用 forward 函数时,会把当前传入的 key 和 value 写入 kv cache 里,如果是prefill,那写的就是整个prompt的key 和 value,如果是decode 阶段,那就是写当前token的

class FlashAttentionImpl(AttentionImpl):
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
    ) -> torch.Tensor:
      
        num_tokens, hidden_size = query.shape
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

        if kv_cache is not None:
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                value_cache,
                                                attn_metadata.slot_mapping,
                                                attn_metadata.kv_cache_dtype)

通过 write_to_paged_cache 把 key 和 value 写入 kv cache。

4、prefill 和 decode 执行kernel

然后 prefill 和 decode 调用自己的 kernel,传入 kcache 和 vcache 进行 load

        if attn_metadata.is_prompt:
            # Prompt run.
            if kv_cache is None or attn_metadata.block_tables.numel() == 0:
                # normal attention
                # When block_tables are not filled, it means q and k are the
                # prompt, and they have the same length.
                output = flash_attn_varlen_func(
                    q=query,
                    k=key,
                    v=value,
                    cu_seqlens_q=attn_metadata.seq_start_loc,
                    cu_seqlens_k=attn_metadata.seq_start_loc,
                    max_seqlen_q=attn_metadata.max_prompt_len,
                    max_seqlen_k=attn_metadata.max_prompt_len,
                    softmax_scale=self.scale,
                    causal=True,
                    window_size=self.sliding_window,
                    alibi_slopes=self.alibi_slopes,
                )
            else:
                # prefix-enabled attention
                output = PagedAttention.forward_prefix(
                    query,
                    key,
                    value,
                    key_cache,
                    value_cache,
                    attn_metadata.block_tables,
                    attn_metadata.subquery_start_loc,
                    attn_metadata.prompt_lens_tensor,
                    attn_metadata.context_lens,
                    attn_metadata.max_subquery_len,
                    self.alibi_slopes,
                )
        else:
            # Decoding run.
            output = PagedAttention.forward_decode(
                query,
                key_cache,
                value_cache,
                attn_metadata.block_tables,
                attn_metadata.context_lens,
                attn_metadata.max_context_len,
                attn_metadata.kv_cache_dtype,
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
            )

prefill 调用 flash_attn_varlen_func​ 对应的 kernel

以 forward_decode 为例,在 csrc/attention/attention_kernels.cu 中真正执行 kernel

__device__ void paged_attention_kernel(
  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
  scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
  const int num_kv_heads,                 // [num_heads]

// 向量化访问设置
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;

为了提高内存访问效率,代码使用向量化加载,每次加载16字节数据。

  • 加载Query数据
// 加载query到寄存器
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];

#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
  const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
  q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
__syncthreads();

每个线程组协作加载一个完整的query向量到共享内存。

  • 计算Query-Key点积
// 共享内存规划
extern __shared__ char shared_mem[];
float* logits = reinterpret_cast<float*>(shared_mem);
__shared__ float red_smem[2 * NUM_WARPS];

// 初始化最大logit值
float qk_max = -FLT_MAX;

// 遍历key块
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
  const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
  
  // 加载key并计算点积
  for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
    const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
    K_vec k_vecs[NUM_VECS_PER_THREAD];
  
    // 加载key向量
    #pragma unroll
    for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
      // 计算key缓存指针
      const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
                                     + kv_head_idx * kv_head_stride
                                     + physical_block_offset * x;
      // 加载key向量(支持FP8量化)
      // ...
    }
  
    // 计算点积
    float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
  
    // 添加ALiBi位置编码偏置
    qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
  
    // 存储结果到共享内存
    if (thread_group_offset == 0) {
      const bool mask = token_idx >= context_len;
      logits[token_idx - start_token_idx] = mask ? 0.f : qk;
      qk_max = mask ? qk_max : fmaxf(qk_max, qk);
    }
  }
}

这部分是核心计算,每个线程组加载key并计算与query的点积,同时应用ALiBi位置编码。

在这里读取 v_cache 并与logit 进行计算

    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
                                   + kv_head_idx * kv_head_stride;
#pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      if (row_idx < HEAD_SIZE) {
        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
        V_vec v_vec;
        if constexpr (IS_FP8_E5M2_KV_CACHE) {
#ifdef ENABLE_FP8_E5M2
          V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
          // Vector conversion from V_quant_vec to V_vec.
          v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
#else
          assert(false);
#endif
        } else {
          v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
        }
        if (block_idx == num_context_blocks - 1) {
          // NOTE(woosuk): When v_vec contains the tokens that are out of the context,
          // we should explicitly zero out the values since they may contain NaNs.
          // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
          for (int j = 0; j < V_VEC_SIZE; j++) {
            v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
          }
        }
        accs[i] += dot(logits_vec, v_vec);
      }
    }
  }