Skip to content

vllm.compilation.passes.fusion.mla_attn_quant_fusion

MLAAttnFp8StaticQuantPattern

Bases: VllmPatternReplacement[..., Tensor]

Fusion for MLA Attention+Fp8StaticQuant.

Matches the pattern: MLA attention -> static FP8 quant, and replaces it with MLA attention(output_scale=scale, output=fp8_buffer).

Source code in vllm/compilation/passes/fusion/mla_attn_quant_fusion.py
class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
    """
    Fusion for MLA Attention+Fp8StaticQuant.

    Matches the pattern: MLA attention -> static FP8 quant, and replaces
    it with MLA attention(output_scale=scale, output=fp8_buffer).
    """

    def __init__(self, layer: MLAAttention, dtype: torch.dtype) -> None:
        self._layer_name = layer.layer_name
        self._num_heads = layer.num_heads
        self._v_head_dim = layer.v_head_dim
        self._kv_lora_rank = layer.kv_lora_rank
        self._qk_rope_head_dim = layer.qk_rope_head_dim
        self._qk_head_dim = layer.qk_nope_head_dim + layer.qk_rope_head_dim
        self._output_dim = layer.num_heads * layer.v_head_dim
        self._dtype = dtype
        self._quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)

    @property
    def pattern(self) -> Callable[..., torch.Tensor]:
        _ln = _encode_layer_name(self._layer_name)

        if _USE_LAYERNAME:

            def _pattern_with_ln(  # type: ignore[misc]
                q,
                kv_c_normed,
                k_pe,
                output_attn,
                scale,
                kv_cache_dummy_dep,
                layer_name,
            ):
                at1 = auto_functionalized(
                    MLA_ATTN_OP,
                    q=q,
                    kv_c_normed=kv_c_normed,
                    k_pe=k_pe,
                    output=output_attn,
                    layer_name=layer_name,
                    output_scale=None,
                    output_block_scale=None,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
                )
                # MLA output is already 2D (T, N*V), no reshape needed
                return self._quant_matcher(at1[1], scale)[0]

            return _pattern_with_ln

        def _pattern(q, kv_c_normed, k_pe, output_attn, scale, kv_cache_dummy_dep):
            at1 = auto_functionalized(
                MLA_ATTN_OP,
                q=q,
                kv_c_normed=kv_c_normed,
                k_pe=k_pe,
                output=output_attn,
                layer_name=_ln,
                output_scale=None,
                output_block_scale=None,
                kv_cache_dummy_dep=kv_cache_dummy_dep,
            )
            # MLA output is already 2D (T, N*V), no reshape needed
            return self._quant_matcher(at1[1], scale)[0]

        return _pattern

    @property
    def replacement(self) -> Callable[..., torch.Tensor]:
        _ln = _encode_layer_name(self._layer_name)

        if _USE_LAYERNAME:

            def _replacement_with_ln(  # type: ignore[misc]
                q,
                kv_c_normed,
                k_pe,
                output_attn,
                scale,
                kv_cache_dummy_dep,
                layer_name,
            ):
                # MLA output in quant_dtype
                output_attn = torch.empty(
                    [q.shape[0], self._output_dim],
                    dtype=FP8_DTYPE,
                    device=q.device,
                )
                at1 = auto_functionalized(
                    MLA_ATTN_OP,
                    q=q,
                    kv_c_normed=kv_c_normed,
                    k_pe=k_pe,
                    output=output_attn,
                    layer_name=layer_name,
                    output_scale=scale,
                    output_block_scale=None,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
                )
                return at1[1]

            return _replacement_with_ln

        def _replacement(q, kv_c_normed, k_pe, output_attn, scale, kv_cache_dummy_dep):
            # MLA output in quant_dtype
            output_attn = torch.empty(
                [q.shape[0], self._output_dim],
                dtype=FP8_DTYPE,
                device=q.device,
            )
            at1 = auto_functionalized(
                MLA_ATTN_OP,
                q=q,
                kv_c_normed=kv_c_normed,
                k_pe=k_pe,
                output=output_attn,
                layer_name=_ln,
                output_scale=scale,
                output_block_scale=None,
                kv_cache_dummy_dep=kv_cache_dummy_dep,
            )
            return at1[1]

        return _replacement

    def get_inputs(self) -> list[torch.Tensor]:
        inputs: list = [
            self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype),
            self.empty(5, self._kv_lora_rank, dtype=self._dtype),
            self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype),
            self.empty(5, self._output_dim, dtype=self._dtype),
            self.empty_fp32(1, 1),
            self.empty(0, dtype=self._dtype),
        ]
        if _USE_LAYERNAME:
            inputs.append(_encode_layer_name(self._layer_name))
        return inputs

MLAAttnNvfp4QuantPattern

Bases: VllmPatternReplacement[..., tuple[Tensor, Tensor]]

Fusion for MLA Attention+Nvfp4Quant.

Matches the pattern: MLA attention -> NVFP4 quant, and replaces it with MLA attention(output_scale=scale, output_block_scale=block_scale, output=fp4_buffer).

Source code in vllm/compilation/passes/fusion/mla_attn_quant_fusion.py
class MLAAttnNvfp4QuantPattern(
    VllmPatternReplacement[..., tuple[torch.Tensor, torch.Tensor]]
):
    """
    Fusion for MLA Attention+Nvfp4Quant.

    Matches the pattern: MLA attention -> NVFP4 quant, and replaces
    it with MLA attention(output_scale=scale, output_block_scale=block_scale,
    output=fp4_buffer).
    """

    def __init__(self, layer: MLAAttention, dtype: torch.dtype) -> None:
        self._layer_name = layer.layer_name
        self._num_heads = layer.num_heads
        self._v_head_dim = layer.v_head_dim
        self._kv_lora_rank = layer.kv_lora_rank
        self._qk_rope_head_dim = layer.qk_rope_head_dim
        self._qk_head_dim = layer.qk_nope_head_dim + layer.qk_rope_head_dim
        self._output_dim = layer.num_heads * layer.v_head_dim
        self._dtype = dtype
        self._QUANT_OP = QUANT_OPS[kNvfp4Dynamic]

    @property
    def pattern(
        self,
    ) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
        _ln = _encode_layer_name(self._layer_name)

        if _USE_LAYERNAME:

            def _pattern_with_ln(  # type: ignore[misc]
                q,
                kv_c_normed,
                k_pe,
                output_attn,
                input_scale,
                kv_cache_dummy_dep,
                layer_name,
            ):
                at1 = auto_functionalized(
                    MLA_ATTN_OP,
                    q=q,
                    kv_c_normed=kv_c_normed,
                    k_pe=k_pe,
                    output=output_attn,
                    layer_name=layer_name,
                    output_scale=None,
                    output_block_scale=None,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
                )
                output_quant, output_scale = create_fp4_output_tensors(
                    at1[1].shape[0], at1[1].shape[1], at1[1].device, True
                )
                at2 = auto_functionalized(
                    self._QUANT_OP,
                    input=at1[1],
                    input_scale=input_scale,
                    is_sf_swizzled_layout=True,
                    output=output_quant,
                    output_scale=output_scale,
                )
                output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
                return at2[1], output_scale_view

            return _pattern_with_ln

        def _pattern(
            q, kv_c_normed, k_pe, output_attn, input_scale, kv_cache_dummy_dep
        ):
            at1 = auto_functionalized(
                MLA_ATTN_OP,
                q=q,
                kv_c_normed=kv_c_normed,
                k_pe=k_pe,
                output=output_attn,
                layer_name=_ln,
                output_scale=None,
                output_block_scale=None,
                kv_cache_dummy_dep=kv_cache_dummy_dep,
            )
            # Replicate what scaled_fp4_quant() does: allocate output
            # tensors inline then call the .out variant.
            output_quant, output_scale = create_fp4_output_tensors(
                at1[1].shape[0], at1[1].shape[1], at1[1].device, True
            )
            at2 = auto_functionalized(
                self._QUANT_OP,
                input=at1[1],
                input_scale=input_scale,
                is_sf_swizzled_layout=True,
                output=output_quant,
                output_scale=output_scale,
            )
            output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
            return at2[1], output_scale_view

        return _pattern

    @property
    def replacement(
        self,
    ) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
        _ln = _encode_layer_name(self._layer_name)

        if _USE_LAYERNAME:

            def _replacement_with_ln(  # type: ignore[misc]
                q,
                kv_c_normed,
                k_pe,
                output_attn,
                input_scale,
                kv_cache_dummy_dep,
                layer_name,
            ):
                # MLA output in quant_dtype (FP4 packed as uint8)
                output_attn = torch.empty(
                    [q.shape[0], self._output_dim // 2],
                    dtype=FP4_DTYPE,
                    device=q.device,
                )
                output_scale = create_fp4_output_tensors(
                    q.shape[0], self._output_dim, q.device, True
                )[1]
                output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
                at2 = auto_functionalized(
                    MLA_ATTN_OP,
                    q=q,
                    kv_c_normed=kv_c_normed,
                    k_pe=k_pe,
                    output=output_attn,
                    layer_name=layer_name,
                    output_scale=input_scale,
                    output_block_scale=output_scale_view,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
                )
                return at2[1], at2[2]

            return _replacement_with_ln

        def _replacement(
            q, kv_c_normed, k_pe, output_attn, input_scale, kv_cache_dummy_dep
        ):
            # MLA output in quant_dtype (FP4 packed as uint8)
            output_attn = torch.empty(
                [q.shape[0], self._output_dim // 2],
                dtype=FP4_DTYPE,
                device=q.device,
            )
            # attention output block scale
            output_scale = create_fp4_output_tensors(
                q.shape[0], self._output_dim, q.device, True
            )[1]
            output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
            at2 = auto_functionalized(
                MLA_ATTN_OP,
                q=q,
                kv_c_normed=kv_c_normed,
                k_pe=k_pe,
                output=output_attn,
                layer_name=_ln,
                output_scale=input_scale,
                output_block_scale=output_scale_view,
                kv_cache_dummy_dep=kv_cache_dummy_dep,
            )
            return at2[1], at2[2]

        return _replacement

    def get_inputs(self) -> list[torch.Tensor]:
        inputs: list = [
            self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype),
            self.empty(5, self._kv_lora_rank, dtype=self._dtype),
            self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype),
            self.empty(5, self._output_dim, dtype=self._dtype),
            self.empty_fp32(1, 1),
            self.empty(0, dtype=self._dtype),
        ]
        if _USE_LAYERNAME:
            inputs.append(_encode_layer_name(self._layer_name))
        return inputs

MLAAttnQuantFusionPass

Bases: VllmFusionPatternMatcherPass

This pass fuses post-attention quantization onto MLA attention if supported.

It uses the pattern matcher and matches each MLA layer manually, as strings cannot be wildcarded. This also lets us check support on attention layers upon registration instead of during pattern matching.

Source code in vllm/compilation/passes/fusion/mla_attn_quant_fusion.py
class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass):
    """
    This pass fuses post-attention quantization onto MLA attention if supported.

    It uses the pattern matcher and matches each MLA layer manually, as strings
    cannot be wildcarded. This also lets us check support on attention layers
    upon registration instead of during pattern matching.
    """

    def __init__(self, config: VllmConfig) -> None:
        super().__init__(config, "mla_attn_quant_fusion")

        dtype = config.model_config.dtype
        layers = list(get_layers_from_vllm_config(config, MLAAttention).values())

        if len(layers) == 0:
            logger.warning(
                "MLA attention + quant fusion is enabled, but no MLA "
                "attention layers were found in "
                "CompilationConfig.static_forward_context "
                "so no fusion patterns were registered."
            )

        # When _USE_LAYERNAME is enabled, layer_name is a wildcard so all
        # layers produce the same pattern — register once then break.
        for layer in layers:
            if layer.impl.fused_output_quant_supported(kFp8StaticTensorSym):
                self.register(MLAAttnFp8StaticQuantPattern(layer, dtype))
                if _USE_LAYERNAME:
                    break

        if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
            for layer in layers:
                if layer.impl.fused_output_quant_supported(kNvfp4Dynamic):
                    self.register(MLAAttnNvfp4QuantPattern(layer, dtype))
                    if _USE_LAYERNAME:
                        break

        self.dump_patterns(config, self.pm_pass)