Skip to content

vllm.model_executor.kernels.linear.scaled_mm.flashinfer

FlashInferFp8DeepGEMMDynamicBlockScaledKernel

Bases: Fp8BlockScaledDynamicMMLinearKernel

Conditional FlashInfer / DeepGEMM FP8 block-scaled GEMM.

Dispatches between two kernels based on input batch size: - Small batches (M < 32): FlashInfer's swapAB trick for better utilisation. - Large batches (M >= 32): DeepGEMM for peak throughput.

apply_input_quant is False because FlashInfer accepts BF16 input and handles FP8 conversion internally. The DeepGEMM branch therefore quantises BF16→FP8 inside apply_mm via a closure before dispatching to the DeepGEMM kernel — keeping both branches compatible with the single BF16 tensor operand list passed by torch.cond.

Source code in vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py
class FlashInferFp8DeepGEMMDynamicBlockScaledKernel(
    Fp8BlockScaledDynamicMMLinearKernel
):
    """
    Conditional FlashInfer / DeepGEMM FP8 block-scaled GEMM.

    Dispatches between two kernels based on input batch size:
    - Small batches (M < 32): FlashInfer's swapAB trick for better utilisation.
    - Large batches (M >= 32): DeepGEMM for peak throughput.

    apply_input_quant is False because FlashInfer accepts BF16 input and
    handles FP8 conversion internally.  The DeepGEMM branch therefore
    quantises BF16→FP8 inside apply_mm via a closure before dispatching to
    the DeepGEMM kernel — keeping both branches compatible with the single
    BF16 tensor operand list passed by torch.cond.
    """

    base_type: ClassVar[type[FlashInferFp8BlockScaledMMKernel]] = (
        FlashInferFp8BlockScaledMMKernel
    )
    fallback_type: ClassVar[type[DeepGemmFp8BlockScaledMMKernel]] = (
        DeepGemmFp8BlockScaledMMKernel
    )
    apply_input_quant: ClassVar[bool] = False

    def __init__(self, config: FP8ScaledMMLinearLayerConfig):
        super().__init__(config)
        self.base: FlashInferFp8BlockScaledMMKernel
        self.fallback: DeepGemmFp8BlockScaledMMKernel

    def process_weights_after_loading(self, layer: torch.nn.Module):
        # DeepGEMM need post-processing; both kernels share the same
        # parameter tensor layout so processing once is sufficient.
        self.fallback.process_weights_after_loading(layer)

    def apply_block_scaled_mm(
        self,
        A: torch.Tensor,
        B: torch.Tensor,
        As: torch.Tensor,
        Bs: torch.Tensor,
    ) -> torch.Tensor:
        group_size = self.weight_group_shape.col
        use_deep_gemm_e8m0 = self.fallback.use_deep_gemm_e8m0

        return torch.ops.vllm.dynamic_flashinfer_deepgemm_blockscale_gemm(
            A, B, Bs, group_size, use_deep_gemm_e8m0
        )

_dynamic_flashinfer_deepgemm_blockscale_gemm_fake

_dynamic_flashinfer_deepgemm_blockscale_gemm_fake(
    input: Tensor,
    weight: Tensor,
    weight_scale: Tensor,
    group_size: int,
    use_deep_gemm_e8m0: bool,
) -> Tensor

Required fake/meta implementation for torch.compile graph tracing.

Source code in vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py
def _dynamic_flashinfer_deepgemm_blockscale_gemm_fake(
    input: torch.Tensor,
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    group_size: int,
    use_deep_gemm_e8m0: bool,
) -> torch.Tensor:
    """
    Required fake/meta implementation for torch.compile graph tracing.
    """
    return torch.empty(
        input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
    )

_dynamic_flashinfer_deepgemm_blockscale_gemm_impl

_dynamic_flashinfer_deepgemm_blockscale_gemm_impl(
    input: Tensor,
    weight: Tensor,
    weight_scale: Tensor,
    group_size: int,
    use_deep_gemm_e8m0: bool,
) -> Tensor

Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection.

This function switches between two optimized kernels based on the input batch size: - For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization. - For larger batches (M >= 32): Uses the official DeepGEMM kernel.

The conditional logic must use torch.cond() instead of a simple if-else statement to maintain compatibility with torch.compile graph compilation.

This batch-size-dependent selection is essential for maintaining model accuracy. Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1 when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy drop.

Parameters:

Name Type Description Default
input Tensor

Input tensor of shape (batch_size, input_dim) in FP8 format

required
weight Tensor

Weight tensor of shape (output_dim, input_dim) in FP8 format

required
weight_scale Tensor

Scale factors for weight quantization (per-group)

required
group_size int

Quantization group size for the weight tensor

required
use_deep_gemm_e8m0 bool

Whether to use the E8M0 format in DeepGEMM quantization

required

Returns:

Type Description
Tensor

Output tensor of shape (batch_size, output_dim) in bfloat16 format

Source code in vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py
def _dynamic_flashinfer_deepgemm_blockscale_gemm_impl(
    input: torch.Tensor,
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    group_size: int,
    use_deep_gemm_e8m0: bool,
) -> torch.Tensor:
    """
    Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection.

    This function switches between two optimized kernels based on the input batch size:
    - For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization.
    - For larger batches (M >= 32): Uses the official DeepGEMM kernel.

    The conditional logic must use torch.cond() instead of a simple if-else statement
    to maintain compatibility with torch.compile graph compilation.

    This batch-size-dependent selection is essential for maintaining model accuracy.
    Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1
    when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy
    drop.

    Args:
        input: Input tensor of shape (batch_size, input_dim) in FP8 format
        weight: Weight tensor of shape (output_dim, input_dim) in FP8 format
        weight_scale: Scale factors for weight quantization (per-group)
        group_size: Quantization group size for the weight tensor
        use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization

    Returns:
        Output tensor of shape (batch_size, output_dim) in bfloat16 format
    """

    def run_flashinfer_deepgemm_swapAB(
        input: torch.Tensor,
        weight: torch.Tensor,
        weight_scale: torch.Tensor,
    ) -> torch.Tensor:
        return flashinfer_fp8_blockscale_gemm(
            input=input,
            weight=weight,
            weight_scale=weight_scale,
            out_dtype=torch.bfloat16,
        )

    def run_deepgemm(
        input: torch.Tensor,
        weight: torch.Tensor,
        weight_scale: torch.Tensor,
    ) -> torch.Tensor:
        q_input, input_scale = per_token_group_quant_fp8(
            input,
            group_size=group_size,
            column_major_scales=True,
            use_ue8m0=use_deep_gemm_e8m0,
        )
        output = torch.empty(
            (q_input.shape[0], weight.shape[0]),
            dtype=torch.bfloat16,
            device=q_input.device,
        )
        fp8_gemm_nt(
            (q_input, input_scale),
            (weight, weight_scale),
            output,
            is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
        )
        return output

    if envs.VLLM_BATCH_INVARIANT:
        return run_deepgemm(input, weight, weight_scale)

    condition = input.shape[0] < 32

    # PyTorch's torch.compile cannot handle input-dependent control flow in standard
    # Python conditionals. torch.cond() explicitly registers both code paths in the
    # computation graph, allowing torch.compile to capture both branches.
    # without torch.cond, the M < 32 condition won't be able to be captured by torch
    # compile
    return torch.cond(
        condition,
        run_flashinfer_deepgemm_swapAB,
        run_deepgemm,
        (input, weight, weight_scale),
    )

_flashinfer_fp8_blockscale_gemm_fake

_flashinfer_fp8_blockscale_gemm_fake(
    input: Tensor, weight: Tensor, weight_scale: Tensor
) -> Tensor

Required fake/meta implementation for torch.compile graph tracing.

Source code in vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py
def _flashinfer_fp8_blockscale_gemm_fake(
    input: torch.Tensor,
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
) -> torch.Tensor:
    """
    Required fake/meta implementation for torch.compile graph tracing.
    """
    return torch.empty(
        input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
    )