class FlashInferCutlassNvFp4LinearKernel(NvFp4LinearKernel):
"""NVFP4 GEMM via FlashInfer's CUTLASS wrapper."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
if (
cutlass_fp4_supported()
and current_platform.has_device_capability(100)
and has_flashinfer()
):
return True, None
return False, "FlashInfer + >=sm_100 required"
@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.weight_scale.data), requires_grad=False
)
padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight.data
)
layer.weight = torch.nn.Parameter(padded_weight, requires_grad=False)
layer.weights_padding_cols = weights_padding_cols
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
output_size = layer.output_size_per_partition
output_dtype = x.dtype
output_shape = [*x.shape[:-1], output_size]
x_fp4, x_blockscale = scaled_fp4_quant(
x,
layer.input_global_scale_inv,
is_sf_swizzled_layout=True,
backend="flashinfer-cutlass",
)
x_fp4 = pad_nvfp4_activation_for_cutlass(
x_fp4, getattr(layer, "weights_padding_cols", 0)
)
out = flashinfer_scaled_fp4_mm(
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
backend="cutlass",
)
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)