class EmulationMxfp8LinearKernel(Mxfp8LinearKernel):
"""Software emulation fallback for MXFP8 (dequant to BF16)."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
return True, None
@classmethod
def can_implement(cls, c: Mxfp8LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
weight_scale = layer.weight_scale
if weight_scale.dtype != MXFP8_SCALE_DTYPE:
raise ValueError(
f"Emulation backend requires {MXFP8_SCALE_DTYPE} "
f"weight_scale dtype, got {weight_scale.dtype}."
)
if weight_scale.ndim != 2:
raise ValueError(
f"Emulation backend requires 2D weight_scale, "
f"got {weight_scale.ndim}D. "
f"Ensure process_weights_after_loading was called."
)
weight_bf16 = dequant_mxfp8_to_bf16(layer.weight, weight_scale)
output = torch.nn.functional.linear(x, weight_bf16, bias)
return output.to(x.dtype)