Skip to content

vllm.lora.layers.base_linear

BaseLinearLayerWithLoRA

Bases: BaseLayerWithLoRA

Source code in vllm/lora/layers/base_linear.py
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: LinearBase):
        super().__init__()

        self._enable_aux_cuda_stream = envs.VLLM_LORA_ENABLE_DUAL_STREAM
        self.base_layer = base_layer
        self.input_size = self.base_layer.input_size
        # Ensure tp_size and tp_rank consistency with the base_layer.
        self.tp_size = self.base_layer.tp_size
        self.tp_rank = self.base_layer.tp_rank
        self.device = _get_lora_device(self.base_layer)
        self._init_lora_stream_context()
        self.output_slices: tuple[int, ...]
        self.output_size: int
        self.n_slices: int

    def _init_lora_stream_context(self) -> None:
        if not self._enable_aux_cuda_stream:
            return
        vllm_config = get_current_vllm_config()
        self._lora_stream = _get_lora_aux_cuda_stream()
        assert current_platform.is_cuda_alike()
        self._events = [torch.cuda.Event(), torch.cuda.Event()]
        # lora_linear avoids prefix conflicts with the base layer
        self.layer_name = self.base_layer.prefix + ".lora_linear_async"
        compilation_config = vllm_config.compilation_config
        if self.layer_name in compilation_config.static_forward_context:
            raise ValueError("Duplicate layer name: {}".format(self.layer_name))
        compilation_config.static_forward_context[self.layer_name] = self

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        self.lora_config = lora_config
        if isinstance(self.base_layer, ReplicatedLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, ColumnParallelLinear):
            lora_a_out_size = (
                lora_config.max_lora_rank
                if not lora_config.fully_sharded_loras
                else divide(lora_config.max_lora_rank, self.tp_size)
            )
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, RowParallelLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = (
                self.output_size
                if not lora_config.fully_sharded_loras
                else divide(self.output_size, self.tp_size)
            )
        else:
            raise NotImplementedError

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_a_out_size,
                self.input_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self.n_slices)
        )
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_b_out_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self.n_slices)
        )
        self.output_slices = (self.lora_b_stacked[0].shape[2],)

    def reset_lora(self, index: int):
        for s_index in range(self.n_slices):
            self.lora_a_stacked[s_index][index] = 0
            self.lora_b_stacked[s_index][index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        # Except for QKVParallelLinearWithLoRA and
        # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
        # store weights in a tuple of size 1. These two layers will
        # override this function.
        assert isinstance(lora_a, torch.Tensor)
        assert isinstance(lora_b, torch.Tensor)
        assert (
            len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
        )

        self.reset_lora(index)
        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)

        self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
            lora_a, non_blocking=True
        )
        self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
            lora_b, non_blocking=True
        )

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        # is_forward_context_available for tower modules
        if self._enable_aux_cuda_stream and is_forward_context_available():
            output_size = sum(self.output_slices)
            return torch.ops.vllm.lora_linear_async(
                self.layer_name, output_size, x, bias
            )
        else:
            return self._apply_sync(x, bias)

    def _apply_sync(
        self, x: torch.Tensor, bias: torch.Tensor | None = None
    ) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

        original_shape = output.shape if output.ndim == 3 else None

        # In transformers backend, x and output have extra batch dimension like
        # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
        # therefore we need to flatten the batch dimensions.
        if x.ndim == 3 and output.ndim == 3:
            output = output.flatten(0, 1)
            x = x.flatten(0, 1)

        lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
            output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
        )
        if not current_platform.can_update_inplace():
            output = lora_output

        # Reshape the flattened output back to its original shape,
        # as some MM encoders cannot handle flattened inputs.
        if original_shape is not None:
            output = output.reshape(original_shape)

        return output

    def _apply_async_impl(
        self, x: torch.Tensor, bias: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Forward pass with base linear and LoRA on separate CUDA streams
        for overlap, using maybe_execute_in_parallel.
        Base layer runs on default stream; LoRA runs on aux stream.
        """
        assert envs.VLLM_LORA_ENABLE_DUAL_STREAM
        assert x.ndim in (2, 3)
        num_tokens = x.size(0) if x.ndim == 2 else x.size(1)
        output_size = sum(self.output_slices)

        def base_fn() -> torch.Tensor:
            return self.base_layer.quant_method.apply(self.base_layer, x, bias)

        def lora_fn() -> torch.Tensor:
            # Must be zeros, not empty: _lora_expand_kernel exits early (without
            # writing) when lora_id == -1 (no active LoRA). If uninitialized,
            # output.add_(lora_result) below would corrupt the base output.
            lora_output = torch.zeros(
                (num_tokens, output_size),
                device=self.device,
                dtype=x.dtype,
            )

            # Flatten the batch dimension for the transformers backend
            # (which uses shape (1, seq_len, hidden)), matching _apply_sync.
            x_2d = x.flatten(0, 1) if x.ndim == 3 else x
            self.punica_wrapper.add_lora_linear(
                lora_output,
                x_2d,
                self.lora_a_stacked,
                self.lora_b_stacked,
                1.0,
                self.output_slices,
                add_inputs=False,
            )
            return lora_output

        output, lora_result = maybe_execute_in_parallel(
            base_fn,
            lora_fn,
            self._events[0],
            self._events[1],
            self._lora_stream,
        )

        original_shape = output.shape if output.ndim == 3 else None

        # In transformers backend, x and output have extra batch dimension like
        # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
        # therefore we need to flatten the batch dimensions.
        if x.ndim == 3 and output.ndim == 3:
            output = output.flatten(0, 1)
            x = x.flatten(0, 1)

        output.add_(lora_result)

        # Reshape the flattened output back to its original shape,
        # as some MM encoders cannot handle flattened inputs.
        if original_shape is not None:
            output = output.reshape(original_shape)

        return output

    @property
    def weight(self) -> torch.Tensor:
        # unquantizedLinear
        if hasattr(self.base_layer, "weight"):
            return self.base_layer.weight
        # Compressed Tensor
        elif hasattr(self.base_layer, "weight_packed"):
            return self.base_layer.weight_packed
        # GPTQ/AWQ
        elif hasattr(self.base_layer, "qweight"):
            return self.base_layer.qweight
        # marlin
        elif hasattr(self.base_layer, "B"):
            return self.base_layer.B
        else:
            raise ValueError(f"Unsupported base layer: {self.base_layer}")

    @property
    def bias(self) -> torch.Tensor | None:
        if hasattr(self.base_layer, "bias"):
            return self.base_layer.bias
        else:
            return None

_apply_async_impl

_apply_async_impl(
    x: Tensor, bias: Tensor | None = None
) -> Tensor

Forward pass with base linear and LoRA on separate CUDA streams for overlap, using maybe_execute_in_parallel. Base layer runs on default stream; LoRA runs on aux stream.

Source code in vllm/lora/layers/base_linear.py
def _apply_async_impl(
    self, x: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
    """
    Forward pass with base linear and LoRA on separate CUDA streams
    for overlap, using maybe_execute_in_parallel.
    Base layer runs on default stream; LoRA runs on aux stream.
    """
    assert envs.VLLM_LORA_ENABLE_DUAL_STREAM
    assert x.ndim in (2, 3)
    num_tokens = x.size(0) if x.ndim == 2 else x.size(1)
    output_size = sum(self.output_slices)

    def base_fn() -> torch.Tensor:
        return self.base_layer.quant_method.apply(self.base_layer, x, bias)

    def lora_fn() -> torch.Tensor:
        # Must be zeros, not empty: _lora_expand_kernel exits early (without
        # writing) when lora_id == -1 (no active LoRA). If uninitialized,
        # output.add_(lora_result) below would corrupt the base output.
        lora_output = torch.zeros(
            (num_tokens, output_size),
            device=self.device,
            dtype=x.dtype,
        )

        # Flatten the batch dimension for the transformers backend
        # (which uses shape (1, seq_len, hidden)), matching _apply_sync.
        x_2d = x.flatten(0, 1) if x.ndim == 3 else x
        self.punica_wrapper.add_lora_linear(
            lora_output,
            x_2d,
            self.lora_a_stacked,
            self.lora_b_stacked,
            1.0,
            self.output_slices,
            add_inputs=False,
        )
        return lora_output

    output, lora_result = maybe_execute_in_parallel(
        base_fn,
        lora_fn,
        self._events[0],
        self._events[1],
        self._lora_stream,
    )

    original_shape = output.shape if output.ndim == 3 else None

    # In transformers backend, x and output have extra batch dimension like
    # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
    # therefore we need to flatten the batch dimensions.
    if x.ndim == 3 and output.ndim == 3:
        output = output.flatten(0, 1)
        x = x.flatten(0, 1)

    output.add_(lora_result)

    # Reshape the flattened output back to its original shape,
    # as some MM encoders cannot handle flattened inputs.
    if original_shape is not None:
        output = output.reshape(original_shape)

    return output