Bases: NvFp4LinearKernel
NVFP4 weight-only GEMM via Marlin (W4A16).
Source code in vllm/model_executor/kernels/linear/nvfp4/marlin.py
| class MarlinNvFp4LinearKernel(NvFp4LinearKernel):
"""NVFP4 weight-only GEMM via Marlin (W4A16)."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if is_fp4_marlin_supported():
return True, None
return False, "Marlin FP4 not available"
@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression "
"will be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
prepare_fp4_layer_for_marlin(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_global_scale=layer.weight_global_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
|