Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector

NixlConnector – thin facade that delegates to scheduler / worker.

NixlConnector

Bases: KVConnectorBase_V1, SupportsHMA

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
class NixlConnector(KVConnectorBase_V1, SupportsHMA):
    @property
    def prefer_cross_layer_blocks(self) -> bool:
        if any(
            [
                isinstance(group.kv_cache_spec, MambaSpec)
                for group in self.kv_cache_config.kv_cache_groups
            ]
        ):
            # Hybrid SSM models do not yet support cross-layer layout
            return False

        backend = get_current_attn_backend(self._vllm_config)
        if backend.get_name() not in (
            "FLASH_ATTN",
            "FLASHINFER",
            "TRITON_ATTN",
        ):
            return False

        # For now there is no benefit to run cross layers when backend
        # does not support on HND
        if get_kv_cache_layout() != "HND":
            return False

        extra_config = self.kv_transfer_config.kv_connector_extra_config
        return (
            str(extra_config.get("enable_cross_layers_blocks", "False")).lower()
            == "true"
        )

    def __init__(
        self,
        vllm_config: VllmConfig,
        role: KVConnectorRole,
        kv_cache_config: "KVCacheConfig",
    ):
        super().__init__(vllm_config, role, kv_cache_config)
        assert vllm_config.kv_transfer_config is not None
        assert vllm_config.kv_transfer_config.engine_id is not None
        self.kv_cache_config = kv_cache_config
        self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
        self.kv_transfer_config = vllm_config.kv_transfer_config
        if role == KVConnectorRole.SCHEDULER:
            self.connector_scheduler: NixlConnectorScheduler | None = (
                NixlConnectorScheduler(vllm_config, self.engine_id, kv_cache_config)
            )
            self.connector_worker: NixlConnectorWorker | None = None
        elif role == KVConnectorRole.WORKER:
            self.connector_scheduler = None
            self.connector_worker = NixlConnectorWorker(
                vllm_config, self.engine_id, kv_cache_config
            )

    ############################################################
    # Class Methods
    ############################################################
    @classmethod
    def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
        if vllm_config.model_config is None:
            logger.warning_once(
                "Unable to detect current VLLM config. "
                "Fallback to default kv cache layout."
            )
            return None
        use_mla = vllm_config.model_config.use_mla
        if use_mla:
            # return None when we have mla
            # as the layout should not matter in that case,
            # which fallback to the default behavior.
            return None
        logger.info_once(
            "NixlConnector setting KV cache layout to HND for better xfer performance."
        )
        return "HND"

    ############################################################
    # Scheduler Side Methods
    ############################################################

    def get_num_new_matched_tokens(
        self, request: "Request", num_computed_tokens: int
    ) -> tuple[int | None, bool]:
        assert self.connector_scheduler is not None
        return self.connector_scheduler.get_num_new_matched_tokens(
            request, num_computed_tokens
        )

    def update_state_after_alloc(
        self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
    ):
        assert self.connector_scheduler is not None
        return self.connector_scheduler.update_state_after_alloc(
            request, blocks, num_external_tokens
        )

    def build_connector_meta(
        self,
        scheduler_output: SchedulerOutput,
    ) -> KVConnectorMetadata:
        assert self.connector_scheduler is not None
        return self.connector_scheduler.build_connector_meta(scheduler_output)

    def request_finished(
        self,
        request: "Request",
        block_ids: list[int],
    ) -> tuple[bool, dict[str, Any] | None]:
        assert self.connector_scheduler is not None
        return self.connector_scheduler.request_finished(request, (block_ids,))

    def request_finished_all_groups(
        self,
        request: "Request",
        block_ids: tuple[list[int], ...],
    ) -> tuple[bool, dict[str, Any] | None]:
        assert self.connector_scheduler is not None
        return self.connector_scheduler.request_finished(request, block_ids)

    def set_xfer_handshake_metadata(
        self, metadata: dict[int, KVConnectorHandshakeMetadata]
    ) -> None:
        """
        Set the KV connector handshake metadata for this connector.

        Args:
            metadata (dict): the handshake metadata to set.
        """
        assert self.connector_scheduler is not None
        self.connector_scheduler.set_xfer_handshake_metadata(metadata)

    ############################################################
    # Worker Side Methods
    ############################################################
    def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
        assert self.connector_worker is not None
        self.connector_worker.register_kv_caches(kv_caches)

    def register_cross_layers_kv_cache(
        self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
    ):
        assert self.connector_worker is not None
        self.connector_worker.register_cross_layers_kv_caches(kv_cache)

    def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
        assert self.connector_worker is not None
        self.connector_worker.set_host_xfer_buffer_ops(copy_operation)

    def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
        """Get the finished recving and sending requests."""
        assert self.connector_worker is not None
        return self.connector_worker.get_finished()

    def get_block_ids_with_load_errors(self) -> set[int]:
        """Get block IDs that failed to load via NIXL."""
        assert self.connector_worker is not None
        return self.connector_worker.get_block_ids_with_load_errors()

    def get_kv_connector_stats(self) -> KVConnectorStats | None:
        if self.connector_worker is None:
            return None
        return self.connector_worker.get_kv_connector_stats()

    @classmethod
    def build_kv_connector_stats(
        cls, data: dict[str, Any] | None = None
    ) -> KVConnectorStats | None:
        return (
            NixlKVConnectorStats(data=data)
            if data is not None
            else NixlKVConnectorStats()
        )

    @classmethod
    def build_prom_metrics(
        cls,
        vllm_config: VllmConfig,
        metric_types: dict[type[PromMetric], type[PromMetricT]],
        labelnames: list[str],
        per_engine_labelvalues: dict[int, list[object]],
    ) -> KVConnectorPromMetrics:
        return NixlPromMetrics(
            vllm_config, metric_types, labelnames, per_engine_labelvalues
        )

    def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
        assert self.connector_worker is not None
        assert isinstance(self._connector_metadata, NixlConnectorMetadata)
        self.connector_worker.start_load_kv(self._connector_metadata)

    def wait_for_layer_load(self, layer_name: str) -> None:
        """NixlConnector does not do layerwise saving."""
        pass

    def save_kv_layer(
        self,
        layer_name: str,
        kv_layer: torch.Tensor,
        attn_metadata: AttentionMetadata,
        **kwargs,
    ) -> None:
        """NixlConnector does not save explicitly."""
        pass

    def wait_for_save(self):
        assert self.connector_worker is not None
        assert isinstance(self._connector_metadata, NixlConnectorMetadata)
        if self.connector_worker.use_host_buffer and self.connector_worker.copy_blocks:
            self.connector_worker.save_kv_to_host(self._connector_metadata)

    def shutdown(self):
        if self.connector_worker is not None:
            self.connector_worker.shutdown()
        if self.connector_scheduler is not None:
            self.connector_scheduler.shutdown()

    def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
        """
        Get the KVConnector handshake metadata for this connector.
        This metadata is used for out-of-band connector handshake
        between P/D workers.

        Returns:
            KVConnectorHandshakeMetadata: the handshake metadata.
            None if no handshake metadata is available.
        """
        assert self.connector_worker is not None
        return self.connector_worker.xfer_handshake_metadata

get_block_ids_with_load_errors

get_block_ids_with_load_errors() -> set[int]

Get block IDs that failed to load via NIXL.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
def get_block_ids_with_load_errors(self) -> set[int]:
    """Get block IDs that failed to load via NIXL."""
    assert self.connector_worker is not None
    return self.connector_worker.get_block_ids_with_load_errors()

get_finished

get_finished(
    finished_req_ids: set[str],
) -> tuple[set[str], set[str]]

Get the finished recving and sending requests.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
    """Get the finished recving and sending requests."""
    assert self.connector_worker is not None
    return self.connector_worker.get_finished()

get_handshake_metadata

get_handshake_metadata() -> (
    KVConnectorHandshakeMetadata | None
)

Get the KVConnector handshake metadata for this connector. This metadata is used for out-of-band connector handshake between P/D workers.

Returns:

Name Type Description
KVConnectorHandshakeMetadata KVConnectorHandshakeMetadata | None

the handshake metadata.

KVConnectorHandshakeMetadata | None

None if no handshake metadata is available.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
    """
    Get the KVConnector handshake metadata for this connector.
    This metadata is used for out-of-band connector handshake
    between P/D workers.

    Returns:
        KVConnectorHandshakeMetadata: the handshake metadata.
        None if no handshake metadata is available.
    """
    assert self.connector_worker is not None
    return self.connector_worker.xfer_handshake_metadata

save_kv_layer

save_kv_layer(
    layer_name: str,
    kv_layer: Tensor,
    attn_metadata: AttentionMetadata,
    **kwargs,
) -> None

NixlConnector does not save explicitly.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
def save_kv_layer(
    self,
    layer_name: str,
    kv_layer: torch.Tensor,
    attn_metadata: AttentionMetadata,
    **kwargs,
) -> None:
    """NixlConnector does not save explicitly."""
    pass

set_xfer_handshake_metadata

set_xfer_handshake_metadata(
    metadata: dict[int, KVConnectorHandshakeMetadata],
) -> None

Set the KV connector handshake metadata for this connector.

Parameters:

Name Type Description Default
metadata dict

the handshake metadata to set.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
def set_xfer_handshake_metadata(
    self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
    """
    Set the KV connector handshake metadata for this connector.

    Args:
        metadata (dict): the handshake metadata to set.
    """
    assert self.connector_scheduler is not None
    self.connector_scheduler.set_xfer_handshake_metadata(metadata)

wait_for_layer_load

wait_for_layer_load(layer_name: str) -> None

NixlConnector does not do layerwise saving.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
def wait_for_layer_load(self, layer_name: str) -> None:
    """NixlConnector does not do layerwise saving."""
    pass