class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: LinearBase):
super().__init__()
self._enable_aux_cuda_stream = envs.VLLM_LORA_ENABLE_DUAL_STREAM
self.base_layer = base_layer
self.input_size = self.base_layer.input_size
# Ensure tp_size and tp_rank consistency with the base_layer.
self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(self.base_layer)
self._init_lora_stream_context()
self.output_slices: tuple[int, ...]
self.output_size: int
self.n_slices: int
def _init_lora_stream_context(self) -> None:
if not self._enable_aux_cuda_stream:
return
vllm_config = get_current_vllm_config()
self._lora_stream = _get_lora_aux_cuda_stream()
assert current_platform.is_cuda_alike()
self._events = [torch.cuda.Event(), torch.cuda.Event()]
# lora_linear avoids prefix conflicts with the base layer
self.layer_name = self.base_layer.prefix + ".lora_linear_async"
compilation_config = vllm_config.compilation_config
if self.layer_name in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(self.layer_name))
compilation_config.static_forward_context[self.layer_name] = self
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
self.lora_config = lora_config
if isinstance(self.base_layer, ReplicatedLinear):
lora_a_out_size = lora_config.max_lora_rank
lora_b_out_size = self.output_size
elif isinstance(self.base_layer, ColumnParallelLinear):
lora_a_out_size = (
lora_config.max_lora_rank
if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size)
)
lora_b_out_size = self.output_size
elif isinstance(self.base_layer, RowParallelLinear):
lora_a_out_size = lora_config.max_lora_rank
lora_b_out_size = (
self.output_size
if not lora_config.fully_sharded_loras
else divide(self.output_size, self.tp_size)
)
else:
raise NotImplementedError
self.lora_a_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_a_out_size,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.n_slices)
)
self.lora_b_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_b_out_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.n_slices)
)
self.output_slices = (self.lora_b_stacked[0].shape[2],)
def reset_lora(self, index: int):
for s_index in range(self.n_slices):
self.lora_a_stacked[s_index][index] = 0
self.lora_b_stacked[s_index][index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will
# override this function.
assert isinstance(lora_a, torch.Tensor)
assert isinstance(lora_b, torch.Tensor)
assert (
len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
)
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True
)
self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
# is_forward_context_available for tower modules
if self._enable_aux_cuda_stream and is_forward_context_available():
output_size = sum(self.output_slices)
return torch.ops.vllm.lora_linear_async(
self.layer_name, output_size, x, bias
)
else:
return self._apply_sync(x, bias)
def _apply_sync(
self, x: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
original_shape = output.shape if output.ndim == 3 else None
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
)
if not current_platform.can_update_inplace():
output = lora_output
# Reshape the flattened output back to its original shape,
# as some MM encoders cannot handle flattened inputs.
if original_shape is not None:
output = output.reshape(original_shape)
return output
def _apply_async_impl(
self, x: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
"""
Forward pass with base linear and LoRA on separate CUDA streams
for overlap, using maybe_execute_in_parallel.
Base layer runs on default stream; LoRA runs on aux stream.
"""
assert envs.VLLM_LORA_ENABLE_DUAL_STREAM
assert x.ndim in (2, 3)
num_tokens = x.size(0) if x.ndim == 2 else x.size(1)
output_size = sum(self.output_slices)
def base_fn() -> torch.Tensor:
return self.base_layer.quant_method.apply(self.base_layer, x, bias)
def lora_fn() -> torch.Tensor:
# Must be zeros, not empty: _lora_expand_kernel exits early (without
# writing) when lora_id == -1 (no active LoRA). If uninitialized,
# output.add_(lora_result) below would corrupt the base output.
lora_output = torch.zeros(
(num_tokens, output_size),
device=self.device,
dtype=x.dtype,
)
# Flatten the batch dimension for the transformers backend
# (which uses shape (1, seq_len, hidden)), matching _apply_sync.
x_2d = x.flatten(0, 1) if x.ndim == 3 else x
self.punica_wrapper.add_lora_linear(
lora_output,
x_2d,
self.lora_a_stacked,
self.lora_b_stacked,
1.0,
self.output_slices,
add_inputs=False,
)
return lora_output
output, lora_result = maybe_execute_in_parallel(
base_fn,
lora_fn,
self._events[0],
self._events[1],
self._lora_stream,
)
original_shape = output.shape if output.ndim == 3 else None
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)
output.add_(lora_result)
# Reshape the flattened output back to its original shape,
# as some MM encoders cannot handle flattened inputs.
if original_shape is not None:
output = output.reshape(original_shape)
return output
@property
def weight(self) -> torch.Tensor:
# unquantizedLinear
if hasattr(self.base_layer, "weight"):
return self.base_layer.weight
# Compressed Tensor
elif hasattr(self.base_layer, "weight_packed"):
return self.base_layer.weight_packed
# GPTQ/AWQ
elif hasattr(self.base_layer, "qweight"):
return self.base_layer.qweight
# marlin
elif hasattr(self.base_layer, "B"):
return self.base_layer.B
else:
raise ValueError(f"Unsupported base layer: {self.base_layer}")
@property
def bias(self) -> torch.Tensor | None:
if hasattr(self.base_layer, "bias"):
return self.base_layer.bias
else:
return None