Skip to content

vllm.v1.attention.backends.mla.indexer

DeepseekV32IndexerMetadataBuilder

Bases: AttentionMetadataBuilder

Source code in vllm/v1/attention/backends/mla/indexer.py
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
    reorder_batch_threshold: int = 1
    natively_supported_next_n: list[int] = [1, 2]
    # TODO (matt): integrate kernel with next_n = 4 support

    @classmethod
    def get_cudagraph_support(
        cls,
        vllm_config: VllmConfig,
        kv_cache_spec: AttentionSpec,
    ) -> AttentionCGSupport:
        return AttentionCGSupport.UNIFORM_BATCH

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        scheduler_config = self.vllm_config.scheduler_config
        # NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
        self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config)
        self.num_speculative_tokens = (
            self.vllm_config.speculative_config.num_speculative_tokens
            if self.vllm_config.speculative_config
            else 0
        )
        next_n = self.num_speculative_tokens + 1
        self.reorder_batch_threshold += self.num_speculative_tokens
        self.use_flattening = next_n not in self.natively_supported_next_n

        sm_count = num_compute_units(self.device.index)
        self.num_sms = sm_count

        self.offsets_buffer = torch.arange(
            next_n, device=self.device, dtype=torch.int32
        )
        self.decode_lens_buffer = torch.zeros(
            (scheduler_config.max_num_batched_tokens,),
            dtype=torch.int32,
            device=self.device,
        )
        if not self.use_flattening and next_n > 1:
            # Native MTP: 2D buffer for per-token seq_lens.
            # Flattening path is never used, so no expanded_seq_lens_buffer.
            self.decode_seq_lens_buffer = torch.zeros(
                (scheduler_config.max_num_seqs, next_n),
                dtype=torch.int32,
                device=self.device,
            )
        else:
            # Flattening or no MTP: 1D buffer for expanded per-token seq_lens.
            self.decode_seq_lens_buffer = torch.zeros(
                (scheduler_config.max_num_batched_tokens,),
                dtype=torch.int32,
                device=self.device,
            )
        self.arange_buffer = torch.arange(
            scheduler_config.max_num_seqs * next_n,
            dtype=torch.int32,
            device=self.device,
        )
        max_num_blocks_per_req = cdiv(
            self.vllm_config.model_config.max_model_len,
            self.kv_cache_spec.block_size * get_total_cp_world_size(),
        )
        self.expanded_block_table_buffer = torch.zeros(
            (
                scheduler_config.max_num_batched_tokens,
                max_num_blocks_per_req,
            ),
            dtype=torch.int32,
            device=self.device,
        )

        # See: DeepGMM/csrc/apis/attention.hpp
        self.scheduler_metadata_buffer = torch.empty(
            (self.num_sms + 1, 2), dtype=torch.int32, device=self.device
        )

    def build_one_prefill_chunk(
        self,
        req_slice: slice,
        query_slice: slice,
        query_start_loc_cpu,
        seq_lens_cpu,
        block_table,
        skip_kv_gather: bool = False,
    ) -> DeepseekV32IndexerPrefillChunkMetadata:
        prefill_query_start_loc = (
            query_start_loc_cpu[req_slice.start : req_slice.stop + 1]
            - query_start_loc_cpu[req_slice.start]
        )
        cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
            prefill_query_start_loc, seq_lens_cpu[req_slice], self.device
        )
        token_start = query_start_loc_cpu[req_slice.start].item()
        total_seq_lens = seq_lens_cpu[req_slice].sum()
        num_reqs = req_slice.stop - req_slice.start
        seq_idx = torch.arange(0, num_reqs, dtype=torch.int32)
        token_to_seq = torch.repeat_interleave(seq_idx, seq_lens_cpu[req_slice]).to(
            self.device
        )
        assert total_seq_lens <= self.max_prefill_buffer_size
        cu_seq_lens = (
            torch.cat(
                [
                    torch.zeros(1, dtype=torch.int32),
                    seq_lens_cpu[req_slice].cumsum(dim=0),
                ]
            )
            .to(torch.int32)
            .to(self.device)
        )

        return DeepseekV32IndexerPrefillChunkMetadata(
            cu_seqlen_ks=cu_seqlen_ks[query_slice],
            cu_seqlen_ke=cu_seqlen_ke[query_slice],
            cu_seq_lens=cu_seq_lens,
            token_to_seq=token_to_seq,
            total_seq_lens=total_seq_lens,
            block_table=block_table[req_slice],
            token_start=token_start + query_slice.start,
            token_end=token_start + query_slice.stop,
            num_reqs=num_reqs,
            skip_kv_gather=skip_kv_gather,
        )

    def _prepare_decode_tensors(
        self,
        seq_lens: torch.Tensor,
        block_table: torch.Tensor,
        decode_lens: torch.Tensor,
        decode_lens_cpu: torch.Tensor,
        query_start_loc: torch.Tensor,
        num_decodes: int,
        num_decode_tokens: int,
        use_native: bool,
        next_n: int,
        max_decode_len: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, bool]:
        """Expand seq_lens/block_table/decode_lens for the decode kernels.

        Flatten path (not use_native, max_decode_len > 1):
          Each multi-token decode request is expanded into individual
          single-token entries so the kernel always sees next_n=1.

        Native path (use_native or max_decode_len == 1):
          Plain decode or spec-decode with 2D per-token context lengths.

        Returns (seq_lens, block_table, decode_lens, batch_size, requires_padding).
        seq_lens is 1D (batch_size,) for flatten/plain, 2D (B, next_n) for native MTP.
        """
        if not use_native and max_decode_len > 1:
            assert self.decode_seq_lens_buffer.dim() == 1
            # Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
            # padding) and decode_lens [3, 1, 4, 0] in the below example comments.
            # The context lengths are therefore
            # [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].

            # 3 + 1 + 4 + 0 = 8
            actual_expanded = int(decode_lens_cpu.sum().item())

            # Fuse expanded_base and expanded_starts into a single repeat_interleave:
            # seq_len_i = (context_start[b] - query_start_loc[b]) + arange[i] + 1
            # where context_start[b] = seq_lens[b] - decode_lens[b].
            # Example: offsets = [7-0, 6-3, 8-4, 0-8] = [7, 3, 4, -8]
            # expanded_offsets  = [7, 7, 7, 3, 4, 4, 4, 4]
            # result            = [8, 9, 10, 7, 9, 10, 11, 12]
            expanded_offsets = torch.repeat_interleave(
                seq_lens - decode_lens - query_start_loc,
                decode_lens,
                output_size=actual_expanded,
            )

            # [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
            self.decode_seq_lens_buffer[:actual_expanded] = (
                expanded_offsets + self.arange_buffer[:actual_expanded] + 1
            )
            self.decode_seq_lens_buffer[actual_expanded:] = 0
            seq_lens = self.decode_seq_lens_buffer[:num_decode_tokens]

            # Give each of the flattened entries the same block table row as the
            # original request.
            self.expanded_block_table_buffer[:actual_expanded] = (
                torch.repeat_interleave(
                    block_table, decode_lens, dim=0, output_size=actual_expanded
                )
            )
            if actual_expanded < num_decode_tokens:
                self.expanded_block_table_buffer[
                    actual_expanded:num_decode_tokens, 0
                ] = 0
            block_table = self.expanded_block_table_buffer[:num_decode_tokens]

            # All reqs now have decode_len=1
            self.decode_lens_buffer[:num_decode_tokens] = 1
            decode_lens = self.decode_lens_buffer[:num_decode_tokens]
            return seq_lens, block_table, decode_lens, num_decode_tokens, False
        else:
            # Native path: plain decode (next_n==1) or spec decode
            # with 2D per-token context lengths (next_n > 1).
            #
            # When decode_lens are not truly uniform (e.g. some requests have
            # decode_len < next_n due to padding or short prefills), the simple
            # reshape in sparse_attn_indexer won't work. Use pack_seq_triton
            # (requires_padding) instead.
            min_decode_len = int(decode_lens_cpu.min().item())
            requires_padding = min_decode_len != max_decode_len
            if use_native and next_n > 1:
                assert self.decode_seq_lens_buffer.dim() == 2
                # (B, next_n): token j attends to L - next_n + j + 1 KV tokens
                self.decode_seq_lens_buffer[:num_decodes] = (
                    seq_lens.unsqueeze(1) - next_n + 1 + self.offsets_buffer
                )
                seq_lens = self.decode_seq_lens_buffer[:num_decodes]
            return seq_lens, block_table, decode_lens, num_decodes, requires_padding

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> DeepseekV32IndexerMetadata:
        num_reqs = common_attn_metadata.num_reqs
        num_tokens = common_attn_metadata.num_actual_tokens

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(
                common_attn_metadata,
                decode_threshold=self.reorder_batch_threshold,
                require_uniform=not self.use_flattening,
            )
        )

        assert num_decodes + num_prefills == num_reqs
        assert num_decode_tokens + num_prefill_tokens == num_tokens

        prefill_metadata = None
        if num_prefills > 0:
            prefill_query_lens_cpu = torch.diff(
                query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1]
            )
            max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
            chunk_specs = split_indexer_prefill_chunks(
                common_attn_metadata.seq_lens_cpu[num_decodes:],
                prefill_query_lens_cpu,
                self.max_prefill_buffer_size,
                max_logits_bytes,
                request_offset=num_decodes,
            )
            chunks = [
                self.build_one_prefill_chunk(
                    req_slice,
                    query_slice,
                    query_start_loc_cpu,
                    common_attn_metadata.seq_lens_cpu,
                    common_attn_metadata.block_table_tensor,
                    skip_kv_gather=query_slice.start > 0,
                )
                for req_slice, query_slice in chunk_specs
            ]
            prefill_metadata = DeepseekV32IndexerPrefillMetadata(
                chunks=chunks,
            )

        decode_metadata = None
        if num_decodes > 0:
            torch.diff(
                common_attn_metadata.query_start_loc[: num_decodes + 1],
                out=self.decode_lens_buffer[:num_decodes],
            )
            decode_lens = self.decode_lens_buffer[:num_decodes]
            decode_lens_cpu = torch.diff(
                common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
            )

            seq_lens = common_attn_metadata.seq_lens[:num_decodes]
            block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]

            max_decode_len = int(decode_lens_cpu.max().item())
            next_n = 1 + self.num_speculative_tokens
            use_native = not self.use_flattening and max_decode_len == next_n

            seq_lens, block_table, decode_lens, batch_size, requires_padding = (
                self._prepare_decode_tensors(
                    seq_lens=seq_lens,
                    block_table=block_table,
                    decode_lens=decode_lens,
                    decode_lens_cpu=decode_lens_cpu,
                    query_start_loc=common_attn_metadata.query_start_loc[:num_decodes],
                    num_decodes=num_decodes,
                    num_decode_tokens=num_decode_tokens,
                    use_native=use_native,
                    next_n=next_n,
                    max_decode_len=max_decode_len,
                )
            )

            # DeepGEMM is required for the paged MQA logits on CUDA devices
            if current_platform.is_cuda() and has_deep_gemm():
                self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
                    seq_lens,
                    self.kv_cache_spec.block_size,
                    self.num_sms,
                )

            decode_metadata = DeepSeekV32IndexerDecodeMetadata(
                block_table=block_table,
                seq_lens=seq_lens,
                decode_lens=decode_lens,
                requires_padding=requires_padding,
                schedule_metadata=self.scheduler_metadata_buffer,
            )

        attn_metadata = DeepseekV32IndexerMetadata(
            seq_lens=common_attn_metadata.seq_lens,
            num_reqs=common_attn_metadata.num_reqs,
            max_query_len=common_attn_metadata.max_query_len,
            max_seq_len=common_attn_metadata.max_seq_len,
            num_actual_tokens=common_attn_metadata.num_actual_tokens,
            query_start_loc=common_attn_metadata.query_start_loc,
            slot_mapping=common_attn_metadata.slot_mapping,
            head_dim=128,
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
            prefill=prefill_metadata,
            decode=decode_metadata,
        )

        return attn_metadata

_prepare_decode_tensors

_prepare_decode_tensors(
    seq_lens: Tensor,
    block_table: Tensor,
    decode_lens: Tensor,
    decode_lens_cpu: Tensor,
    query_start_loc: Tensor,
    num_decodes: int,
    num_decode_tokens: int,
    use_native: bool,
    next_n: int,
    max_decode_len: int,
) -> tuple[Tensor, Tensor, Tensor, int, bool]

Expand seq_lens/block_table/decode_lens for the decode kernels.

Flatten path (not use_native, max_decode_len > 1): Each multi-token decode request is expanded into individual single-token entries so the kernel always sees next_n=1.

Native path (use_native or max_decode_len == 1): Plain decode or spec-decode with 2D per-token context lengths.

Returns (seq_lens, block_table, decode_lens, batch_size, requires_padding). seq_lens is 1D (batch_size,) for flatten/plain, 2D (B, next_n) for native MTP.

Source code in vllm/v1/attention/backends/mla/indexer.py
def _prepare_decode_tensors(
    self,
    seq_lens: torch.Tensor,
    block_table: torch.Tensor,
    decode_lens: torch.Tensor,
    decode_lens_cpu: torch.Tensor,
    query_start_loc: torch.Tensor,
    num_decodes: int,
    num_decode_tokens: int,
    use_native: bool,
    next_n: int,
    max_decode_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, bool]:
    """Expand seq_lens/block_table/decode_lens for the decode kernels.

    Flatten path (not use_native, max_decode_len > 1):
      Each multi-token decode request is expanded into individual
      single-token entries so the kernel always sees next_n=1.

    Native path (use_native or max_decode_len == 1):
      Plain decode or spec-decode with 2D per-token context lengths.

    Returns (seq_lens, block_table, decode_lens, batch_size, requires_padding).
    seq_lens is 1D (batch_size,) for flatten/plain, 2D (B, next_n) for native MTP.
    """
    if not use_native and max_decode_len > 1:
        assert self.decode_seq_lens_buffer.dim() == 1
        # Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
        # padding) and decode_lens [3, 1, 4, 0] in the below example comments.
        # The context lengths are therefore
        # [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].

        # 3 + 1 + 4 + 0 = 8
        actual_expanded = int(decode_lens_cpu.sum().item())

        # Fuse expanded_base and expanded_starts into a single repeat_interleave:
        # seq_len_i = (context_start[b] - query_start_loc[b]) + arange[i] + 1
        # where context_start[b] = seq_lens[b] - decode_lens[b].
        # Example: offsets = [7-0, 6-3, 8-4, 0-8] = [7, 3, 4, -8]
        # expanded_offsets  = [7, 7, 7, 3, 4, 4, 4, 4]
        # result            = [8, 9, 10, 7, 9, 10, 11, 12]
        expanded_offsets = torch.repeat_interleave(
            seq_lens - decode_lens - query_start_loc,
            decode_lens,
            output_size=actual_expanded,
        )

        # [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
        self.decode_seq_lens_buffer[:actual_expanded] = (
            expanded_offsets + self.arange_buffer[:actual_expanded] + 1
        )
        self.decode_seq_lens_buffer[actual_expanded:] = 0
        seq_lens = self.decode_seq_lens_buffer[:num_decode_tokens]

        # Give each of the flattened entries the same block table row as the
        # original request.
        self.expanded_block_table_buffer[:actual_expanded] = (
            torch.repeat_interleave(
                block_table, decode_lens, dim=0, output_size=actual_expanded
            )
        )
        if actual_expanded < num_decode_tokens:
            self.expanded_block_table_buffer[
                actual_expanded:num_decode_tokens, 0
            ] = 0
        block_table = self.expanded_block_table_buffer[:num_decode_tokens]

        # All reqs now have decode_len=1
        self.decode_lens_buffer[:num_decode_tokens] = 1
        decode_lens = self.decode_lens_buffer[:num_decode_tokens]
        return seq_lens, block_table, decode_lens, num_decode_tokens, False
    else:
        # Native path: plain decode (next_n==1) or spec decode
        # with 2D per-token context lengths (next_n > 1).
        #
        # When decode_lens are not truly uniform (e.g. some requests have
        # decode_len < next_n due to padding or short prefills), the simple
        # reshape in sparse_attn_indexer won't work. Use pack_seq_triton
        # (requires_padding) instead.
        min_decode_len = int(decode_lens_cpu.min().item())
        requires_padding = min_decode_len != max_decode_len
        if use_native and next_n > 1:
            assert self.decode_seq_lens_buffer.dim() == 2
            # (B, next_n): token j attends to L - next_n + j + 1 KV tokens
            self.decode_seq_lens_buffer[:num_decodes] = (
                seq_lens.unsqueeze(1) - next_n + 1 + self.offsets_buffer
            )
            seq_lens = self.decode_seq_lens_buffer[:num_decodes]
        return seq_lens, block_table, decode_lens, num_decodes, requires_padding

kv_spans_from_batches

kv_spans_from_batches(
    start_seq_loc: Tensor,
    seq_len_per_batch: Tensor,
    device: device,
) -> tuple[Tensor, Tensor]

Parameters:

Name Type Description Default
start_seq_loc Tensor

1D long tensor [B+1], cumulative counts of selected tokens per batch. Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total.

required
seq_len_per_batch Tensor

1D long tensor [B], full sequence length (KV length) of each batch. Example: [5, 9, 4].

required

Returns:

Name Type Description
start_tensor Tensor

1D long tensor [N], start offset in the concatenated KV cache for each token's batch.

end_location Tensor

1D long tensor [N], exclusive end = start + token's local position. (So the attended KV slice is kv[start:end].)

Assumes each batch contributes its full seq_len_per_batch[i] keys to the KV cache, andthe selected tokens within a batch are the last counts[i] positions of that sequence.

Source code in vllm/v1/attention/backends/mla/indexer.py
def kv_spans_from_batches(
    start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
      start_seq_loc: 1D long tensor [B+1], cumulative counts of
                     selected tokens per batch.
            Example: [0, 2, 4, 7] ->
                     batch sizes (selected) [2, 2, 3], N=7 tokens total.
      seq_len_per_batch: 1D long tensor [B],
                         full sequence length (KV length) of each batch.
                         Example: [5, 9, 4].

    Returns:
      start_tensor: 1D long tensor [N], start offset in the
                    concatenated KV cache for each token's batch.
      end_location: 1D long tensor [N],
                    **exclusive** end = start + token's local position.
                    (So the attended KV slice is kv[start:end].)

    Assumes each batch contributes its full `seq_len_per_batch[i]`
    keys to the KV cache, andthe selected tokens within a batch
    are the **last** `counts[i]` positions of that sequence.
    """
    q = start_seq_loc.to(dtype=torch.long)
    L = seq_len_per_batch.to(dtype=torch.long)
    assert q.dim() == 1 and L.dim() == 1
    assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"

    # Selected tokens per batch and totals
    counts = q[1:] - q[:-1]  # [B]
    N = int(q[-1].item())  # total selected tokens
    B = L.numel()

    if N == 0:
        return (
            torch.empty(0, dtype=torch.long, device=device),
            torch.empty(0, dtype=torch.long, device=device),
        )

    # KV start offsets per batch in the concatenated KV cache
    kv_starts_per_batch = torch.cumsum(L, dim=0) - L  # [B]

    # For each selected token, which batch does it belong to?
    batch_id = torch.repeat_interleave(torch.arange(B), counts)  # [N]

    # Map batch KV start to each token
    start_tensor = kv_starts_per_batch[batch_id]  # [N]

    # End-align local positions inside each batch:
    # local_pos = L[b] - counts[b] + (1..counts[b])  for each batch b
    L_expand = torch.repeat_interleave(L, counts)  # [N]
    m_expand = torch.repeat_interleave(counts, counts)  # [N]
    # position within the selected block: 1..counts[b]
    pos_within = (
        torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1
    )

    local_pos = L_expand - m_expand + pos_within  # [N], 1-based
    end_location = start_tensor + local_pos  # exclusive end

    return start_tensor.int().to(device), end_location.int().to(device)

split_indexer_prefill_chunks

split_indexer_prefill_chunks(
    seq_lens_cpu: Tensor,
    query_lens_cpu: Tensor,
    workspace_size: int,
    max_logits_bytes: int,
    request_offset: int = 0,
) -> list[tuple[slice, slice]]

Split prefill requests into chunks for the sparse indexer, respecting: - N constraint: total_seq_lens <= workspace_size (existing O(N) workspace) - Logits constraint: M * N * 4 <= max_logits_bytes

When a single request-level chunk still exceeds the logits budget, sub-chunks on the query dimension (M) to bound peak memory.

Returns list of (req_slice, query_slice) tuples.

Source code in vllm/v1/attention/backends/mla/indexer.py
def split_indexer_prefill_chunks(
    seq_lens_cpu: torch.Tensor,
    query_lens_cpu: torch.Tensor,
    workspace_size: int,
    max_logits_bytes: int,
    request_offset: int = 0,
) -> list[tuple[slice, slice]]:
    """
    Split prefill requests into chunks for the sparse indexer, respecting:
    - N constraint: total_seq_lens <= workspace_size (existing O(N) workspace)
    - Logits constraint: M * N * 4 <= max_logits_bytes

    When a single request-level chunk still exceeds the logits budget,
    sub-chunks on the query dimension (M) to bound peak memory.

    Returns list of (req_slice, query_slice) tuples.
    """
    chunks: list[tuple[slice, slice]] = []
    n = len(seq_lens_cpu)
    max_logits_elems = max_logits_bytes // 4
    end = 0

    while end < n:
        start, chunk_m, chunk_n = end, 0, 0

        while end < n:
            q, s = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
            new_m, new_n = chunk_m + q, chunk_n + s
            if new_n <= workspace_size and new_m * new_n <= max_logits_elems:
                chunk_m, chunk_n = new_m, new_n
                end += 1
            else:
                break

        # A single request can exceed the budget, requiring sub-chunking
        # on the query dimension.
        if end == start:
            chunk_m, chunk_n = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
            end += 1

        req_slice = slice(start + request_offset, end + request_offset)
        max_q = max(1, max_logits_elems // chunk_n) if chunk_n > 0 else chunk_m
        for q_off in range(0, chunk_m, max_q):
            sub_m = min(max_q, chunk_m - q_off)
            chunks.append((req_slice, slice(q_off, q_off + sub_m)))

    return chunks