Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler

Scheduler-side logic for the NIXL connector.

NixlConnectorScheduler

Implementation of Scheduler side methods

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
class NixlConnectorScheduler:
    """Implementation of Scheduler side methods"""

    def __init__(
        self,
        vllm_config: "VllmConfig",
        engine_id: str,
        kv_cache_config: "KVCacheConfig",
    ):
        self.vllm_config = vllm_config
        self.block_size = vllm_config.cache_config.block_size
        self.engine_id: EngineId = engine_id
        self.kv_cache_config = kv_cache_config
        self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
        self.side_channel_port = (
            envs.VLLM_NIXL_SIDE_CHANNEL_PORT
            + vllm_config.parallel_config.data_parallel_index
        )
        assert vllm_config.kv_transfer_config is not None
        if current_platform.device_type == "cpu":
            self.use_host_buffer = False
        else:
            self.use_host_buffer = (
                vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
            )
        self._is_hma_required = (
            not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
            # Also handle unlikely SW-only model case instead of checking num_groups>1.
            and any(
                not isinstance(g.kv_cache_spec, FullAttentionSpec)
                for g in kv_cache_config.kv_cache_groups
            )
        )
        self._has_mamba = any(
            isinstance(g.kv_cache_spec, MambaSpec)
            for g in kv_cache_config.kv_cache_groups
        )

        logger.info("Initializing NIXL Scheduler %s", engine_id)
        if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
            logger.info("Hybrid Memory Allocator is enabled with NIXL")

        # Background thread for handling new handshake requests.
        self._nixl_handshake_listener_t: threading.Thread | None = None
        self._stop_event = threading.Event()

        # Requests that need to start recv/send.
        # New requests are added by update_state_after_alloc in
        # the scheduler. Used to make metadata passed to Worker.
        self._reqs_need_recv: dict[ReqId, tuple[Request, BlockIds]] = {}
        self._reqs_need_save: dict[ReqId, Request] = {}
        # Reqs to send and their expiration time
        self._reqs_need_send: dict[ReqId, float] = {}
        self._reqs_in_batch: set[ReqId] = set()
        # Reqs to remove from processed set because they're not to send after
        # remote prefill or aborted.
        self._reqs_not_processed: set[ReqId] = set()

        # Gather Sliding Window sizes for each kv cache group (if any) in number of
        # blocks per KV cache group. This is used to clip the local attention window.
        sw_sizes_tokens: list[tuple[int, int]] = [
            (g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size)
            if isinstance(g.kv_cache_spec, SlidingWindowSpec)
            else (0, self.block_size)
            for g in kv_cache_config.kv_cache_groups
        ]
        # cdiv(n_tokens, block_size) gives blocks/window; add 1 to conservatively
        # account for boundary overlap eg window isn't fully aligned with blocks.
        self.blocks_per_sw = [
            cdiv(n_tokens, block_size) + 1 if n_tokens else 0
            for n_tokens, block_size in sw_sizes_tokens
        ]

    def shutdown(self):
        self._stop_event.set()
        if self._nixl_handshake_listener_t is not None:
            self._nixl_handshake_listener_t.join()
            self._nixl_handshake_listener_t = None

    def get_sw_clipped_blocks(self, block_ids: BlockIds) -> BlockIds:
        """
        Clip the number of blocks to the sliding window size for each kv cache group
        that employs SWA.
        This is necessary because the KV Cache manager initially allocates blocks for
        the entire sequence length, and successively cleans up blocks that are outside
        the window prior to the `request_finished_all_groups` hook.
        """
        if len(block_ids) == 0 or not self._is_hma_required:
            # No blocks to clip eg Full prefix cache hit or not a hybrid model.
            return block_ids
        # NOTE (NickLucche) This logic is currently handled at the connector level
        # because offloading connectors might want to receive the whole sequence even
        # for SWA groups. We will abstract this logic once the interface is more stable
        assert len(block_ids) == len(self.blocks_per_sw), (
            "Number of KV cache groups must match"
        )
        # For non-SWA groups, blocks_per_sw is 0 so we return all block_ids unchanged
        return tuple(
            [
                blocks[-self.blocks_per_sw[i] :]
                if self.blocks_per_sw[i] > 0
                else blocks
                for i, blocks in enumerate(block_ids)
            ]
        )

    def set_xfer_handshake_metadata(
        self, metadata: dict[int, KVConnectorHandshakeMetadata]
    ) -> None:
        """
        Set the KV connector handshake metadata for this connector.

        Args:
            metadata (dict): the handshake metadata to set.
        """
        encoded_data: dict[int, bytes] = {}
        encoder = msgspec.msgpack.Encoder()
        for tp_rank, rank_metadata in metadata.items():
            if not isinstance(rank_metadata, NixlHandshakePayload):
                raise ValueError(
                    "NixlConnectorScheduler expects NixlHandshakePayload for "
                    "handshake metadata."
                )
            encoded_data[tp_rank] = encoder.encode(rank_metadata)
            logger.debug(
                "Tp rank %d: encoded NixlHandshakePayload size: %s bytes",
                tp_rank,
                str(len(encoded_data[tp_rank])),
            )

        # Only start the listener when we have metadata to serve.
        if self._nixl_handshake_listener_t is None:
            ready_event = threading.Event()
            self._nixl_handshake_listener_t = threading.Thread(
                target=self._nixl_handshake_listener,
                args=(
                    encoded_data,
                    ready_event,
                    self._stop_event,
                    self.side_channel_port,
                ),
                daemon=True,
                name="nixl_handshake_listener",
            )
            self._nixl_handshake_listener_t.start()
            ready_event.wait()  # Wait for listener ZMQ socket to be ready.

    @staticmethod
    def _nixl_handshake_listener(
        encoded_data: dict[int, Any],
        ready_event: threading.Event,
        stop_event: threading.Event,
        port: int,
    ):
        """Background thread for getting new NIXL handshakes."""
        # NOTE(rob): this is a simple implementation. We will move
        # to a better approach via HTTP endpoint soon.

        # Listen for new requests for metadata.
        host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
        path = make_zmq_path("tcp", host, port)
        logger.debug("Starting listening on path: %s", path)
        with zmq_ctx(zmq.ROUTER, path) as sock:
            sock.setsockopt(zmq.RCVTIMEO, 1000)
            ready_event.set()
            while True:
                try:
                    identity, _, msg = sock.recv_multipart()
                except zmq.Again:
                    if stop_event.is_set():
                        break
                    continue
                # Decode the message which contains (GET_META_MSG, rank)
                msg, target_tp_rank = msgspec.msgpack.decode(msg)
                logger.debug(
                    "Received message for tp rank %s",
                    target_tp_rank,
                )
                if msg != GET_META_MSG:
                    logger.warning("Connection listener got unexpected message %s", msg)
                sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))

    def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int:
        """D-side only. Returns N-1 for Mamba models since the decoder
        always recomputes the last token and must start from h(N-1)."""
        if self._has_mamba and num_prompt_tokens > 1:
            return num_prompt_tokens - 1
        return num_prompt_tokens

    def _truncate_mamba_request_for_prefill(self, request: "Request") -> None:
        """P-side only: drop the last prompt token so the prefiller computes
        h(N-1) instead of h(N). The decoder recomputes the last token to
        derive h(N) correctly.

        Guarded by ``_p_side_truncated`` to avoid repeated truncation if the
        request is preempted and rescheduled."""
        params = request.kv_transfer_params
        if (
            params is not None
            # Guard against repeated truncation after preemption/reschedule.
            and not params.get("_p_side_truncated")
            and request.num_prompt_tokens > 1
        ):
            if request.prompt_token_ids is not None:
                request.prompt_token_ids.pop()
            elif request.prompt_embeds is not None:
                request.prompt_embeds = request.prompt_embeds[:-1]
            else:
                return

            request._all_token_ids.pop()
            request.num_prompt_tokens -= 1
            request.max_tokens = 1
            params["_p_side_truncated"] = True

    def get_num_new_matched_tokens(
        self, request: "Request", num_computed_tokens: int
    ) -> tuple[int, bool]:
        """
        For remote prefill, pull all prompt blocks from remote
        asynchronously relative to engine execution.

        Args:
            request (Request): the request object.
            num_computed_tokens (int): the number of locally
                computed tokens for this request
        Returns:
            * the number of tokens that can be loaded from the
              external KV cache beyond what is already computed.
            * true if the external KV cache tokens will be loaded
              asynchronously (between scheduler steps).
        """

        params = request.kv_transfer_params
        logger.debug(
            "NIXLConnector get_num_new_matched_tokens: "
            "num_computed_tokens=%s, kv_transfer_params=%s",
            num_computed_tokens,
            params,
        )

        if params is not None and params.get("do_remote_prefill"):
            # Remote prefill: get all prompt blocks from remote.
            token_ids = request.prompt_token_ids or []
            actual = self._mamba_prefill_token_count(len(token_ids))
            count = actual - num_computed_tokens
            if count > 0:
                return count, True

        if params is not None and params.get("do_remote_decode") and self._has_mamba:
            self._truncate_mamba_request_for_prefill(request)

        # No remote prefill for this request.
        return 0, False

    def update_state_after_alloc(
        self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
    ):
        params = request.kv_transfer_params
        logger.debug(
            "NIXLConnector update_state_after_alloc: "
            "num_external_tokens=%s, kv_transfer_params=%s",
            num_external_tokens,
            params,
        )

        if not params:
            return

        if params.get("do_remote_decode"):
            self._reqs_in_batch.add(request.request_id)
        if self.use_host_buffer and params.get("do_remote_decode"):
            # NOTE: when accelerator is not directly supported by Nixl,
            # prefilled blocks need to be saved to host memory before transfer.
            self._reqs_need_save[request.request_id] = request
        elif params.get("do_remote_prefill"):
            if params.get("remote_block_ids"):
                if all(
                    p in params
                    for p in (
                        "remote_engine_id",
                        "remote_request_id",
                        "remote_host",
                        "remote_port",
                    )
                ):
                    # If remote_blocks and num_external_tokens = 0, we have
                    # a full prefix cache hit on the D worker. We need to call
                    # send_notif in _read_blocks to free the memory on the P.

                    unhashed_local_block_ids: BlockIds = (
                        blocks.get_unhashed_block_ids_all_groups()
                        if num_external_tokens > 0
                        else ()
                    )
                    local_block_ids = self.get_sw_clipped_blocks(
                        unhashed_local_block_ids
                    )

                    # Get unhashed blocks to pull from remote. Mind that a full prefix
                    # cache hit is indicated with an empty list.
                    self._reqs_need_recv[request.request_id] = (
                        request,
                        local_block_ids,
                    )

                else:
                    logger.warning(
                        "Got invalid KVTransferParams: %s. This "
                        "request will not utilize KVTransfer",
                        params,
                    )
            else:
                assert num_external_tokens == 0
            # Only trigger 1 KV transfer per request.
            params["do_remote_prefill"] = False

    def _build_save_meta(
        self,
        meta: NixlConnectorMetadata,
        scheduler_output: SchedulerOutput,
    ) -> None:
        # only called when use_host_buffer is True to build the save metadata

        # NOTE: For the prefill side, there might be a chance that an early added
        # request is a chunked prefill, so we need to check if new blocks are added
        for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output):
            req_to_save = self._reqs_need_save.get(req_id)
            if req_to_save is None or new_block_id_groups is None:
                continue
            req = req_to_save

            assert req.kv_transfer_params is not None
            clipped_block_id_groups = self.get_sw_clipped_blocks(new_block_id_groups)
            meta.add_new_req_to_save(
                request_id=req_id,
                local_block_ids=clipped_block_id_groups,
                kv_transfer_params=req.kv_transfer_params,
            )
            assert scheduler_output.num_scheduled_tokens is not None
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
            is_partial = (
                req.num_computed_tokens + num_scheduled_tokens
            ) < req.num_prompt_tokens
            if not is_partial:
                # For non-partial prefills, once new req_meta is scheduled, it
                # can be removed from _reqs_need_save.
                # For partial prefill case, we will retain the request in
                # _reqs_need_save until all blocks are scheduled with req_meta.
                # Therefore, only pop if `not is_partial`.
                self._reqs_need_save.pop(req_id)

    def build_connector_meta(
        self,
        scheduler_output: SchedulerOutput,
    ) -> KVConnectorMetadata:
        meta = NixlConnectorMetadata()

        # Loop through scheduled reqs and convert to ReqMeta.
        for req_id, (req, block_ids) in self._reqs_need_recv.items():
            assert req.kv_transfer_params is not None
            meta.add_new_req_to_recv(
                request_id=req_id,
                local_block_ids=block_ids,
                kv_transfer_params=req.kv_transfer_params,
            )

        if self.use_host_buffer:
            self._build_save_meta(meta, scheduler_output)

        meta.reqs_to_send = self._reqs_need_send
        meta.reqs_in_batch = self._reqs_in_batch
        meta.reqs_not_processed = self._reqs_not_processed

        # Clear the list once workers start the transfers
        self._reqs_need_recv.clear()
        self._reqs_in_batch = set()
        self._reqs_not_processed = set()
        self._reqs_need_send = {}

        return meta

    def request_finished(
        self,
        request: "Request",
        block_ids: BlockIds,
    ) -> tuple[bool, dict[str, Any] | None]:
        """
        Once a request is finished, determine whether request blocks
        should be freed now or will be sent asynchronously and freed later.
        """
        from vllm.v1.request import RequestStatus

        params = request.kv_transfer_params
        logger.debug(
            "NIXLConnector request_finished(%s), request_status=%s, "
            "kv_transfer_params=%s",
            request.request_id,
            request.status,
            params,
        )
        if not params:
            return False, None

        if params.get("do_remote_prefill"):
            # If do_remote_prefill is still True when the request is finished,
            # update_state_after_alloc must not have been called (the request
            # must have been aborted before it was scheduled).
            # To avoid stranding the prefill blocks in the prefill instance,
            # we must add empty block_ids to _reqs_need_recv so that our
            # worker side will notify and free blocks in the prefill instance.
            self._reqs_need_recv[request.request_id] = (request, [])
            params["do_remote_prefill"] = False
            return False, None

        if not params.get("do_remote_decode"):
            return False, None
        if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
            # Also include the case of a P/D Prefill request with immediate
            # block free (eg abort). Stop tracking this request.
            self._reqs_not_processed.add(request.request_id)
            # Clear _reqs_need_save if a request is aborted as partial prefill.
            self._reqs_need_save.pop(request.request_id, None)
            return False, None

        # TODO: check whether block_ids actually ever be 0. If not we could
        # remove the conditional below
        delay_free_blocks = any(len(group) > 0 for group in block_ids)

        if delay_free_blocks:
            # Prefill request on remote. It will be read from D upon completion
            logger.debug(
                "NIXLConnector request_finished(%s) waiting for %d seconds "
                "for remote decode to fetch blocks",
                request.request_id,
                envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT,
            )
            self._reqs_need_send[request.request_id] = (
                time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
            )
            # NOTE HMA will "mark" empty/null blocks in groups with 0s (eg SWA ones),
            # trimming down after allocating for the whole sequence length. Empty
            # blocks are always at the start of the list.
            # Here we "unpad" blocks to send the actual remote blocks to be read.
            block_ids = self.get_sw_clipped_blocks(block_ids)

        return delay_free_blocks, dict(
            do_remote_prefill=True,
            do_remote_decode=False,
            remote_block_ids=block_ids,
            remote_engine_id=self.engine_id,
            remote_request_id=request.request_id,
            remote_host=self.side_channel_host,
            remote_port=self.side_channel_port,
            tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
        )

_mamba_prefill_token_count

_mamba_prefill_token_count(num_prompt_tokens: int) -> int

D-side only. Returns N-1 for Mamba models since the decoder always recomputes the last token and must start from h(N-1).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py
def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int:
    """D-side only. Returns N-1 for Mamba models since the decoder
    always recomputes the last token and must start from h(N-1)."""
    if self._has_mamba and num_prompt_tokens > 1:
        return num_prompt_tokens - 1
    return num_prompt_tokens

_nixl_handshake_listener staticmethod

_nixl_handshake_listener(
    encoded_data: dict[int, Any],
    ready_event: Event,
    stop_event: Event,
    port: int,
)

Background thread for getting new NIXL handshakes.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py
@staticmethod
def _nixl_handshake_listener(
    encoded_data: dict[int, Any],
    ready_event: threading.Event,
    stop_event: threading.Event,
    port: int,
):
    """Background thread for getting new NIXL handshakes."""
    # NOTE(rob): this is a simple implementation. We will move
    # to a better approach via HTTP endpoint soon.

    # Listen for new requests for metadata.
    host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
    path = make_zmq_path("tcp", host, port)
    logger.debug("Starting listening on path: %s", path)
    with zmq_ctx(zmq.ROUTER, path) as sock:
        sock.setsockopt(zmq.RCVTIMEO, 1000)
        ready_event.set()
        while True:
            try:
                identity, _, msg = sock.recv_multipart()
            except zmq.Again:
                if stop_event.is_set():
                    break
                continue
            # Decode the message which contains (GET_META_MSG, rank)
            msg, target_tp_rank = msgspec.msgpack.decode(msg)
            logger.debug(
                "Received message for tp rank %s",
                target_tp_rank,
            )
            if msg != GET_META_MSG:
                logger.warning("Connection listener got unexpected message %s", msg)
            sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))

_truncate_mamba_request_for_prefill

_truncate_mamba_request_for_prefill(
    request: Request,
) -> None

P-side only: drop the last prompt token so the prefiller computes h(N-1) instead of h(N). The decoder recomputes the last token to derive h(N) correctly.

Guarded by _p_side_truncated to avoid repeated truncation if the request is preempted and rescheduled.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py
def _truncate_mamba_request_for_prefill(self, request: "Request") -> None:
    """P-side only: drop the last prompt token so the prefiller computes
    h(N-1) instead of h(N). The decoder recomputes the last token to
    derive h(N) correctly.

    Guarded by ``_p_side_truncated`` to avoid repeated truncation if the
    request is preempted and rescheduled."""
    params = request.kv_transfer_params
    if (
        params is not None
        # Guard against repeated truncation after preemption/reschedule.
        and not params.get("_p_side_truncated")
        and request.num_prompt_tokens > 1
    ):
        if request.prompt_token_ids is not None:
            request.prompt_token_ids.pop()
        elif request.prompt_embeds is not None:
            request.prompt_embeds = request.prompt_embeds[:-1]
        else:
            return

        request._all_token_ids.pop()
        request.num_prompt_tokens -= 1
        request.max_tokens = 1
        params["_p_side_truncated"] = True

get_num_new_matched_tokens

get_num_new_matched_tokens(
    request: Request, num_computed_tokens: int
) -> tuple[int, bool]

For remote prefill, pull all prompt blocks from remote asynchronously relative to engine execution.

Parameters:

Name Type Description Default
request Request

the request object.

required
num_computed_tokens int

the number of locally computed tokens for this request

required

Returns: * the number of tokens that can be loaded from the external KV cache beyond what is already computed. * true if the external KV cache tokens will be loaded asynchronously (between scheduler steps).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py
def get_num_new_matched_tokens(
    self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
    """
    For remote prefill, pull all prompt blocks from remote
    asynchronously relative to engine execution.

    Args:
        request (Request): the request object.
        num_computed_tokens (int): the number of locally
            computed tokens for this request
    Returns:
        * the number of tokens that can be loaded from the
          external KV cache beyond what is already computed.
        * true if the external KV cache tokens will be loaded
          asynchronously (between scheduler steps).
    """

    params = request.kv_transfer_params
    logger.debug(
        "NIXLConnector get_num_new_matched_tokens: "
        "num_computed_tokens=%s, kv_transfer_params=%s",
        num_computed_tokens,
        params,
    )

    if params is not None and params.get("do_remote_prefill"):
        # Remote prefill: get all prompt blocks from remote.
        token_ids = request.prompt_token_ids or []
        actual = self._mamba_prefill_token_count(len(token_ids))
        count = actual - num_computed_tokens
        if count > 0:
            return count, True

    if params is not None and params.get("do_remote_decode") and self._has_mamba:
        self._truncate_mamba_request_for_prefill(request)

    # No remote prefill for this request.
    return 0, False

get_sw_clipped_blocks

get_sw_clipped_blocks(block_ids: BlockIds) -> BlockIds

Clip the number of blocks to the sliding window size for each kv cache group that employs SWA. This is necessary because the KV Cache manager initially allocates blocks for the entire sequence length, and successively cleans up blocks that are outside the window prior to the request_finished_all_groups hook.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py
def get_sw_clipped_blocks(self, block_ids: BlockIds) -> BlockIds:
    """
    Clip the number of blocks to the sliding window size for each kv cache group
    that employs SWA.
    This is necessary because the KV Cache manager initially allocates blocks for
    the entire sequence length, and successively cleans up blocks that are outside
    the window prior to the `request_finished_all_groups` hook.
    """
    if len(block_ids) == 0 or not self._is_hma_required:
        # No blocks to clip eg Full prefix cache hit or not a hybrid model.
        return block_ids
    # NOTE (NickLucche) This logic is currently handled at the connector level
    # because offloading connectors might want to receive the whole sequence even
    # for SWA groups. We will abstract this logic once the interface is more stable
    assert len(block_ids) == len(self.blocks_per_sw), (
        "Number of KV cache groups must match"
    )
    # For non-SWA groups, blocks_per_sw is 0 so we return all block_ids unchanged
    return tuple(
        [
            blocks[-self.blocks_per_sw[i] :]
            if self.blocks_per_sw[i] > 0
            else blocks
            for i, blocks in enumerate(block_ids)
        ]
    )

request_finished

request_finished(
    request: Request, block_ids: BlockIds
) -> tuple[bool, dict[str, Any] | None]

Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py
def request_finished(
    self,
    request: "Request",
    block_ids: BlockIds,
) -> tuple[bool, dict[str, Any] | None]:
    """
    Once a request is finished, determine whether request blocks
    should be freed now or will be sent asynchronously and freed later.
    """
    from vllm.v1.request import RequestStatus

    params = request.kv_transfer_params
    logger.debug(
        "NIXLConnector request_finished(%s), request_status=%s, "
        "kv_transfer_params=%s",
        request.request_id,
        request.status,
        params,
    )
    if not params:
        return False, None

    if params.get("do_remote_prefill"):
        # If do_remote_prefill is still True when the request is finished,
        # update_state_after_alloc must not have been called (the request
        # must have been aborted before it was scheduled).
        # To avoid stranding the prefill blocks in the prefill instance,
        # we must add empty block_ids to _reqs_need_recv so that our
        # worker side will notify and free blocks in the prefill instance.
        self._reqs_need_recv[request.request_id] = (request, [])
        params["do_remote_prefill"] = False
        return False, None

    if not params.get("do_remote_decode"):
        return False, None
    if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
        # Also include the case of a P/D Prefill request with immediate
        # block free (eg abort). Stop tracking this request.
        self._reqs_not_processed.add(request.request_id)
        # Clear _reqs_need_save if a request is aborted as partial prefill.
        self._reqs_need_save.pop(request.request_id, None)
        return False, None

    # TODO: check whether block_ids actually ever be 0. If not we could
    # remove the conditional below
    delay_free_blocks = any(len(group) > 0 for group in block_ids)

    if delay_free_blocks:
        # Prefill request on remote. It will be read from D upon completion
        logger.debug(
            "NIXLConnector request_finished(%s) waiting for %d seconds "
            "for remote decode to fetch blocks",
            request.request_id,
            envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT,
        )
        self._reqs_need_send[request.request_id] = (
            time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
        )
        # NOTE HMA will "mark" empty/null blocks in groups with 0s (eg SWA ones),
        # trimming down after allocating for the whole sequence length. Empty
        # blocks are always at the start of the list.
        # Here we "unpad" blocks to send the actual remote blocks to be read.
        block_ids = self.get_sw_clipped_blocks(block_ids)

    return delay_free_blocks, dict(
        do_remote_prefill=True,
        do_remote_decode=False,
        remote_block_ids=block_ids,
        remote_engine_id=self.engine_id,
        remote_request_id=request.request_id,
        remote_host=self.side_channel_host,
        remote_port=self.side_channel_port,
        tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
    )

set_xfer_handshake_metadata

set_xfer_handshake_metadata(
    metadata: dict[int, KVConnectorHandshakeMetadata],
) -> None

Set the KV connector handshake metadata for this connector.

Parameters:

Name Type Description Default
metadata dict

the handshake metadata to set.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py
def set_xfer_handshake_metadata(
    self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
    """
    Set the KV connector handshake metadata for this connector.

    Args:
        metadata (dict): the handshake metadata to set.
    """
    encoded_data: dict[int, bytes] = {}
    encoder = msgspec.msgpack.Encoder()
    for tp_rank, rank_metadata in metadata.items():
        if not isinstance(rank_metadata, NixlHandshakePayload):
            raise ValueError(
                "NixlConnectorScheduler expects NixlHandshakePayload for "
                "handshake metadata."
            )
        encoded_data[tp_rank] = encoder.encode(rank_metadata)
        logger.debug(
            "Tp rank %d: encoded NixlHandshakePayload size: %s bytes",
            tp_rank,
            str(len(encoded_data[tp_rank])),
        )

    # Only start the listener when we have metadata to serve.
    if self._nixl_handshake_listener_t is None:
        ready_event = threading.Event()
        self._nixl_handshake_listener_t = threading.Thread(
            target=self._nixl_handshake_listener,
            args=(
                encoded_data,
                ready_event,
                self._stop_event,
                self.side_channel_port,
            ),
            daemon=True,
            name="nixl_handshake_listener",
        )
        self._nixl_handshake_listener_t.start()
        ready_event.wait()  # Wait for listener ZMQ socket to be ready.