Skip to content

vllm.model_executor.models.fireredlid

FireRedLID – Language Identification model adapted for vLLM.

Architecture: ConformerEncoder + TransformerDecoder (6-layer cross-attn) Vocabulary: 120 LID tokens (dict.txt) Output: Up to 2 tokens (e.g. "en", "zh mandarin")

This implementation follows the Whisper-style encoder-decoder pattern

• Encoder processes audio features (Fbank + CMVN via FeatureExtractor) • Decoder performs single-step autoregressive forward • vLLM's generation loop handles beam search / sampling

FireRedLIDAttention

Bases: Module

Base attention with shared QKV/FC projections for the LID decoder.

Source code in vllm/model_executor/models/fireredlid.py
class FireRedLIDAttention(nn.Module):
    """Base attention with shared QKV/FC projections for the LID decoder."""

    def __init__(
        self,
        d_model: int,
        n_head: int,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        assert n_head % tp_size == 0
        self.total_num_heads = n_head
        self.num_heads = n_head // tp_size
        self.num_kv_heads = max(1, n_head // tp_size)
        self.head_dim = d_model // n_head
        self.scaling = self.head_dim**-0.5

        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.w_qs = ColumnParallelLinear(
            d_model,
            d_model,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.w_qs",
        )
        self.w_ks = ColumnParallelLinear(
            d_model,
            d_model,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.w_ks",
        )
        self.w_vs = ColumnParallelLinear(
            d_model,
            d_model,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.w_vs",
        )
        self.fc = RowParallelLinear(
            d_model,
            d_model,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc",
        )
        self._init_attn(cache_config, quant_config, prefix)

    def _init_attn(self, cache_config, quant_config, prefix: str) -> None:
        raise NotImplementedError

FireRedLIDAudioInputs

Bases: TensorSchema

Dimensions
  • b: Batch size
  • t: Time frames (variable across utterances)
  • nmb: Number of mel bins (80)
Source code in vllm/model_executor/models/fireredlid.py
class FireRedLIDAudioInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - t: Time frames  (variable across utterances)
        - nmb: Number of mel bins (80)
    """

    input_features: Annotated[
        list[torch.Tensor] | None,
        TensorShape("b", "t", "nmb", dynamic_dims={"t"}),
    ]
    speech_lengths: Annotated[
        list[torch.Tensor] | None,
        TensorShape("b"),
    ]
    fake_token_lengths: Annotated[
        list[torch.Tensor] | None,
        TensorShape("b"),
    ]

FireRedLIDDecoderLayer

Bases: Module

vLLM-native decoder layer while preserving FireRedLID parameter names.

Source code in vllm/model_executor/models/fireredlid.py
class FireRedLIDDecoderLayer(nn.Module):
    """vLLM-native decoder layer while preserving FireRedLID parameter names."""

    def __init__(
        self,
        d_model: int,
        n_head: int,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.self_attn_norm = nn.LayerNorm(d_model)
        self.self_attn = FireRedLIDSelfAttention(
            d_model,
            n_head,
            vllm_config=vllm_config,
            prefix=f"{prefix}.self_attn",
        )

        self.cross_attn_norm = nn.LayerNorm(d_model)
        self.cross_attn = FireRedLIDCrossAttention(
            d_model,
            n_head,
            vllm_config=vllm_config,
            prefix=f"{prefix}.cross_attn",
        )

        self.mlp_norm = nn.LayerNorm(d_model)
        self.mlp = FireRedLIDFFN(d_model, d_model * 4)

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor | None,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.self_attn_norm(hidden_states)
        hidden_states = self.self_attn(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.cross_attn_norm(hidden_states)
        hidden_states = self.cross_attn(hidden_states, encoder_hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.mlp_norm(hidden_states)
        hidden_states = residual + self.mlp(hidden_states)

        return hidden_states

FireRedLIDForConditionalGeneration

Bases: Module, SupportsTranscription, SupportsMultiModal

Source code in vllm/model_executor/models/fireredlid.py
@MULTIMODAL_REGISTRY.register_processor(
    FireRedLIDMultiModalProcessor,
    info=FireRedLIDProcessingInfo,
    dummy_inputs=FireRedLIDDummyInputsBuilder,
)
class FireRedLIDForConditionalGeneration(
    nn.Module, SupportsTranscription, SupportsMultiModal
):
    # -- SupportsTranscription protocol attributes --
    supports_transcription_only = True
    supported_languages = _FIREREDLID_SUPPORTED_LANGUAGES

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            "encoder.": "model.encoder.",
            "lid_decoder.": "model.decoder.",
            # Encoder FFN: nn.Sequential indices → named children
            "net.0": "pre_layer_norm",
            "net.1": "linear_expand",
            "net.4": "linear_project",
        }
    )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.config = config
        self.dtype = vllm_config.model_config.dtype

        with self._mark_composite_model(
            vllm_config,
            language_targets=FireRedLIDDecoder,
            tower_targets={"audio": FireRedLIDEncoder},
        ):
            self.model = FireRedLIDModel(
                vllm_config=vllm_config,
                prefix=maybe_prefix(prefix, "model"),
            )

        self.proj_out = ParallelLMHead(
            getattr(config, "vocab_size", 120),
            getattr(config, "d_model", 1280),
            quant_config=vllm_config.quant_config,
            prefix=maybe_prefix(prefix, "proj_out"),
        )
        self.proj_out = self.proj_out.tie_weights(self.model.decoder.tgt_word_emb)

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(
            getattr(config, "vocab_size", 120),
            scale=logit_scale,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        encoder_outputs: list[torch.Tensor] | None = None,
        **kwargs,
    ) -> torch.Tensor:
        if encoder_outputs is None:
            encoder_outputs = []
        decoder_outputs = self.model(
            input_ids=input_ids,
            positions=positions,
            encoder_outputs=encoder_outputs,
        )
        return decoder_outputs

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        """Run encoder on audio features and return per-item embeddings."""
        audio_input = self._parse_and_validate_audio_input(**kwargs)

        speech = audio_input["input_features"]
        speech_lengths = audio_input["speech_lengths"]
        if speech is None or speech_lengths is None:
            return []

        # When audio items have different time lengths, vLLM's
        # MultiModalBatchedField._reduce_data returns a plain
        # list[Tensor] instead of a stacked Tensor.  The encoder
        # expects a padded [B, Tmax, feat_dim] Tensor, so we
        # normalise both speech and speech_lengths here.
        if isinstance(speech, (list, tuple)):
            # Each element: [Ti, feat_dim]  (or [1, Ti, feat_dim])
            tensors = [
                s.squeeze(0) if s.dim() == 3 and s.size(0) == 1 else s for s in speech
            ]
            device = tensors[0].device
            dtype = tensors[0].dtype
            feat_dim = tensors[0].shape[-1]
            lengths = torch.tensor(
                [t.size(0) for t in tensors],
                device=device,
                dtype=torch.int32,
            )
            t_max = int(lengths.max().item())
            # Pre-allocate zero-padded batch tensor
            speech = torch.zeros(
                (len(tensors), t_max, feat_dim),
                device=device,
                dtype=dtype,
            )
            for i, t in enumerate(tensors):
                speech[i, : t.size(0)] = t
            speech_lengths = lengths
        else:
            # Already a batched Tensor [B, T, feat_dim]
            if speech.dim() == 2:
                speech = speech.unsqueeze(0)

        speech_lengths = torch.as_tensor(
            speech_lengths, dtype=torch.int32, device=speech.device
        )

        enc_output, enc_lengths = self.model.get_encoder_outputs(
            speech=speech,
            speech_lengths=speech_lengths,
        )

        # vLLM expects one 2D tensor per multimodal item. Slice each batch entry
        # by the true encoder length so cross-attention never sees padded frames.
        return tuple(
            enc_output[i, : max(0, int(enc_lengths[i].item()))]
            for i in range(enc_output.size(0))
        )

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.model.decoder.embed_input_ids(input_ids)

    def _parse_and_validate_audio_input(
        self, **kwargs: object
    ) -> FireRedLIDAudioInputs:
        input_features = kwargs.pop("input_features", None)
        speech_lengths = kwargs.pop("speech_lengths", None)
        fake_token_lengths = kwargs.pop("fake_token_lengths", None)
        return FireRedLIDAudioInputs(
            input_features=input_features,
            speech_lengths=speech_lengths,
            fake_token_lengths=fake_token_lengths,
        )

    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.proj_out, hidden_states)
        return logits

    @classmethod
    def validate_language(cls, language: str | None) -> str | None:
        # FireRedLID is a language *identification* model – the caller does
        # not need to specify a language up-front.  Accept None silently.
        if language is None:
            return None
        return super().validate_language(language)

    @classmethod
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
        stt_config: SpeechToTextConfig,
        model_config: ModelConfig,
        language: str | None,
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
        to_language: str | None,
    ) -> PromptType:
        """Build the prompt for the FireRedLID encoder-decoder model.

        The decoder receives a single <sos> token; the encoder processes
        the raw audio waveform via the multimodal pipeline.
        """
        prompt: PromptType = {
            "encoder_prompt": {
                "prompt": "",
                "multi_modal_data": {
                    "audio": (audio, int(stt_config.sample_rate)),
                },
            },
            "decoder_prompt": {
                "prompt": "<sos>",
            },
        }
        return prompt

    @classmethod
    def get_speech_to_text_config(
        cls,
        model_config: ModelConfig,
        task_type: Literal["transcribe", "translate"],
    ) -> SpeechToTextConfig:
        processor = cached_processor_from_config(model_config)
        return SpeechToTextConfig(
            max_audio_clip_s=processor.feature_extractor.chunk_length,
            sample_rate=processor.feature_extractor.sampling_rate,
            # LID output is at most 2 tokens – no chunking needed.
            min_energy_split_window_size=None,
        )

    @classmethod
    def post_process_output(cls, text: str) -> str:
        # Strip any leading/trailing whitespace from the raw LID output.
        return text.strip()

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=[
                # Position encoding buffers are rebuilt at init
                "model.encoder.positional_encoding.pe",
                "model.decoder.positional_encoding.pe",
                # Tied output projection (shared with embedding)
                "model.decoder.tgt_word_prj.weight",
                "proj_out.",
            ],
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

embed_multimodal

embed_multimodal(**kwargs: object) -> MultiModalEmbeddings

Run encoder on audio features and return per-item embeddings.

Source code in vllm/model_executor/models/fireredlid.py
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
    """Run encoder on audio features and return per-item embeddings."""
    audio_input = self._parse_and_validate_audio_input(**kwargs)

    speech = audio_input["input_features"]
    speech_lengths = audio_input["speech_lengths"]
    if speech is None or speech_lengths is None:
        return []

    # When audio items have different time lengths, vLLM's
    # MultiModalBatchedField._reduce_data returns a plain
    # list[Tensor] instead of a stacked Tensor.  The encoder
    # expects a padded [B, Tmax, feat_dim] Tensor, so we
    # normalise both speech and speech_lengths here.
    if isinstance(speech, (list, tuple)):
        # Each element: [Ti, feat_dim]  (or [1, Ti, feat_dim])
        tensors = [
            s.squeeze(0) if s.dim() == 3 and s.size(0) == 1 else s for s in speech
        ]
        device = tensors[0].device
        dtype = tensors[0].dtype
        feat_dim = tensors[0].shape[-1]
        lengths = torch.tensor(
            [t.size(0) for t in tensors],
            device=device,
            dtype=torch.int32,
        )
        t_max = int(lengths.max().item())
        # Pre-allocate zero-padded batch tensor
        speech = torch.zeros(
            (len(tensors), t_max, feat_dim),
            device=device,
            dtype=dtype,
        )
        for i, t in enumerate(tensors):
            speech[i, : t.size(0)] = t
        speech_lengths = lengths
    else:
        # Already a batched Tensor [B, T, feat_dim]
        if speech.dim() == 2:
            speech = speech.unsqueeze(0)

    speech_lengths = torch.as_tensor(
        speech_lengths, dtype=torch.int32, device=speech.device
    )

    enc_output, enc_lengths = self.model.get_encoder_outputs(
        speech=speech,
        speech_lengths=speech_lengths,
    )

    # vLLM expects one 2D tensor per multimodal item. Slice each batch entry
    # by the true encoder length so cross-attention never sees padded frames.
    return tuple(
        enc_output[i, : max(0, int(enc_lengths[i].item()))]
        for i in range(enc_output.size(0))
    )

get_generation_prompt classmethod

get_generation_prompt(
    audio: ndarray,
    stt_config: SpeechToTextConfig,
    model_config: ModelConfig,
    language: str | None,
    task_type: Literal["transcribe", "translate"],
    request_prompt: str,
    to_language: str | None,
) -> PromptType

Build the prompt for the FireRedLID encoder-decoder model.

The decoder receives a single token; the encoder processes the raw audio waveform via the multimodal pipeline.

Source code in vllm/model_executor/models/fireredlid.py
@classmethod
def get_generation_prompt(
    cls,
    audio: np.ndarray,
    stt_config: SpeechToTextConfig,
    model_config: ModelConfig,
    language: str | None,
    task_type: Literal["transcribe", "translate"],
    request_prompt: str,
    to_language: str | None,
) -> PromptType:
    """Build the prompt for the FireRedLID encoder-decoder model.

    The decoder receives a single <sos> token; the encoder processes
    the raw audio waveform via the multimodal pipeline.
    """
    prompt: PromptType = {
        "encoder_prompt": {
            "prompt": "",
            "multi_modal_data": {
                "audio": (audio, int(stt_config.sample_rate)),
            },
        },
        "decoder_prompt": {
            "prompt": "<sos>",
        },
    }
    return prompt

FireRedLIDModel

Bases: Module

Source code in vllm/model_executor/models/fireredlid.py
class FireRedLIDModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config

        self.encoder = FireRedLIDEncoder(
            idim=getattr(config, "idim", 80),
            n_layers_enc=getattr(config, "n_layers_enc", 16),
            n_head=getattr(config, "n_head", 20),
            d_model=getattr(config, "d_model", 1280),
            kernel_size=getattr(config, "kernel_size", 33),
            pe_maxlen=getattr(config, "pe_maxlen", 5000),
        )

        self.decoder = FireRedLIDDecoder(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "decoder"),
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        encoder_outputs: list[torch.Tensor] | None = None,
    ) -> torch.Tensor:
        enc_states = (
            torch.cat(encoder_outputs, dim=0)
            if encoder_outputs and len(encoder_outputs) > 0
            else None
        )
        decoder_outputs = self.decoder(
            input_ids=input_ids,
            positions=positions,
            encoder_hidden_states=enc_states,
        )
        return decoder_outputs

    def get_encoder_outputs(
        self,
        speech: torch.Tensor | list[torch.Tensor],
        speech_lengths: torch.Tensor | list[torch.Tensor],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Run the encoder and return padded outputs plus true sequence lengths."""
        enc_output, enc_lengths, _ = self.encoder(speech, speech_lengths)
        return enc_output, enc_lengths

get_encoder_outputs

get_encoder_outputs(
    speech: Tensor | list[Tensor],
    speech_lengths: Tensor | list[Tensor],
) -> tuple[Tensor, Tensor]

Run the encoder and return padded outputs plus true sequence lengths.

Source code in vllm/model_executor/models/fireredlid.py
def get_encoder_outputs(
    self,
    speech: torch.Tensor | list[torch.Tensor],
    speech_lengths: torch.Tensor | list[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
    """Run the encoder and return padded outputs plus true sequence lengths."""
    enc_output, enc_lengths, _ = self.encoder(speech, speech_lengths)
    return enc_output, enc_lengths

FireRedLIDPositionalEmbedding

Bases: Module

Absolute sinusoidal positional embedding indexed by positions.

Source code in vllm/model_executor/models/fireredlid.py
class FireRedLIDPositionalEmbedding(nn.Module):
    """Absolute sinusoidal positional embedding indexed by `positions`."""

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        assert d_model % 2 == 0
        pe = torch.zeros(max_len, d_model, requires_grad=False)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float()
            * -(torch.log(torch.tensor(10000.0)).item() / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe, persistent=False)

    def forward(self, position_ids: torch.Tensor) -> torch.Tensor:
        return self.pe[position_ids]