class NixlConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(
self,
vllm_config: "VllmConfig",
engine_id: str,
kv_cache_config: "KVCacheConfig",
):
if NixlWrapper is None:
logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available")
logger.info("Initializing NIXL wrapper")
logger.info("Initializing NIXL worker %s", engine_id)
# Config.
self.vllm_config = vllm_config
# mypy will complain on re-assignment otherwise.
self.block_size: int = cast(int, vllm_config.cache_config.block_size)
if vllm_config.kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set for NixlConnector")
self.kv_transfer_config = vllm_config.kv_transfer_config
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"]
)
self._is_hma_required = (
not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
and any(
not isinstance(g.kv_cache_spec, FullAttentionSpec)
for g in kv_cache_config.kv_cache_groups
)
)
self.kv_cache_config = kv_cache_config
self._layer_specs = {
layer: group.kv_cache_spec
for group in kv_cache_config.kv_cache_groups
for layer in group.layer_names
}
self.hma_group_size = len(kv_cache_config.kv_cache_tensors)
# ---- Mamba model state (derived from model config) ----
self._is_mamba_group = [
isinstance(group.kv_cache_spec, MambaSpec)
for group in kv_cache_config.kv_cache_groups
]
mamba_ssm_size = (0, 0)
self._has_mamba = any(self._is_mamba_group)
if self._has_mamba:
assert self._is_hma_required
mamba_spec = next(
spec
for spec in self._layer_specs.values()
if isinstance(spec, MambaSpec)
)
conv_nbytes, ssm_nbytes = (
torch.tensor([], dtype=mamba_spec.dtypes[0]).element_size(), # type: ignore[misc]
torch.tensor([], dtype=mamba_spec.dtypes[1]).element_size(), # type: ignore[misc]
)
conv_shape, ssm_shape = (
torch.Size(mamba_spec.shapes[0]),
torch.Size(mamba_spec.shapes[1]),
)
mamba_ssm_size = (
conv_shape.numel() * conv_nbytes,
ssm_shape.numel() * ssm_nbytes,
)
self._mamba_ssm_size = mamba_ssm_size
# Conv state sub-projection decomposition (None when no Mamba).
# The 3-read transfer requires DS (dim, state_len) conv layout so
# that x/B/C sub-projections are contiguous in memory.
self._conv_decomp: MambaConvSplitInfo | None = None
if self._has_mamba:
assert is_conv_state_dim_first(), (
"3-read Mamba conv transfer requires DS conv state layout. "
"Set VLLM_SSM_CONV_STATE_LAYOUT=DS"
)
local_tp = vllm_config.parallel_config.tensor_parallel_size
self._conv_decomp = derive_mamba_conv_split(mamba_spec, local_tp)
# Agent.
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
# Configure NIXL num_threads to avoid UAR exhaustion on Mellanox NICs.
# Each UCX thread allocates UARs (doorbell pages) via DevX, and
# excessive NIXL UAR usage can exhaust NIC UAR space. This can cause
# components like NVSHMEM (used by DeepEP kernels) to fail during RDMA
# initialization with "mlx5dv_devx_alloc_uar" errors.
# Ref: https://network.nvidia.com/files/doc-2020/ethernet-adapters-programming-manual.pdf#page=63
num_threads = vllm_config.kv_transfer_config.get_from_extra_config(
"num_threads", 4
)
if nixl_agent_config is None:
config = None
else:
# Enable telemetry by default for NIXL 0.7.1 and above.
config = (
nixl_agent_config(backends=self.nixl_backends, capture_telemetry=True)
if len(non_ucx_backends) > 0
else nixl_agent_config(num_threads=num_threads, capture_telemetry=True)
)
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
# Metadata.
self.engine_id: EngineId = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.num_blocks = kv_cache_config.num_blocks
self.enable_permute_local_kv = False
self.enable_heterogeneous_attn_post_process = False
# KV Caches and nixl tracking data.
self.device_type = current_platform.device_type
self.kv_buffer_device: str = vllm_config.kv_transfer_config.kv_buffer_device
if self.device_type not in _NIXL_SUPPORTED_DEVICE:
raise RuntimeError(f"{self.device_type} is not supported.")
elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[self.device_type]:
raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
"is not supported."
)
self.device_kv_caches: dict[str, torch.Tensor] = {}
# cpu kv buffer for xfer
# used when device memory can not be registered under nixl
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
if self.device_type == "cpu":
self.use_host_buffer = False
else:
self.use_host_buffer = self.kv_buffer_device == "cpu"
# reserve different cores for start_load_kv() from model_forward()
if self.device_type == "cpu":
numa_core_list = current_platform.discover_numa_topology()
# setup one last core in each numa for kv transfer.
rsv_cores_for_kv = [
max(each_numa_core_list) for each_numa_core_list in numa_core_list
]
if rsv_cores_for_kv:
if not hasattr(os, "sched_setaffinity"):
raise NotImplementedError(
"os.sched_setaffinity is not available on this platform"
)
os.sched_setaffinity(0, rsv_cores_for_kv)
# support for oot platform which can't register nixl memory
# type based on kv_buffer_device
nixl_memory_type = current_platform.get_nixl_memory_type()
if nixl_memory_type is None:
if self.kv_buffer_device in ["cuda", "xpu"]:
nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
nixl_memory_type = "DRAM"
if nixl_memory_type is None:
raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
"is not supported."
)
self.nixl_memory_type = nixl_memory_type
# Note: host xfer buffer ops when use_host_buffer is True
self.copy_blocks: CopyBlocksOp | None = None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
self.device_id: int = 0
# Current rank may pull from multiple remote TP workers.
# EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer
self.kv_caches_base_addr = defaultdict[EngineId, dict[int, list[int]]](dict)
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
self.num_regions = 0
# nixl_prepped_dlist_handle.
self.src_xfer_handles_by_block_size: dict[int, int] = {}
# Populated dynamically during handshake based on remote configuration.
# Keep track of regions at different tp_ratio values. tp_ratio->handles
self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {}
# Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}.
self.dst_xfer_side_handles = defaultdict[EngineId, dict[int, int]](dict)
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks.
self.dst_num_blocks: dict[EngineId, int] = {}
self._registered_descs: list[Any] = []
# ---- Mamba-HMA per-engine state (only used when self._has_mamba) ----
# Per-engine transfer config (source of truth for FA/mamba sizing).
self._transfer_configs: dict[str, HeteroTPTransferConfig] = {}
# NOTE (ZhanqiuHu): _mamba_phys_ratio MUST be per-engine.
# compute_mamba_phys_ratio = ceil((conv_bytes + ssm_bytes) / block_len)
# where conv/ssm bytes are per-TP-rank (dimension-sharded). With
# heterogeneous TP the per-rank sizes differ, so the ratio differs:
# e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261.
self._mamba_phys_ratio: dict[EngineId, int] = {}
# In progress transfers.
# [req_id -> list[handle]]
self._recving_metadata: dict[ReqId, ReqMeta] = {}
self._recving_transfers = defaultdict[ReqId, list[TransferHandle]](list)
# Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {}
# Set of requests that have been part of a batch, regardless of status.
self._reqs_to_process: set[ReqId] = set()
# invalid blocks from failed NIXL operations
self._invalid_block_ids: set[int] = set()
# requests that skipped transfer (handshake or transfer failures)
self._failed_recv_reqs: set[ReqId] = set()
# Handshake metadata of this worker for NIXL transfers.
self.xfer_handshake_metadata: NixlHandshakePayload | None = None
# Background thread for initializing new NIXL handshakes.
self._handshake_initiation_executor = ThreadPoolExecutor(
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
max_workers=1,
thread_name_prefix="vllm-nixl-handshake-initiator",
)
self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]()
self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {}
# Protects _handshake_futures and _remote_agents.
self._handshake_lock = threading.RLock()
self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config
self.use_mla = self.model_config.use_mla
# Get the attention backend from the first layer
# NOTE (NickLucche) models with multiple backends are not supported yet
self.attn_backends = get_current_attn_backends(vllm_config)
self.backend_name = self.attn_backends[0].get_name()
self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.info("Detected attention backend %s", self.backend_name)
logger.info("Detected kv cache layout %s", self.kv_cache_layout)
# lazy initialized in register_kv_caches
self.compat_hash: str | None = None
self.kv_topo: TpKVTopology | None = None
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats()
self._physical_blocks_per_logical_kv_block = 1
self._sync_block_size_with_kernel()
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
)
def _sync_block_size_with_kernel(self) -> None:
backends = get_current_attn_backends(self.vllm_config)
kernel_block_size = select_common_block_size(self.block_size, backends)
# Number of blocks not accounting for kernel block mismatches
self._logical_num_blocks = self.num_blocks
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter.",
self.block_size,
kernel_block_size,
)
assert self.block_size > kernel_block_size
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size
self._block_size[self.engine_id] = kernel_block_size
self.num_blocks *= self._physical_blocks_per_logical_kv_block
def _nixl_handshake(
self,
host: str,
port: int,
remote_tp_size: int,
expected_engine_id: str,
) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance."""
# the first time we connect to a remote agent.
# be careful, the handshake happens in a background thread.
# it does not have an active cuda context until any cuda runtime
# call is made. when UCX fails to find a valid cuda context, it will
# disable any cuda ipc communication, essentially disabling any NVLink
# communication.
# when we are using device buffers, we need to set the device
# explicitly to make sure the handshake background thread has a valid
# cuda context.
if not self.use_host_buffer:
current_platform.set_device(self.device_id)
# When target instance TP > local TP, we need to perform multiple
# handshakes. Do it in a single background job for simplicity.
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
assert self.kv_topo is not None
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
remote_rank_to_agent_name = {}
path = make_zmq_path("tcp", host, port)
with zmq_ctx(zmq.REQ, path) as sock:
for remote_rank in p_remote_ranks:
logger.debug(
"Querying metadata on path: %s at remote tp rank %s",
path,
remote_rank,
)
start_time = time.perf_counter()
# Send query for the request.
msg = msgspec.msgpack.encode((GET_META_MSG, remote_rank))
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(msg)
handshake_bytes = sock.recv()
# Decode handshake payload to get compatibility hash
handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
try:
handshake_payload = handshake_decoder.decode(handshake_bytes)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
raise RuntimeError(
f"Failed to decode NixlHandshakePayload. This likely indicates "
f"an incompatibility between connector version. Error: {e}"
) from e
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s",
got_metadata_time - start_time,
)
# Check compatibility hash BEFORE decoding agent metadata
assert self.compat_hash is not None
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
):
raise RuntimeError(
f"NIXL compatibility hash mismatch. "
f"Local: {self.compat_hash}, "
f"Remote: {handshake_payload.compatibility_hash}. "
f"Prefill and decode instances have incompatible "
f"configurations. This may be due to: different vLLM versions,"
f" models, dtypes, KV cache layouts, attention backends, etc. "
f"Both instances must use identical configurations."
f"Disable this check using "
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
f'{{"enforce_handshake_compat": false}}}}\''
)
logger.info(
"NIXL compatibility check passed (hash: %s)",
handshake_payload.compatibility_hash,
)
# Decode agent metadata
metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
try:
metadata = metadata_decoder.decode(
handshake_payload.agent_metadata_bytes
)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
# This should not happen if hash matched
raise RuntimeError(
f"Failed to decode NixlAgentMetadata. Error: {e}"
) from e
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, remote_rank, remote_tp_size
)
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
remote_rank_to_agent_name[remote_rank] = remote_agent_name
return remote_rank_to_agent_name
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
Initialize transfer buffer in CPU mem for accelerators
NOT directly supported by NIXL (e.g., tpu)
"""
xfer_buffers: dict[str, torch.Tensor] = {}
inv_order = [0, 1, 3, 2, 4]
try:
for layer_name, kv_cache in kv_caches.items():
kv_shape = kv_cache.shape
kv_dtype = kv_cache.dtype
permute_shape = False
if (
self.kv_cache_layout == "NHD"
and self.vllm_config.kv_transfer_config is not None
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
):
logger.info_once(
"'enable_permute_local_kv' flag is enabled while "
"device KV Layout is NHD. Init host buffer with"
" HND to better support Decode/Prefill TP_ratio > 1."
)
# Since NHD will not support Decode/Prefill TP_ratio > 1,
# we can leverage host_buffer for permute
self.host_buffer_kv_cache_layout = "HND"
kv_shape = (
tuple(kv_shape[i] for i in inv_order)
if not self.use_mla
else kv_shape
)
permute_shape = not self.use_mla
xfer_buffers[layer_name] = torch.empty(
kv_shape, dtype=kv_dtype, device="cpu"
)
if permute_shape:
xfer_buffers[layer_name] = xfer_buffers[layer_name].permute(
inv_order
)
except MemoryError as e:
logger.error("NIXLConnectorWorker gets %s.", e)
raise
self.host_xfer_buffers = xfer_buffers
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""Assign copy (d2h, h2d) operations when host buffer is used."""
# Set a no-op if the host buffer is not cpu.
if self.kv_buffer_device != "cpu":
return
# Set a no-op if self.device_type is 'cpu'.
if self.device_type == "cpu":
return
assert self.use_host_buffer
self.copy_blocks = copy_operation
def _log_failure(
self,
failure_type: str,
req_id: str | None,
msg: str = "",
error: Exception | None = None,
meta: ReqMeta | None = None,
**extra_context,
):
"""Log transfer failure with structured context for easier debugging."""
context: dict[str, Any] = {
"failure_type": failure_type,
"request_id": req_id,
"engine_id": self.engine_id,
}
if meta is None and req_id is not None:
# Try to get metadata from in progress transfers when not provided
meta = self._recving_metadata.get(req_id)
if meta and meta.remote:
context.update(
{
"remote_engine_id": meta.remote.engine_id,
"remote_request_id": meta.remote.request_id,
"remote_host": meta.remote.host,
"remote_port": meta.remote.port,
"num_local_blocks": sum(
len(group) for group in meta.local_block_ids
),
"num_remote_blocks": sum(
len(group) for group in meta.remote.block_ids
),
"local_block_ids_sample": meta.local_block_ids[0][:10]
if meta.local_block_ids
else [],
}
)
context.update(extra_context)
if msg:
failure_type = f"{failure_type}. {msg}"
logger.error(
"NIXL transfer failure: %s | Context: %s",
failure_type,
context,
exc_info=error is not None,
stacklevel=2,
)
def _background_nixl_handshake(
self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta
):
# Do NIXL handshake in background and add to _ready_requests when done.
fut = self._handshake_futures.get(remote_engine_id)
if fut is None:
assert meta.remote is not None
fut = self._handshake_initiation_executor.submit(
self._nixl_handshake,
meta.remote.host,
meta.remote.port,
meta.tp_size,
remote_engine_id,
)
self._handshake_futures[remote_engine_id] = fut
def done_callback(f: Future[dict[int, str]], eid=remote_engine_id):
with self._handshake_lock:
del self._handshake_futures[eid]
try:
self._remote_agents[eid] = f.result()
except Exception as e:
self._log_failure(
failure_type="handshake_setup_failed",
req_id=None,
error=e,
remote_engine_id=eid,
)
fut.add_done_callback(done_callback)
# check handshake success before proceeding with request
def request_ready(f: Future[Any], entry=(req_id, meta)):
try:
# check if handshake succeeded
f.result()
self._ready_requests.put(entry)
except Exception as e:
# handshake failed - mark blocks as invalid
self._log_failure(
failure_type="handshake_failed",
req_id=req_id,
error=e,
meta=meta,
)
if (
req_meta := self._recving_metadata.get(req_id)
) and not self._is_hma_required:
self._invalid_block_ids.update(req_meta.local_block_ids[0])
self._failed_recv_reqs.add(req_id)
fut.add_done_callback(request_ready)
def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None:
"""Register a cross-layers KV cache tensor with NIXL.
`use_uniform_kv_cache()` guarantees a single KV cache group whose
layers all share the same `AttentionSpec`, so any layer name from
`_layer_specs` yields the correct per-layer spec for `page_size_bytes`.
"""
first_layer = next(iter(self._layer_specs))
# Forwarding a real layer name rather than a synthetic key
self.register_kv_caches({first_layer: kv_cache})
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=self.attn_backends,
# SSM States come in tuples (ssm, conv)
tensor_shape=next(iter(kv_caches.values())).shape
if not self._has_mamba
else None,
is_mamba=self._has_mamba,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
)
if self.use_host_buffer:
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
assert len(self.host_xfer_buffers) == len(kv_caches), (
f"host_buffer: {len(self.host_xfer_buffers)}, "
f"kv_caches: {len(kv_caches)}"
)
xfer_buffers = self.host_xfer_buffers
else:
xfer_buffers = kv_caches
assert not self.host_xfer_buffers, (
"host_xfer_buffer should not be initialized when "
f"kv_buffer_device is {self.kv_buffer_device}"
)
logger.info(
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
"use_host_buffer: %s",
self.use_mla,
self.kv_buffer_device,
self.use_host_buffer,
)
caches_data = []
# With hybrid allocator, layers can share a kv cache tensor
seen_base_addresses = []
# Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can
# elegantly support MLA and any cases where the K and V tensors
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim).
tensor_size_bytes = None
# Enable different block lengths for different layers *only* when MLA is used.
# This is not used for SSM layers, which use the counterpart `mamba_ssm_size`.
self.block_len_per_layer = list[int]()
for layer_name, cache_or_caches in xfer_buffers.items():
# NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to
# that of FI, with block laid out as in `get_backend_aware_kv_block_len`.
# However, physical page_size may differ when kernel requires a specific
# block size. This leads to SSM and FA layers having different num_blocks.
# `_physical_blocks_per_logical_kv_block` ratio is used to adjust for this.
layer_spec = self._layer_specs[layer_name]
if isinstance(layer_spec, UniformTypeKVCacheSpecs):
# MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs
layer_spec = layer_spec.kv_cache_specs[layer_name]
cache_list = self.kv_topo.get_transfer_cache_regions(
cache_or_caches, layer_spec
)
# `layer_spec.page_size_bytes` only accounts for logical page_size, that is
# the page_size assuming constant `self._logical_num_blocks`.
physical_page_size = (
layer_spec.page_size_bytes
if isinstance(layer_spec, MambaSpec)
else layer_spec.page_size_bytes
// self._physical_blocks_per_logical_kv_block
)
# For when registering multiple tensors eg K/V in separate regions.
physical_page_size = physical_page_size // len(cache_list)
if self.kv_topo._cross_layers_blocks:
# When cross-layers blocks are used, multiply by number of layers
physical_page_size = physical_page_size * len(
self.kv_cache_config.kv_cache_tensors
)
num_blocks = (
self._logical_num_blocks
if isinstance(layer_spec, MambaSpec)
else self.num_blocks
)
# `page_size` accounts for physical blocks, st KVCache is always
# [`num_blocks` * `page_size`]
curr_tensor_size_bytes = num_blocks * physical_page_size
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
# TODO (NickLucche) we could eventually unify how we handle FA/FI regions,
# registering a single tensor for both K/V and splitting logically like FI.
for cache in cache_list:
base_addr = cache.data_ptr()
if base_addr in seen_base_addresses:
# NOTE (NickLucche) HMA employs memory pooling to share tensors
# across groups. This results in skipping all tensors but the ones
# pointed to by group0. Also, generally we will have more blocks
# per tensor but fewer regions.
logger.debug("Skipping %s because it's already seen", layer_name)
continue
logger.debug(
"Registering layer %s with cache shape: %s", layer_name, cache.shape
)
seen_base_addresses.append(base_addr)
# Only record non-Mamba page sizes.
if isinstance(layer_spec, MambaSpec):
self.block_len_per_layer.append(
physical_page_size // self._physical_blocks_per_logical_kv_block
)
else:
self.block_len_per_layer.append(physical_page_size)
assert cache.shape[0] == num_blocks, (
"All kv cache tensors must have the same number of blocks"
)
if not self.use_mla:
# Different kv cache shape is not supported by HeteroTP.
# This must also hold true for Mamba-like models.
assert tensor_size_bytes == curr_tensor_size_bytes, (
"All kv cache tensors must have the same size"
)
# Need to make sure the device ID is non-negative for NIXL,
# Torch uses -1 to indicate CPU tensors.
self.device_id = max(cache.get_device(), 0)
caches_data.append(
(base_addr, curr_tensor_size_bytes, self.device_id, "")
)
logger.debug(
"Different block lengths collected: %s", set(self.block_len_per_layer)
)
assert len(self.block_len_per_layer) == len(seen_base_addresses)
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data)
if self.kv_topo.is_kv_layout_blocks_first:
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to
# split on kv_heads dim as required by heterogeneous TP, one must
# be able to index K/V separately. Hence we double the number
# of 'virtual' regions here and halve `block_len` below.
# Similarly for Mamba layers, we register SSM+Conv as a single region and
# then duplicate it logically to be able to index SSM/Conv separately.
self.num_regions *= 2
# Total local FA descriptors (boundary between FA and mamba descs).
self.num_descs = self.num_regions * self.num_blocks
descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type)
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends)
logger.debug("Done registering descs")
self._registered_descs.append(descs)
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self._has_mamba:
self._mamba_phys_ratio[self.engine_id] = (
self._physical_blocks_per_logical_kv_block
)
logger.info(
"Hybrid SSM registration: num_blocks=%s, "
"logical_num_blocks=%s, ratio=%s, num_regions=%s, "
"num_descs=%s, mamba_ssm_size=%s, block_len_per_layer=%s",
self.num_blocks,
self._logical_num_blocks,
self._physical_blocks_per_logical_kv_block,
self.num_regions,
self.num_descs,
self._mamba_ssm_size,
set(self.block_len_per_layer),
)
# Register local/src descr for NIXL xfer.
self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = (
self.register_local_xfer_handler(self.block_size)
)
# After KV Caches registered, listen for new connections.
agent_metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
device_id=self.device_id,
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer,
kv_cache_layout=self.kv_cache_layout
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout,
block_size=self.block_size,
ssm_sizes=self._mamba_ssm_size,
attn_backend_name=self.backend_name,
)
# Wrap metadata in payload with hash for defensive decoding
assert self.compat_hash is not None
encoder = msgspec.msgpack.Encoder()
self.xfer_handshake_metadata = NixlHandshakePayload(
compatibility_hash=self.compat_hash,
agent_metadata_bytes=encoder.encode(agent_metadata),
)
def _build_mamba_local(
self,
base_addresses: list[int],
block_size_ratio: int,
) -> list[tuple[int, int, int]]:
"""Build 4 desc regions (x, B, C, ssm) per layer for local mamba
blocks, enabling the 3-read transfer with DS conv layout."""
assert block_size_ratio == 1, (
"Mamba 3-read transfer with block_size_ratio != 1 is not tested. "
f"Got block_size_ratio={block_size_ratio}."
)
assert self._conv_decomp is not None
conv_offsets = self._conv_decomp.local_conv_offsets
conv_size, ssm_size = self._mamba_ssm_size
num_blocks = self._logical_num_blocks * block_size_ratio
phys_ratio = self._physical_blocks_per_logical_kv_block
result: list[tuple[int, int, int]] = []
for i, base_addr in enumerate(base_addresses):
page_stride = self.block_len_per_layer[i] // block_size_ratio * phys_ratio
for off, sz in conv_offsets:
for blk in range(num_blocks):
result.append(
(base_addr + blk * page_stride + off, sz, self.device_id)
)
# SSM temporal state follows the conv state.
for blk in range(num_blocks):
result.append(
(
base_addr + blk * page_stride + conv_size,
ssm_size,
self.device_id,
)
)
return result
def _build_fa_remote_for_mamba(
self,
nixl_agent_meta: NixlAgentMetadata,
transfer_cfg: HeteroTPTransferConfig,
block_size_ratio: int,
kv_topo: TpKVTopology,
) -> list[tuple[int, int, int]]:
"""Build remote FA descriptors for mamba models.
Uses transfer_cfg for GQA-aware FA divisor and head-based rank offset
instead of the standard uniform tp_ratio split.
"""
assert block_size_ratio == 1, (
"Mamba 3-read transfer with block_size_ratio != 1 is not tested. "
f"Got block_size_ratio={block_size_ratio}."
)
# TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA
# hetero-TP logic stabilizes.
tp_ratio = transfer_cfg.tp_ratio
result: list[tuple[int, int, int]] = []
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
local_block_len = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=True, mamba_view=False
)
remote_kv_block_len = local_block_len // block_size_ratio
if block_size_ratio > 1:
local_block_len = remote_kv_block_len
if tp_ratio < 0 and not self.use_mla:
local_block_len = local_block_len // transfer_cfg.physical_fa_num_reads
rank_offset = transfer_cfg.fa_rank_offset(remote_kv_block_len)
num_blocks = nixl_agent_meta.num_blocks
page_size = nixl_agent_meta.block_lens[i]
for block_id in range(num_blocks):
block_offset = block_id * page_size
addr = base_addr + block_offset + rank_offset
result.append((addr, local_block_len, nixl_agent_meta.device_id))
if kv_topo.is_kv_layout_blocks_first:
second_split = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=False, mamba_view=False
)
if tp_ratio < 0 and not self.use_mla:
second_split = second_split // transfer_cfg.physical_fa_num_reads
for block_id in range(num_blocks):
block_offset = block_id * page_size
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
result.append((v_addr, second_split, nixl_agent_meta.device_id))
return result
def _build_mamba_remote(
self,
nixl_agent_meta: NixlAgentMetadata,
tp_ratio: int,
) -> list[tuple[int, int, int]]:
"""Build 4 remote desc regions (x, B, C, ssm) per layer for
the 3-read transfer. For hetero-TP, each D rank reads only its
sub-projection slice from the P rank."""
assert self._conv_decomp is not None
effective_ratio = max(tp_ratio, 1)
# Mamba conv state is always TP-sharded, even when attention KV
# is replicated (num_kv_heads < tp_size).
local_offset = self.tp_rank % effective_ratio
conv_size_remote = nixl_agent_meta.ssm_sizes[0]
if tp_ratio >= 1:
# D_TP >= P_TP: P page is larger, D reads its slice.
conv_offsets = self._conv_decomp.remote_conv_offsets(
local_offset, effective_ratio
)
ssm_read_size = self._mamba_ssm_size[1]
else:
# NOTE (ZhanqiuHu): tp_ratio < 0 means P_TP > D_TP, so P pages
# are smaller than D's. self._conv_decomp has D-sized dimensions,
# but we need P-sized offsets. Scale down by |tp_ratio|.
abs_ratio = -tp_ratio
xb_p = self._conv_decomp.x_bytes // abs_ratio
bb_p = self._conv_decomp.b_bytes // abs_ratio
conv_offsets = [(0, xb_p), (xb_p, bb_p), (xb_p + bb_p, bb_p)]
ssm_read_size = nixl_agent_meta.ssm_sizes[1]
remote_ratio = self._mamba_phys_ratio[nixl_agent_meta.engine_id]
num_blocks = nixl_agent_meta.num_blocks // remote_ratio
device_id = nixl_agent_meta.device_id
result: list[tuple[int, int, int]] = []
# NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case
# block lengths vary across layers (e.g. MLA).
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
page_stride = nixl_agent_meta.block_lens[i] * remote_ratio
for off, sz in conv_offsets:
for blk in range(num_blocks):
result.append((base_addr + blk * page_stride + off, sz, device_id))
# SSM temporal state is also TP-sharded on the heads dimension.
for blk in range(num_blocks):
ssm_addr = (
base_addr
+ blk * page_stride
+ conv_size_remote
+ local_offset * ssm_read_size
)
result.append((ssm_addr, ssm_read_size, device_id))
return result
def register_local_xfer_handler(
self,
block_size: int,
) -> tuple[int, list[tuple[int, int, int]]]:
"""
Function used for register local xfer handler with local block_size or
Remote block_size.
When local block_size is same as remote block_size, we use local block_size
to register local_xfer_handler during init.
When remote block size is less than local block size, we need to use
register another local_xfer_handler using remote block len to ensure
data copy correctness.
"""
assert self.kv_topo is not None
kv_topo = self.kv_topo
block_size_ratio = self.block_size // block_size
blocks_data: list[tuple[int, int, int]] = []
local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank]
def register_blocks(blocks_data: list[tuple[int, int, int]], mamba: bool):
for i, base_addr in enumerate(local_base_addresses):
# The new block_len is using prefill block_len;
# and num_blocks is multiple with N
kv_block_len = (
self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=True, mamba_view=mamba
)
// block_size_ratio
)
# Jump one page_size, but ssm page_size may be bigger when kernel
# locks block size to a specific value.
block_len_per_layer = (
self.block_len_per_layer[i]
// block_size_ratio
* (1 if not mamba else self._physical_blocks_per_logical_kv_block)
)
num_blocks = self._logical_num_blocks if mamba else self.num_blocks
num_blocks = num_blocks * block_size_ratio
for block_id in range(num_blocks):
block_offset = block_id * block_len_per_layer
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.device_id))
if kv_topo.is_kv_layout_blocks_first:
second_split = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=False, mamba_view=mamba
)
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
for block_id in range(num_blocks):
block_offset = block_id * block_len_per_layer
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
blocks_data.append((v_addr, second_split, self.device_id))
logger.debug(
"Created %s blocks for src engine %s and rank %s on device id %s",
len(blocks_data),
self.engine_id,
self.tp_rank,
self.device_id,
)
# NOTE (ZhanqiuHu): mamba=True path in register_blocks is not used
# right now — we use _build_mamba_local instead for the 3-read
# approach. However, we might still need this as a fallback for homogeneous TP.
register_blocks(blocks_data, mamba=False)
if self._has_mamba:
assert self.num_descs == len(blocks_data)
# TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the 3-read split is
# unnecessary — a single conv desc per block suffices. Consider
# adding a fast path that falls back to the standard 2-region
# registration (register_blocks mamba=True) when no hetero-TP
# remote has been seen. Currently we always register 4 regions
# because local descs are created before knowing the remote TP.
logger.debug("Registering local Mamba descriptors (4 regions/layer)")
blocks_data.extend(
self._build_mamba_local(local_base_addresses, block_size_ratio)
)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
# NIXL_INIT_AGENT to be used for preparations of local descs.
return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs), blocks_data
def add_remote_agent(
self,
nixl_agent_meta: NixlAgentMetadata,
remote_tp_rank: int = 0,
remote_tp_size: int = 1,
) -> str:
"""
Add the remote NIXL agent and prepare the descriptors for reading cache
blocks from remote.
In particular, handle both homogeneous and heterogeneous TP. The former
requires local rank_i to read from remote rank_i.
The latter, in the case of D.world_size < P.world_size, requires that a
local (D) TP worker reads from multiple remote (P) TP workers.
Conversely, assuming D.world_size > P.world_size, two or more local TP
workers will read from a single remote TP worker.
Here's an example for the last case described above (non-MLA):
rank_offset p_remote_tp_rank
(kv split no)
--------------------------------
0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ]
/
1 0 Worker1 ---- 2nd half of KV -----/
0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ]
/
1 1 Worker3 ---- 2nd half of KV -----/
Decoder TP workers Prefix TP workers
(world_size=4) (world_size=2)
tp_ratio = 4 // 2 = 2
Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim]
then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format.
Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio
first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split
along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0.
Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1.
Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0
so that the whole cache is shared by "tp_ratio" D TP workers.
For Mamba hetero-TP, both tp_ratio > 0 (D_TP > P_TP) and
tp_ratio < 0 (P_TP > D_TP) are supported by the 3-read transfer.
""" # noqa: E501
engine_id = nixl_agent_meta.engine_id
# TODO re-evaluate refreshing for scaling/recovery
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
logger.debug(
"Remote agent with engine_id %s and rank"
"%s already exchanged metadata, skip handshake.",
engine_id,
remote_tp_rank,
)
return self._remote_agents[engine_id][remote_tp_rank]
### Register remote agent metadata
if engine_id not in self._tp_size:
self._tp_size[engine_id] = remote_tp_size
if engine_id not in self._block_size:
self._block_size[engine_id] = nixl_agent_meta.block_size
remote_agent_name = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata
)
# Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id.
# Example:
# block_size_ratio > 1:
# remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
# local origin:| 0| 1| 8| 12|
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
assert self.kv_topo is not None
kv_topo = self.kv_topo
block_size_ratio = kv_topo.block_size_ratio_from_engine_id(engine_id)
if engine_id not in self.dst_num_blocks:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
if self._has_mamba:
self._mamba_phys_ratio[engine_id] = compute_mamba_phys_ratio(
nixl_agent_meta.ssm_sizes, nixl_agent_meta.block_lens[0]
)
# Keep track of remote agent kv caches base addresses.
self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
nixl_agent_meta.kv_caches_base_addr
)
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
# This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
# this is the ratio between the two sizes.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id)
# Handle tp_size>num_kv_heads: replicate KV cache.
indexes_into_remote = (
not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0
)
# Create transfer config (single source of truth for descriptor sizes).
if self._has_mamba and engine_id not in self._transfer_configs:
self._transfer_configs[engine_id] = HeteroTPTransferConfig(
tp_ratio=tp_ratio,
K=kv_topo.total_num_kv_heads,
d_tp=self.world_size,
p_tp=remote_tp_size,
d_rank=self.tp_rank,
use_mla=self.use_mla,
d_block_len=self.block_len_per_layer[0],
p_block_len=nixl_agent_meta.block_lens[0],
is_blocks_first=kv_topo.is_kv_layout_blocks_first,
)
logger.info("Created %s", self._transfer_configs[engine_id].describe())
logger.debug(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
engine_id,
remote_tp_rank,
tp_ratio,
)
### (Optional) Register local agent memory regions. MLA is not split.
if (
tp_ratio < 0
and not self.use_mla
and tp_ratio not in self.src_xfer_handles_by_tp_ratio
):
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
abs_tp = -tp_ratio
self.src_xfer_handles_by_tp_ratio[tp_ratio] = []
if self._has_mamba:
transfer_cfg = self._transfer_configs.get(engine_id)
assert transfer_cfg is not None
if transfer_cfg.needs_split_handles:
# Mamba-HMA: FA and Mamba use different split factors.
for handle_data in transfer_cfg.compute_split_handle_data(
self.src_blocks_data, self.num_descs, abs_tp
):
descs = self.nixl_wrapper.get_xfer_descs(
handle_data, self.nixl_memory_type
)
handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs
)
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
logger.info(
"Mamba-HMA split handles: targets=%s, fa_reads=%s, "
"fa_entry=%s, mamba_reads=%s, num_descs=%s",
transfer_cfg.transfer_targets,
transfer_cfg.physical_fa_num_reads,
transfer_cfg.fa_entry_size,
transfer_cfg.mamba_num_reads,
self.num_descs,
)
else:
# Original path: uniform divide by abs_tp (non-Mamba-HMA).
for i in range(abs_tp):
blocks_data = []
for memory_region in self.src_blocks_data:
addr, local_block_len, own_tp_rank = memory_region
remote_block_len = local_block_len // abs_tp
addr = addr + i * remote_block_len
blocks_data.append((addr, remote_block_len, own_tp_rank))
descs = self.nixl_wrapper.get_xfer_descs(
blocks_data, self.nixl_memory_type
)
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
### Register remote agent memory regions
blocks_data = []
# With homogeneous TP, D pulls the whole kv cache from corresponding
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
# Register all remote blocks, but only the corresponding kv heads.
def register_remote_blocks(
blocks_data: list[tuple[int, int, int]], mamba: bool
):
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
# Read our whole local region size from remote.
local_block_len = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=True, mamba_view=mamba
)
remote_kv_block_len = local_block_len // block_size_ratio
if block_size_ratio > 1:
# using remote kv_block_len as transfer unit
local_block_len = remote_kv_block_len
if tp_ratio < 0 and not self.use_mla:
# Remote tp is bigger: read a chunk of local region from remote
local_block_len = local_block_len // (-tp_ratio)
rank_offset = (
self.tp_rank % tp_ratio * remote_kv_block_len
if indexes_into_remote
else 0
)
# Assume same num_blocks for mamba and fa
num_blocks = (
nixl_agent_meta.num_blocks
if not mamba
else nixl_agent_meta.num_blocks
// self._physical_blocks_per_logical_kv_block
)
page_size = nixl_agent_meta.block_lens[i] * (
1 if not mamba else self._physical_blocks_per_logical_kv_block
)
for block_id in range(num_blocks):
block_offset = block_id * page_size
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append(
(addr, local_block_len, nixl_agent_meta.device_id)
)
if kv_topo.is_kv_layout_blocks_first:
# With FlashInfer index V separately to allow head splitting.
second_split = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=False, mamba_view=mamba
)
# Apply the same scaling as local_block_len above for when we read
# a chunk of local V from `tp_ratio` separate remote workers.
if tp_ratio < 0 and not self.use_mla:
second_split = second_split // (-tp_ratio)
for block_id in range(num_blocks):
block_offset = block_id * page_size
addr = base_addr + block_offset + rank_offset
# Hop over the first split of remote page: either K or Conv.
if mamba:
v_addr = addr + nixl_agent_meta.ssm_sizes[0]
else:
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append(
(v_addr, second_split, nixl_agent_meta.device_id)
)
logger.debug(
"Created %s blocks for dst engine %s"
" with remote rank %s and local rank %s",
len(blocks_data),
engine_id,
remote_tp_rank,
self.tp_rank,
)
if self._has_mamba:
# Mamba-HMA: separate FA registration with GQA-aware sizing,
# plus mamba 3-read registration for the Mamba "view" of the
# same KV cache tensors.
logger.debug(
"Registering remote Mamba blocks for engine %s rank %s",
engine_id,
remote_tp_rank,
)
transfer_cfg = self._transfer_configs.get(engine_id)
assert transfer_cfg is not None
blocks_data.extend(
self._build_fa_remote_for_mamba(
nixl_agent_meta,
transfer_cfg,
block_size_ratio,
kv_topo,
)
)
blocks_data.extend(
self._build_mamba_remote(
nixl_agent_meta,
tp_ratio,
)
)
else:
register_remote_blocks(blocks_data, mamba=False)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
self.dst_xfer_side_handles[engine_id][remote_tp_rank] = (
self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)
)
if block_size_ratio > 1:
# when prefill with smaller block_size, we need to init a
# new handler with same block_len to match
self.src_xfer_handles_by_block_size[nixl_agent_meta.block_size] = (
self.register_local_xfer_handler(nixl_agent_meta.block_size)[0]
)
return remote_agent_name
def _validate_remote_agent_handshake(
self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int
):
"""
Validate the remote agent handshake metadata ensuring the
invariants hold true.
"""
remote_engine_id = nixl_agent_meta.engine_id
assert self._tp_size[remote_engine_id] == remote_tp_size
assert self.kv_topo is not None
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
remote_engine_id
)
# num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba.
# Mamba models can have replicated FA KV with tp_ratio < 0.
if not self._has_mamba:
assert not (
tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id)
)
if self._is_hma_required:
assert block_size_ratio == 1, (
"HMA does not support different remote block size yet"
)
kv_cache_layout = (
self.kv_cache_layout
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout
)
if not self.use_mla and nixl_agent_meta.kv_cache_layout != kv_cache_layout:
if (
self.kv_transfer_config.enable_permute_local_kv
and nixl_agent_meta.kv_cache_layout == "HND"
):
logger.info(
"Remote is HND and local is NHD, enabled additional permute "
"on local device KV."
)
assert not self._is_hma_required, (
"HMA does not support block size post processing"
)
self.enable_permute_local_kv = True
else:
raise RuntimeError(
"Heterogeneous TP expects same kv_cache_layout. "
"Or enable experimental feature to use HND to NHD support by "
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
)
# if remote_agent used attn is not same as local,
# hint heterogenuous attn post process
if (
nixl_agent_meta.attn_backend_name != self.backend_name
and self.backend_name in ["CPU_ATTN"]
):
if self._is_hma_required:
raise RuntimeError(
"heterogeneous attn post process is not supported with HMA"
)
logger.info(
"[Experimental] CPU_ATTN backend is used, "
"hint heterogeneous attn post process"
)
self.enable_heterogeneous_attn_post_process = True
# Heterogeneous TP requires head-splitting, which only works with
# HND layout. MLA and replicated-KV cases don't split on heads.
# Mamba doesn't support heterogeneous TP.
if (
abs(tp_ratio) != 1
and not self.use_mla
and not self.kv_topo.is_kv_replicated(remote_engine_id)
and kv_cache_layout != "HND"
and not self.enable_permute_local_kv
):
raise RuntimeError(
"Heterogeneous TP head-dimension splitting requires contiguous heads. "
"Use HND layout on the prefill side."
)
# Block len can only vary across layers when using MLA.
remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
# With replicated KV cache, only the number of blocks can differ.
# TODO (ZhanqiuHu): For mamba models, validate FA and mamba
# block_lens separately.
if not self._has_mamba:
for i in range(len(self.block_len_per_layer)):
assert (
self.block_len_per_layer[i] // block_size_ratio
== nixl_agent_meta.block_lens[i]
), "KV cache sizes must match between P and D when replicated"
else:
# When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, (
"All remote layers must have the same block size"
)
# HMA hybrid models (mamba+attention) pad block_len to
# max(attn_page, mamba_page), so the linear tp_ratio scaling
# assumption only holds for pure-attention models.
if not self._has_mamba:
if tp_ratio > 0:
assert (
remote_block_len
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
), (
"Remote P worker KV layer cache must be of shape [2, N,"
" local_kv_heads*tp_ratio, page_size, head_dim] and "
"same dtype."
)
else:
assert block_size_ratio == 1, (
"Different local/remote block sizes are not supported"
" when P TP > D TP."
)
assert remote_block_len == self.block_len_per_layer[0] // (
-tp_ratio
), (
"Remote P worker KV layer cache must be of shape [2, N,"
" local_kv_heads/tp_ratio, page_size, head_dim] and "
"same dtype."
)
# TP workers that handhshake with same remote have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
# Same number of regions/~layers.
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
"""copy recved kv from host buffer to device."""
assert self.use_host_buffer
assert self.copy_blocks is not None
local_block_ids = meta.local_physical_block_ids
# TODO (NickLucche) D2H<>H2D ops could benefit from coalescing io across groups
for group_block_ids in local_block_ids:
self.copy_blocks(
self.host_xfer_buffers,
self.device_kv_caches,
group_block_ids,
group_block_ids,
"h2d",
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"synced recved kv of request[%s] to device kv buffer,"
"local_block_ids: %s. ",
req_id,
",".join(map(str, local_block_ids)),
)
def save_kv_to_host(self, metadata: NixlConnectorMetadata):
"""copy kv from device to host buffer."""
assert self.use_host_buffer
assert self.copy_blocks is not None
for req_id, meta in metadata.reqs_to_save.items():
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"save_load_kv for request[%s] to host xfer buffer."
"local_block_ids: %s. ",
req_id,
",".join(map(str, meta.local_physical_block_ids)),
)
# blocking
for group_block_ids in meta.local_physical_block_ids:
self.copy_blocks(
self.device_kv_caches,
self.host_xfer_buffers,
group_block_ids,
group_block_ids,
"d2h",
)
def post_process_device_kv_on_receive(
self,
block_size_ratio: int,
block_ids_list: list[list[int]],
):
"""
Post process device kv cache after receiving from remote.
3 types of post processing supported:
* kv_cache_postprocess_layout => convert from HND to NHD
* kv_cache_postprocess_blksize => convert from small block size
to large block size
* kv_cache_postprocess_blksize_and_layout => convert from small
block size to large block size and convert from HND to NHD
"""
if len(self.device_kv_caches) == 0:
return
assert block_size_ratio >= 1, "Only nP < nD supported currently."
assert self.kv_topo is not None
if self.enable_permute_local_kv and block_size_ratio > 1:
logger.debug(
"Post-processing device kv cache on receive by converting "
"block_size with %sx bigger and permuting layout from HND"
" to NHD.",
block_size_ratio,
)
elif self.enable_permute_local_kv:
logger.debug(
"Post-processing device kv cache on receive by permuting layout"
"from HND to NHD."
)
else:
logger.debug(
"Post-processing device kv cache on receive by converting "
"block_size with %sx bigger.",
block_size_ratio,
)
split_k_and_v = self.kv_topo.split_k_and_v
for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
for _, cache_or_caches in self.device_kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
if self.enable_permute_local_kv and block_size_ratio > 1:
kv_postprocess_blksize_and_layout_on_receive(
cache, indices, block_size_ratio
)
elif self.enable_permute_local_kv:
kv_postprocess_layout_on_receive(cache, indices)
else:
kv_postprocess_blksize_on_receive(
cache, indices, block_size_ratio
)
def post_process_device_kv_on_receive_heterogeneous_attn(
self, block_ids: list[int]
):
"""
Post process device kv cache after receiving from remote
for heterogeneous attention.
"""
assert self.enable_heterogeneous_attn_post_process
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
for _, cache_or_caches in self.device_kv_caches.items():
blocks_to_update = cache_or_caches.index_select(1, indices)
current_platform.pack_kv_cache(
key=blocks_to_update[0],
value=blocks_to_update[1],
key_cache=cache_or_caches[0],
value_cache=cache_or_caches[1],
block_ids=block_ids,
indices=indices,
)
def get_finished(self) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
assert self.kv_topo is not None
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
# add requests that skipped transfer to done_recving
done_recving.update(self._failed_recv_reqs)
self._failed_recv_reqs.clear()
if len(done_sending) > 0 or len(done_recving) > 0:
logger.debug(
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving",
self.tp_rank,
len(done_sending),
len(done_recving),
)
block_ids_for_blocksize_post_process = defaultdict(list)
block_ids_for_heterogeneous_attn_post_process = list[list[int]]()
for req_id in done_recving:
# clean up metadata for completed requests
meta = self._recving_metadata.pop(req_id, None)
assert meta is not None, f"{req_id} not found in recving_metadata list"
assert meta.remote is not None
if self.use_host_buffer:
self.sync_recved_kv_to_device(req_id, meta)
# post processing for heteroblocksize
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
meta.remote.engine_id
)
if not self.use_mla and (
block_size_ratio > 1 or self.enable_permute_local_kv
):
assert not self._is_hma_required
block_ids_for_blocksize_post_process[block_size_ratio].append(
meta.local_physical_block_ids[0]
)
# post processing for heterogeneous attention
if self.enable_heterogeneous_attn_post_process:
block_ids_for_heterogeneous_attn_post_process.append(
meta.local_physical_block_ids[0]
)
for (
block_size_ratio,
block_ids_list,
) in block_ids_for_blocksize_post_process.items():
self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list)
for block_ids in block_ids_for_heterogeneous_attn_post_process:
self.post_process_device_kv_on_receive_heterogeneous_attn(block_ids)
# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()
while self._reqs_to_send:
req_id, expires = next(iter(self._reqs_to_send.items()))
# Sorted dict, oldest requests are put first so we can exit early.
if now < expires:
break
count = self.consumer_notification_counts_by_req.pop(req_id, 0)
self.xfer_stats.record_kv_expired_req()
logger.warning(
"Releasing expired KV blocks for request %s which were "
"retrieved by %d decode worker(s) within %d seconds.",
req_id,
count,
envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT,
)
self._reqs_to_process.remove(req_id)
del self._reqs_to_send[req_id]
done_sending.add(req_id)
return done_sending, done_recving
def _get_new_notifs(self) -> set[str]:
"""
Get req_ids which got a remote xfer message. When multiple consumers
are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling.
"""
assert self.kv_topo is not None
notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
req_id, tp_size = notif.decode("utf-8").rsplit(":", 1)
if (
req_id not in self._reqs_to_send
and req_id not in self._reqs_to_process
):
logger.error(
"Potentially invalid KV blocks for "
"unrecognized request %s were retrieved by "
"a decode worker. They may have expired.",
req_id,
)
continue
# NOTE: `tp_ratio` is the opposite when swapping local<>remote
n_consumers = int(tp_size)
tp_ratio = self.kv_topo.tp_ratio(n_consumers)
# Number of reads *per producer* to wait for.
# When remote D TP > local P TP we expect `tp_ratio` reads.
consumers_per_producer = (
-tp_ratio if n_consumers > self.world_size else 1
)
self.consumer_notification_counts_by_req[req_id] += 1
# Wait all consumers (D) to be done reading before freeing.
if (
self.consumer_notification_counts_by_req[req_id]
== consumers_per_producer
):
notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id]
self._reqs_to_process.remove(req_id)
self._reqs_to_send.pop(req_id, None)
return notified_req_ids
def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
"""
Pop completed xfers by checking for DONE state.
Args:
transfers: dict of req_id -> list[running_xfer]
Returns:
set of req_ids that have all done xfers
"""
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
in_progress = []
for handle in handles:
try:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
# Get telemetry from NIXL
res = self.nixl_wrapper.get_xfer_telemetry(handle)
self.xfer_stats.record_transfer(res)
self.nixl_wrapper.release_xfer_handle(handle)
elif xfer_state == "PROC":
in_progress.append(handle)
continue
else:
self._log_failure(
failure_type="transfer_failed",
msg="Marking blocks as invalid",
req_id=req_id,
xfer_state=xfer_state,
)
self._handle_failed_transfer(req_id, handle)
except Exception as e:
self._log_failure(
failure_type="transfer_exception",
msg="Marking blocks as invalid",
req_id=req_id,
error=e,
)
self._handle_failed_transfer(req_id, handle)
if not in_progress:
# Only report request as completed when all transfers are done.
done_req_ids.add(req_id)
del transfers[req_id]
else:
transfers[req_id] = in_progress
return done_req_ids
def _handle_failed_transfer(self, req_id: str, handle: int):
"""
Handle a failed transfer by marking all (logical) blocks as invalid and
recording the failure.
Args:
req_id: The request ID.
handle: The transfer handle.
"""
# Use .get() here as the metadata cleanup is handled by get_finished()
# TODO (NickLucche) handle failed transfer for HMA.
if (meta := self._recving_metadata.get(req_id)) and not self._is_hma_required:
self._invalid_block_ids.update(meta.local_block_ids[0])
self.nixl_wrapper.release_xfer_handle(handle)
self.xfer_stats.record_failed_transfer()
def start_load_kv(self, metadata: NixlConnectorMetadata):
"""
Start loading by triggering non-blocking nixl_xfer.
We check for these trnxs to complete in each step().
"""
for req_id, meta in metadata.reqs_to_recv.items():
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
assert meta.remote is not None
# Remote block IDs are kept logical here; expanded in
# _read_blocks_for_req using the remote engine's phys ratio.
remote_engine_id = meta.remote.engine_id
logger.debug(
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
req_id,
remote_engine_id,
len(meta.local_physical_block_ids),
len(meta.remote.block_ids),
)
# always store metadata for failure recovery
self._recving_metadata[req_id] = meta
if remote_engine_id not in self._remote_agents:
# Initiate handshake with remote engine to exchange metadata.
with self._handshake_lock:
if remote_engine_id not in self._remote_agents:
self._background_nixl_handshake(req_id, remote_engine_id, meta)
continue
# Handshake already completed, start async read xfer.
self._read_blocks_for_req(req_id, meta)
# Start transfers for requests whose handshakes have now finished.
while not self._ready_requests.empty():
self._read_blocks_for_req(*self._ready_requests.get_nowait())
# Keep around the requests that have been part of a batch. This is
# needed because async scheduling pushes the misalignment between the
# moment in which requests expiration is set (P side) and the moment in
# which blocks are read from D. As P can now more easily lag behind D
# while processing the next batch, we make sure to only set an
# expiration for requests that have not been read from D yet.
for req_id in metadata.reqs_in_batch:
self._reqs_to_process.add(req_id)
# Remove all requests that are not to be processed (eg aborted).
for req_id in metadata.reqs_not_processed:
self._reqs_to_process.discard(req_id)
# We should never get an abort after setting an expiry timer
assert req_id not in self._reqs_to_send
# Add to requests that are waiting to be read and track expiration.
for req_id, expiration_time in metadata.reqs_to_send.items():
if req_id in self._reqs_to_process:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None and self.kv_topo is not None
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
meta.remote.engine_id
)
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id)
if self._has_mamba:
# Expand remote logical → kernel block IDs.
meta.remote.block_ids = self._logical_to_remote_kernel_block_ids(
meta.remote.block_ids,
self._mamba_phys_ratio[meta.remote.engine_id],
)
# D may have to perform multiple reads from different remote ranks.
for i, remote_rank in enumerate(remote_ranks):
if self.use_mla and tp_ratio < 0 and i > 0:
# MLA opt: when P TP > D TP, only a single read is executed for
# the first remote rank (cache is duplicated)..
break
remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id]
logger.debug(
"Remote agent %s available, calling _read_blocks"
" on remote rank %s with remote block size %s for req %s",
meta.remote.engine_id,
remote_rank,
remote_block_size,
req_id,
)
# Get side handles.
if tp_ratio < 0 and not self.use_mla:
assert remote_block_size == self.block_size
# Remote tp_size > local tp_size: we must perform multiple
# reads. Get the memory chunk onto which we will write to.
local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i]
else:
# Single read from remote, we write to the whole memory region.
# Also handle remote block size different from local block size.
local_xfer_side_handle = self.src_xfer_handles_by_block_size[
remote_block_size
]
# Destination handle: remote_engine_id -> remote_rank -> handle.
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][
remote_rank
]
local_ids: BlockIds = meta.local_physical_block_ids
remote_ids: BlockIds = meta.remote.block_ids
if self._has_mamba:
# Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets.
transfer_cfg = self._transfer_configs.get(meta.remote.engine_id)
assert transfer_cfg is not None
local_ids, remote_ids = transfer_cfg.filter_block_ids_for_rank(
remote_rank,
local_ids,
remote_ids,
self._is_mamba_group,
)
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=local_ids,
remote_block_ids=remote_ids,
remote_rank=remote_rank,
local_xfer_side_handle=local_xfer_side_handle,
remote_xfer_side_handle=remote_xfer_side_handle,
)
if self.use_mla and tp_ratio < 0:
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.
notif_id = f"{req_id}:{self.world_size}".encode()
remote_agents = self._remote_agents[meta.remote.engine_id]
for rank_to_notify, agent in remote_agents.items():
if rank_to_notify != remote_rank:
self.nixl_wrapper.send_notif(agent, notif_msg=notif_id)
def _read_blocks(
self,
local_block_ids: BlockIds,
remote_block_ids: BlockIds,
dst_engine_id: str,
request_id: str,
remote_request_id: str,
remote_rank: int,
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
):
"""
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
assert self.kv_topo is not None
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1:
# TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups.
assert not self._is_hma_required
local_block_ids0 = local_block_ids[0] if local_block_ids else []
remote_block_ids0 = remote_block_ids[0]
local_block_ids_mapped = self.get_mapped_blocks(
np.asarray(local_block_ids0), block_size_ratio
).tolist()
if len(local_block_ids_mapped) > len(remote_block_ids0):
# NOTE:
# get_mapped_blocks will always expand block_ids for n times.
# ex:
# prefill block_ids with block_size as 4:
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# Local decode block_ids with block_size as 16: [1, 2, 3]
# expanded decode block_ids with get_mapped_blocks from [1, 2, 3] to
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# Then we clip local to align with prefill
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
local_block_ids_mapped = local_block_ids_mapped[
: len(remote_block_ids0)
]
local_block_ids = [local_block_ids_mapped] if local_block_ids_mapped else []
remote_block_ids = [remote_block_ids0]
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
# read trxn async easily). If we want to make "READ" happen cleanly,
# then we will need to have the staging blocks on the remote side.
# NOTE(rob): according to nvidia the staging blocks are used to
# saturate IB with heterogeneous TP sizes.
# Number of D TP workers that will read from dst P. Propagate info
# on notification so that dst worker can wait before freeing blocks.
notif_id = f"{remote_request_id}:{self.world_size}".encode()
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
if len(local_block_ids) == 0:
# A full prefix cache hit is indicated with an empty list.
agent_name = self._remote_agents[dst_engine_id][remote_rank]
try:
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
except Exception as e:
self._log_failure(
failure_type="notification_failed",
msg="P worker blocks will be freed after timeout. "
"This may indicate network issues.",
req_id=request_id,
error=e,
dst_engine_id=dst_engine_id,
remote_rank=remote_rank,
remote_agent_name=agent_name,
)
self.xfer_stats.record_failed_notification()
return
assert (
len(remote_block_ids)
== len(local_block_ids)
== len(self.kv_cache_config.kv_cache_groups)
)
remote_block_ids = list(remote_block_ids)
for i, remote_group in enumerate(remote_block_ids):
num_remote_blocks = len(remote_group)
num_local_blocks = len(local_block_ids[i])
if not self._is_mamba_group[i]:
assert num_local_blocks <= num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
# Skip mamba groups — their blocks represent full state (conv+ssm),
# not per-token data, so trimming would corrupt the transfer.
if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]:
remote_block_ids[i] = remote_group[-num_local_blocks:]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
# Get descs ids.
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id,
remote_block_ids,
)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id,
local_block_ids,
block_size_ratio=block_size_ratio,
)
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
# Prepare transfer with Nixl.
handle = None
try:
handle = self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids,
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=notif_id,
)
# Begin async xfer.
self.nixl_wrapper.transfer(handle)
# Use handle to check completion in future step().
self._recving_transfers[request_id].append(handle)
except Exception as e:
# mark all (logical) blocks for this request as invalid
self._log_failure(
failure_type="transfer_setup_failed",
req_id=request_id,
msg="Marking blocks as invalid",
error=e,
dst_engine_id=dst_engine_id,
remote_rank=remote_rank,
)
if (
meta := self._recving_metadata.get(request_id)
) and not self._is_hma_required:
self._invalid_block_ids.update(meta.local_block_ids[0])
self.xfer_stats.record_failed_transfer()
if handle is not None:
self.nixl_wrapper.release_xfer_handle(handle)
self._failed_recv_reqs.add(request_id)
def get_mapped_blocks(
self, block_ids: np.ndarray, block_size_ratio: int
) -> np.ndarray:
"""
Calculates the new set of block IDs by mapping every element
in the (potentially sparse) input array.
Example: block_ids=[0, 2], block_size_ratio=2
get_mapped_blocks 0 1 [2 3] 4 5
# remote is |h0-b0|h1-b0||h0-b1|h1-b1||h0-b1|h1-b1||
# local is |h0-b0......||h1-b0......||h2-b0........
local_block_ids 0 [1] 2
"""
if block_ids.size == 0:
return np.array([], dtype=np.int64)
start_ids = block_ids * block_size_ratio
offsets = np.arange(block_size_ratio)
mapped_2d = start_ids[:, None] + offsets[None, :]
return mapped_2d.flatten().astype(np.int64)
def _get_block_descs_ids(
self,
engine_id: str,
block_ids: BlockIds,
block_size_ratio: float | None = None,
) -> np.ndarray:
"""
Get the descs ids for a set of block ids.
When HMA is enabled number of descriptors across kv cache groups might differ.
A single flattened array is returned for all groups anyway.
"""
region_ids = np.arange(self.num_regions)
# NOTE (NickLucche) With HMA, every kv group has the same number of layers and
# layers from different groups share the same kv tensor.
# eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions,
# same for [3], but group0-group1 blocks will always differ (different areas).
# Therefore we can just flatten the block_ids and compute the descs ids for all
# groups at once.
num_blocks = self.dst_num_blocks[engine_id]
if block_size_ratio is not None:
num_blocks = int(num_blocks * block_size_ratio)
# Compute desc ids per group using the right stride: FA descs have
# num_blocks entries per region (kernel granularity), SSM descs have
# logical_blocks entries per region (no kernel splitting).
region_ids = region_ids[:, None]
if not self._has_mamba:
block_ids = np.concatenate(block_ids)[None, :]
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()
else:
# NOTE (NickLucche) SSM and Attention blocks regions can be exchanged
# arbitrarily by manager. Therefore, descs are duplicated for SSM and
# Attention like so:
# desc_handle->[descs_fa (all regions) | descs_ssm (all regions)].
# This is like having two "low-level views" of the same storage.
# `num_fa_descs` offset must be computed per-engine since P and D can
# have different num_blocks (and thus different FA descs counts).
ratio = self._mamba_phys_ratio[engine_id]
logical_blocks = num_blocks // ratio
num_fa_descs = self.num_regions * num_blocks
# 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm).
mamba_region_ids = np.arange(len(self.block_len_per_layer) * 4)[:, None]
all_descs = []
for i, group in enumerate(block_ids):
group_arr = np.asarray(group)[None, :]
if self._is_mamba_group[i]:
all_descs.append(
(
mamba_region_ids * logical_blocks + group_arr + num_fa_descs
).flatten()
)
else:
all_descs.append((region_ids * num_blocks + group_arr).flatten())
return np.concatenate(all_descs)
def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds:
"""
Convert logical block ids to kernel physical block ids.
This is required when the logical block size (the one set by the user)
does not match the one required by the attn backend.
"""
if self._physical_blocks_per_logical_kv_block == 1:
# Noop when physical and logical block sizes are the same
return block_ids
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
1, -1
)
# Mamba blocks have no logical<>physical discrepancy
group_specs = self.kv_cache_config.kv_cache_groups
return [
BlockTable.map_to_kernel_blocks(
np.array(group),
self._physical_blocks_per_logical_kv_block,
block_arange,
).tolist()
if not isinstance(group_specs[i].kv_cache_spec, MambaSpec)
else group
for i, group in enumerate(block_ids)
]
def _logical_to_remote_kernel_block_ids(
self, block_ids: BlockIds, remote_ratio: int
) -> BlockIds:
"""Map logical block IDs to physical kernel block IDs on the remote.
Args:
block_ids: per-group lists of logical block IDs.
remote_ratio: remote engine's physical blocks per logical block.
Returns:
Same structure with FA groups expanded (each logical block L
becomes kernel blocks [L*remote_ratio .. L*remote_ratio +
local_ratio - 1]). Mamba groups are passed through unchanged.
"""
local_ratio = self._physical_blocks_per_logical_kv_block
if remote_ratio == 1:
return block_ids
local_arange = np.arange(local_ratio).reshape(1, -1)
group_specs = self.kv_cache_config.kv_cache_groups
result: list[list[int]] = []
for i, group in enumerate(block_ids):
if not isinstance(group_specs[i].kv_cache_spec, MambaSpec):
arr = np.array(group).reshape(-1, 1)
expanded = (arr * remote_ratio + local_arange).flatten()
result.append(expanded.tolist())
else:
# Mamba blocks are 1:1 logical-to-physical (no expansion).
result.append(group)
return result
def get_backend_aware_kv_block_len(
self, layer_idx: int, first_split: bool = True, mamba_view: bool = False
) -> int:
"""
Get the block length for one K/V element (K and V have the same size).
For FA and other backends, this is equal to the length of the whole
block, as K and V are in separate regions.
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
Similarly, for SSM-based models, state and conv are interleaved, but crucially
the their size differs.
Reference diagram:
KVCacheTensor (Shared)
/ \\
/ \\
/ \\
Attention (FlashInfer) View Mamba View
| |
| |
+-------------------+ +-------------------+
| KVCacheTensor | | KVCacheTensor |
| | | |
|<----- page ------>| |<----- page ------->|
| size | | size |
| Key 0 | Val 0 | |Conv 0 | SSM 0 |
| Key 1 | Val 1 | |Conv 1 | SSM 1 |
| ... | ... | | ... | ... |
| Key N-2 | Val N-2 | |Conv N-2| SSM N-2 |
| Key N-1 | Val N-1 | |Conv N-1| SSM N-1 |
+-------------------+ +--------------------+
|1st_split-2nd_split| |1st_split-2nd_split |
"""
assert self.kv_topo is not None
if self.kv_topo.is_kv_layout_blocks_first:
# For indexing only half (either just the K or V part).
if mamba_view:
# NOTE (NickLucche) Mamba Opt: this is already skipping the padding so
# we're only transferring the minimum required bytes.
block_len = self._mamba_ssm_size[not first_split]
else:
block_len = self.block_len_per_layer[layer_idx] // 2
else:
block_len = self.block_len_per_layer[layer_idx]
return block_len
def get_kv_connector_stats(self) -> KVConnectorStats | None:
"""
Get the KV transfer stats for the connector.
"""
# Clear stats for next iteration
if not self.xfer_stats.is_empty():
return self.xfer_stats.clone_and_reset()
return None
def get_block_ids_with_load_errors(self) -> set[int]:
"""
Return and clear the set of block IDs that failed to load.
This is called by the scheduler to identify blocks that need
to be retried after a NIXL transfer failure.
"""
result = self._invalid_block_ids
self._invalid_block_ids = set()
return result
def __del__(self):
self.shutdown()
def shutdown(self):
"""Shutdown the connector worker."""
if not hasattr(self, "_handshake_initiation_executor"):
# error happens during init, no need to shutdown
return
self._handshake_initiation_executor.shutdown(wait=False)
for handles in self._recving_transfers.values():
for handle in handles:
self.nixl_wrapper.release_xfer_handle(handle)
self._recving_transfers.clear()
for handle in self.src_xfer_handles_by_block_size.values():
self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_handles_by_block_size.clear()
for handles in self.src_xfer_handles_by_tp_ratio.values():
for handle in handles:
self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_handles_by_tp_ratio.clear()
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
for dst_xfer_side_handle in dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
self.dst_xfer_side_handles.clear()
for remote_agents in self._remote_agents.values():
for agent_name in remote_agents.values():
self.nixl_wrapper.remove_remote_agent(agent_name)
self._remote_agents.clear()
for desc in self._registered_descs:
self.nixl_wrapper.deregister_memory(desc)
self._registered_descs.clear()