Skip to content

vllm.model_executor.kernels.linear.scaled_mm.BlockScaledMMLinearKernel

Fp8BlockScaledDynamicMMLinearKernel

Bases: Fp8BlockScaledMMLinearKernel, ABC

Dynamic FP8 block-scaled kernel that dispatches at runtime.

Extends Fp8BlockScaledMMLinearKernel to inherit apply_weights and overrides apply_block_scaled_mm to dispatch between two sub-kernels using torch.cond.

Subclasses must define

base_type: The primary kernel class. fallback_type: The fallback kernel class.

Source code in vllm/model_executor/kernels/linear/scaled_mm/BlockScaledMMLinearKernel.py
class Fp8BlockScaledDynamicMMLinearKernel(Fp8BlockScaledMMLinearKernel, ABC):
    """Dynamic FP8 block-scaled kernel that dispatches at runtime.

    Extends Fp8BlockScaledMMLinearKernel to inherit apply_weights and overrides
    apply_block_scaled_mm to dispatch between two sub-kernels using torch.cond.

    Subclasses must define:
        base_type:     The primary kernel class.
        fallback_type: The fallback kernel class.
    """

    base_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]]
    fallback_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]]

    def __init__(self, config: "FP8ScaledMMLinearLayerConfig") -> None:
        super().__init__(config)
        self.base = self.base_type(config)
        self.fallback = self.fallback_type(config)

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        is_base_supported, reason_1 = cls.base_type.is_supported(compute_capability)
        is_fallback_supported, reason_2 = cls.fallback_type.is_supported(
            compute_capability
        )
        if is_base_supported and is_fallback_supported:
            return True, None
        if not is_base_supported and not is_fallback_supported:
            return (
                False,
                f"base is not supported due to {reason_1}; "
                f"fallback is not supported due to {reason_2}",
            )
        if not is_base_supported:
            return False, f"base is not supported due to {reason_1}"
        return False, f"fallback is not supported due to {reason_2}"

    @classmethod
    def can_implement(
        cls, config: "FP8ScaledMMLinearLayerConfig"
    ) -> tuple[bool, str | None]:
        can_implement_base, reason_1 = cls.base_type.can_implement(config)
        can_implement_fallback, reason_2 = cls.fallback_type.can_implement(config)
        if can_implement_base and can_implement_fallback:
            return True, None
        if not can_implement_base and not can_implement_fallback:
            return (
                False,
                f"base cannot implement due to {reason_1}; "
                f"fallback cannot implement due to {reason_2}",
            )
        if not can_implement_base:
            return False, f"base cannot implement due to {reason_1}"
        return False, f"fallback cannot implement due to {reason_2}"