Bases: Fp8BlockScaledMMLinearKernel, ABC
Dynamic FP8 block-scaled kernel that dispatches at runtime.
Extends Fp8BlockScaledMMLinearKernel to inherit apply_weights and overrides apply_block_scaled_mm to dispatch between two sub-kernels using torch.cond.
Subclasses must define
base_type: The primary kernel class. fallback_type: The fallback kernel class.
Source code in vllm/model_executor/kernels/linear/scaled_mm/BlockScaledMMLinearKernel.py
| class Fp8BlockScaledDynamicMMLinearKernel(Fp8BlockScaledMMLinearKernel, ABC):
"""Dynamic FP8 block-scaled kernel that dispatches at runtime.
Extends Fp8BlockScaledMMLinearKernel to inherit apply_weights and overrides
apply_block_scaled_mm to dispatch between two sub-kernels using torch.cond.
Subclasses must define:
base_type: The primary kernel class.
fallback_type: The fallback kernel class.
"""
base_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]]
fallback_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]]
def __init__(self, config: "FP8ScaledMMLinearLayerConfig") -> None:
super().__init__(config)
self.base = self.base_type(config)
self.fallback = self.fallback_type(config)
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
is_base_supported, reason_1 = cls.base_type.is_supported(compute_capability)
is_fallback_supported, reason_2 = cls.fallback_type.is_supported(
compute_capability
)
if is_base_supported and is_fallback_supported:
return True, None
if not is_base_supported and not is_fallback_supported:
return (
False,
f"base is not supported due to {reason_1}; "
f"fallback is not supported due to {reason_2}",
)
if not is_base_supported:
return False, f"base is not supported due to {reason_1}"
return False, f"fallback is not supported due to {reason_2}"
@classmethod
def can_implement(
cls, config: "FP8ScaledMMLinearLayerConfig"
) -> tuple[bool, str | None]:
can_implement_base, reason_1 = cls.base_type.can_implement(config)
can_implement_fallback, reason_2 = cls.fallback_type.can_implement(config)
if can_implement_base and can_implement_fallback:
return True, None
if not can_implement_base and not can_implement_fallback:
return (
False,
f"base cannot implement due to {reason_1}; "
f"fallback cannot implement due to {reason_2}",
)
if not can_implement_base:
return False, f"base cannot implement due to {reason_1}"
return False, f"fallback cannot implement due to {reason_2}"
|