Skip to content

vllm.v1.worker.gpu.spec_decode.eagle.cudagraph

EagleCudaGraphManager

Bases: CudaGraphManager

CudaGraphManager for Eagle speculative decoding.

Source code in vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
class EagleCudaGraphManager(CudaGraphManager):
    """CudaGraphManager for Eagle speculative decoding."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        cudagraph_mode: CUDAGraphMode,
        decode_query_len: int,
    ):
        super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)

        # Use a dedicated pool for Eagle to avoid memory overlap with the main
        # model's cudagraph. The base class uses a shared global pool, but Eagle's
        # internal allocations (e.g., gumbel_sample temporaries) can conflict with
        # the main model's allocations when sharing the same pool.
        if cudagraph_mode:
            self.pool = torch.cuda.graph_pool_handle()

    def capture(
        self,
        forward_fn: Callable,
        model_state: ModelState,
        input_buffers: InputBuffers,
        block_tables: BlockTables,
        attn_groups: list[list[AttentionGroup]],
        kv_cache_config: KVCacheConfig,
        progress_bar_desc: str = "Capturing CUDA graphs",
    ) -> None:
        """Capture CUDA graphs for Eagle."""

        def create_forward_fn(
            desc: BatchExecutionDescriptor,
        ) -> Callable[[CUDAGraphMode], None]:
            num_tokens = desc.num_tokens
            num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
            num_tokens_across_dp = (
                torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
                if self.dp_size > 1
                else None
            )
            attn_metadata, slot_mappings = prepare_inputs_to_capture(
                num_reqs,
                num_tokens,
                model_state,
                input_buffers,
                block_tables,
                attn_groups,
                kv_cache_config,
            )

            return lambda cg_mode: forward_fn(
                num_reqs,
                num_tokens,
                attn_metadata,
                slot_mappings,
                num_tokens_across_dp,
                cg_mode,
            )

        super().capture(create_forward_fn, progress_bar_desc)

capture

capture(
    forward_fn: Callable,
    model_state: ModelState,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
    attn_groups: list[list[AttentionGroup]],
    kv_cache_config: KVCacheConfig,
    progress_bar_desc: str = "Capturing CUDA graphs",
) -> None

Capture CUDA graphs for Eagle.

Source code in vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
def capture(
    self,
    forward_fn: Callable,
    model_state: ModelState,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
    attn_groups: list[list[AttentionGroup]],
    kv_cache_config: KVCacheConfig,
    progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
    """Capture CUDA graphs for Eagle."""

    def create_forward_fn(
        desc: BatchExecutionDescriptor,
    ) -> Callable[[CUDAGraphMode], None]:
        num_tokens = desc.num_tokens
        num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
        num_tokens_across_dp = (
            torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
            if self.dp_size > 1
            else None
        )
        attn_metadata, slot_mappings = prepare_inputs_to_capture(
            num_reqs,
            num_tokens,
            model_state,
            input_buffers,
            block_tables,
            attn_groups,
            kv_cache_config,
        )

        return lambda cg_mode: forward_fn(
            num_reqs,
            num_tokens,
            attn_metadata,
            slot_mappings,
            num_tokens_across_dp,
            cg_mode,
        )

    super().capture(create_forward_fn, progress_bar_desc)