class NixlConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(
self,
vllm_config: "VllmConfig",
engine_id: str,
kv_cache_config: "KVCacheConfig",
):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id: EngineId = engine_id
self.kv_cache_config = kv_cache_config
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
+ vllm_config.parallel_config.data_parallel_index
)
assert vllm_config.kv_transfer_config is not None
if current_platform.device_type == "cpu":
self.use_host_buffer = False
else:
self.use_host_buffer = (
vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
)
self._is_hma_required = (
not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
# Also handle unlikely SW-only model case instead of checking num_groups>1.
and any(
not isinstance(g.kv_cache_spec, FullAttentionSpec)
for g in kv_cache_config.kv_cache_groups
)
)
self._has_mamba = any(
isinstance(g.kv_cache_spec, MambaSpec)
for g in kv_cache_config.kv_cache_groups
)
logger.info("Initializing NIXL Scheduler %s", engine_id)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
logger.info("Hybrid Memory Allocator is enabled with NIXL")
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
self._stop_event = threading.Event()
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, BlockIds]] = {}
self._reqs_need_save: dict[ReqId, Request] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
# Reqs to remove from processed set because they're not to send after
# remote prefill or aborted.
self._reqs_not_processed: set[ReqId] = set()
# Gather Sliding Window sizes for each kv cache group (if any) in number of
# blocks per KV cache group. This is used to clip the local attention window.
sw_sizes_tokens: list[tuple[int, int]] = [
(g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size)
if isinstance(g.kv_cache_spec, SlidingWindowSpec)
else (0, self.block_size)
for g in kv_cache_config.kv_cache_groups
]
# cdiv(n_tokens, block_size) gives blocks/window; add 1 to conservatively
# account for boundary overlap eg window isn't fully aligned with blocks.
self.blocks_per_sw = [
cdiv(n_tokens, block_size) + 1 if n_tokens else 0
for n_tokens, block_size in sw_sizes_tokens
]
def shutdown(self):
self._stop_event.set()
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join()
self._nixl_handshake_listener_t = None
def get_sw_clipped_blocks(self, block_ids: BlockIds) -> BlockIds:
"""
Clip the number of blocks to the sliding window size for each kv cache group
that employs SWA.
This is necessary because the KV Cache manager initially allocates blocks for
the entire sequence length, and successively cleans up blocks that are outside
the window prior to the `request_finished_all_groups` hook.
"""
if len(block_ids) == 0 or not self._is_hma_required:
# No blocks to clip eg Full prefix cache hit or not a hybrid model.
return block_ids
# NOTE (NickLucche) This logic is currently handled at the connector level
# because offloading connectors might want to receive the whole sequence even
# for SWA groups. We will abstract this logic once the interface is more stable
assert len(block_ids) == len(self.blocks_per_sw), (
"Number of KV cache groups must match"
)
# For non-SWA groups, blocks_per_sw is 0 so we return all block_ids unchanged
return tuple(
[
blocks[-self.blocks_per_sw[i] :]
if self.blocks_per_sw[i] > 0
else blocks
for i, blocks in enumerate(block_ids)
]
)
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
encoded_data: dict[int, bytes] = {}
encoder = msgspec.msgpack.Encoder()
for tp_rank, rank_metadata in metadata.items():
if not isinstance(rank_metadata, NixlHandshakePayload):
raise ValueError(
"NixlConnectorScheduler expects NixlHandshakePayload for "
"handshake metadata."
)
encoded_data[tp_rank] = encoder.encode(rank_metadata)
logger.debug(
"Tp rank %d: encoded NixlHandshakePayload size: %s bytes",
tp_rank,
str(len(encoded_data[tp_rank])),
)
# Only start the listener when we have metadata to serve.
if self._nixl_handshake_listener_t is None:
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(
encoded_data,
ready_event,
self._stop_event,
self.side_channel_port,
),
daemon=True,
name="nixl_handshake_listener",
)
self._nixl_handshake_listener_t.start()
ready_event.wait() # Wait for listener ZMQ socket to be ready.
@staticmethod
def _nixl_handshake_listener(
encoded_data: dict[int, Any],
ready_event: threading.Event,
stop_event: threading.Event,
port: int,
):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach via HTTP endpoint soon.
# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
path = make_zmq_path("tcp", host, port)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
sock.setsockopt(zmq.RCVTIMEO, 1000)
ready_event.set()
while True:
try:
identity, _, msg = sock.recv_multipart()
except zmq.Again:
if stop_event.is_set():
break
continue
# Decode the message which contains (GET_META_MSG, rank)
msg, target_tp_rank = msgspec.msgpack.decode(msg)
logger.debug(
"Received message for tp rank %s",
target_tp_rank,
)
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int:
"""D-side only. Returns N-1 for Mamba models since the decoder
always recomputes the last token and must start from h(N-1)."""
if self._has_mamba and num_prompt_tokens > 1:
return num_prompt_tokens - 1
return num_prompt_tokens
def _truncate_mamba_request_for_prefill(self, request: "Request") -> None:
"""P-side only: drop the last prompt token so the prefiller computes
h(N-1) instead of h(N). The decoder recomputes the last token to
derive h(N) correctly.
Guarded by ``_p_side_truncated`` to avoid repeated truncation if the
request is preempted and rescheduled."""
params = request.kv_transfer_params
if (
params is not None
# Guard against repeated truncation after preemption/reschedule.
and not params.get("_p_side_truncated")
and request.num_prompt_tokens > 1
):
if request.prompt_token_ids is not None:
request.prompt_token_ids.pop()
elif request.prompt_embeds is not None:
request.prompt_embeds = request.prompt_embeds[:-1]
else:
return
request._all_token_ids.pop()
request.num_prompt_tokens -= 1
request.max_tokens = 1
params["_p_side_truncated"] = True
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
params = request.kv_transfer_params
logger.debug(
"NIXLConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s",
num_computed_tokens,
params,
)
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
token_ids = request.prompt_token_ids or []
actual = self._mamba_prefill_token_count(len(token_ids))
count = actual - num_computed_tokens
if count > 0:
return count, True
if params is not None and params.get("do_remote_decode") and self._has_mamba:
self._truncate_mamba_request_for_prefill(request)
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
params = request.kv_transfer_params
logger.debug(
"NIXLConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens,
params,
)
if not params:
return
if params.get("do_remote_decode"):
self._reqs_in_batch.add(request.request_id)
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
self._reqs_need_save[request.request_id] = request
elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(
p in params
for p in (
"remote_engine_id",
"remote_request_id",
"remote_host",
"remote_port",
)
):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
unhashed_local_block_ids: BlockIds = (
blocks.get_unhashed_block_ids_all_groups()
if num_external_tokens > 0
else ()
)
local_block_ids = self.get_sw_clipped_blocks(
unhashed_local_block_ids
)
# Get unhashed blocks to pull from remote. Mind that a full prefix
# cache hit is indicated with an empty list.
self._reqs_need_recv[request.request_id] = (
request,
local_block_ids,
)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer",
params,
)
else:
assert num_external_tokens == 0
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
def _build_save_meta(
self,
meta: NixlConnectorMetadata,
scheduler_output: SchedulerOutput,
) -> None:
# only called when use_host_buffer is True to build the save metadata
# NOTE: For the prefill side, there might be a chance that an early added
# request is a chunked prefill, so we need to check if new blocks are added
for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output):
req_to_save = self._reqs_need_save.get(req_id)
if req_to_save is None or new_block_id_groups is None:
continue
req = req_to_save
assert req.kv_transfer_params is not None
clipped_block_id_groups = self.get_sw_clipped_blocks(new_block_id_groups)
meta.add_new_req_to_save(
request_id=req_id,
local_block_ids=clipped_block_id_groups,
kv_transfer_params=req.kv_transfer_params,
)
assert scheduler_output.num_scheduled_tokens is not None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
is_partial = (
req.num_computed_tokens + num_scheduled_tokens
) < req.num_prompt_tokens
if not is_partial:
# For non-partial prefills, once new req_meta is scheduled, it
# can be removed from _reqs_need_save.
# For partial prefill case, we will retain the request in
# _reqs_need_save until all blocks are scheduled with req_meta.
# Therefore, only pop if `not is_partial`.
self._reqs_need_save.pop(req_id)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = NixlConnectorMetadata()
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
meta.add_new_req_to_recv(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
)
if self.use_host_buffer:
self._build_save_meta(meta, scheduler_output)
meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
meta.reqs_not_processed = self._reqs_not_processed
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_in_batch = set()
self._reqs_not_processed = set()
self._reqs_need_send = {}
return meta
def request_finished(
self,
request: "Request",
block_ids: BlockIds,
) -> tuple[bool, dict[str, Any] | None]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
from vllm.v1.request import RequestStatus
params = request.kv_transfer_params
logger.debug(
"NIXLConnector request_finished(%s), request_status=%s, "
"kv_transfer_params=%s",
request.request_id,
request.status,
params,
)
if not params:
return False, None
if params.get("do_remote_prefill"):
# If do_remote_prefill is still True when the request is finished,
# update_state_after_alloc must not have been called (the request
# must have been aborted before it was scheduled).
# To avoid stranding the prefill blocks in the prefill instance,
# we must add empty block_ids to _reqs_need_recv so that our
# worker side will notify and free blocks in the prefill instance.
self._reqs_need_recv[request.request_id] = (request, [])
params["do_remote_prefill"] = False
return False, None
if not params.get("do_remote_decode"):
return False, None
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id)
# Clear _reqs_need_save if a request is aborted as partial prefill.
self._reqs_need_save.pop(request.request_id, None)
return False, None
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks = any(len(group) > 0 for group in block_ids)
if delay_free_blocks:
# Prefill request on remote. It will be read from D upon completion
logger.debug(
"NIXLConnector request_finished(%s) waiting for %d seconds "
"for remote decode to fetch blocks",
request.request_id,
envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT,
)
self._reqs_need_send[request.request_id] = (
time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
)
# NOTE HMA will "mark" empty/null blocks in groups with 0s (eg SWA ones),
# trimming down after allocating for the whole sequence length. Empty
# blocks are always at the start of the list.
# Here we "unpad" blocks to send the actual remote blocks to be read.
block_ids = self.get_sw_clipped_blocks(block_ids)
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_block_ids=block_ids,
remote_engine_id=self.engine_id,
remote_request_id=request.request_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
)