vllm.model_executor.kernels.linear.scaled_mm.flashinfer ¶
FlashInferFp8DeepGEMMDynamicBlockScaledKernel ¶
Bases: Fp8BlockScaledDynamicMMLinearKernel
Conditional FlashInfer / DeepGEMM FP8 block-scaled GEMM.
Dispatches between two kernels based on input batch size: - Small batches (M < 32): FlashInfer's swapAB trick for better utilisation. - Large batches (M >= 32): DeepGEMM for peak throughput.
apply_input_quant is False because FlashInfer accepts BF16 input and handles FP8 conversion internally. The DeepGEMM branch therefore quantises BF16→FP8 inside apply_mm via a closure before dispatching to the DeepGEMM kernel — keeping both branches compatible with the single BF16 tensor operand list passed by torch.cond.
Source code in vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py
_dynamic_flashinfer_deepgemm_blockscale_gemm_fake ¶
_dynamic_flashinfer_deepgemm_blockscale_gemm_fake(
input: Tensor,
weight: Tensor,
weight_scale: Tensor,
group_size: int,
use_deep_gemm_e8m0: bool,
) -> Tensor
Required fake/meta implementation for torch.compile graph tracing.
Source code in vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py
_dynamic_flashinfer_deepgemm_blockscale_gemm_impl ¶
_dynamic_flashinfer_deepgemm_blockscale_gemm_impl(
input: Tensor,
weight: Tensor,
weight_scale: Tensor,
group_size: int,
use_deep_gemm_e8m0: bool,
) -> Tensor
Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection.
This function switches between two optimized kernels based on the input batch size: - For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization. - For larger batches (M >= 32): Uses the official DeepGEMM kernel.
The conditional logic must use torch.cond() instead of a simple if-else statement to maintain compatibility with torch.compile graph compilation.
This batch-size-dependent selection is essential for maintaining model accuracy. Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1 when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy drop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input | Tensor | Input tensor of shape (batch_size, input_dim) in FP8 format | required |
weight | Tensor | Weight tensor of shape (output_dim, input_dim) in FP8 format | required |
weight_scale | Tensor | Scale factors for weight quantization (per-group) | required |
group_size | int | Quantization group size for the weight tensor | required |
use_deep_gemm_e8m0 | bool | Whether to use the E8M0 format in DeepGEMM quantization | required |
Returns:
| Type | Description |
|---|---|
Tensor | Output tensor of shape (batch_size, output_dim) in bfloat16 format |
Source code in vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 | |
_flashinfer_fp8_blockscale_gemm_fake ¶
_flashinfer_fp8_blockscale_gemm_fake(
input: Tensor, weight: Tensor, weight_scale: Tensor
) -> Tensor
Required fake/meta implementation for torch.compile graph tracing.