Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration

Modules:

Name Description
multi_process_adapter
utils
vllm_v1_adapter

LMCacheMPSchedulerAdapter

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
class LMCacheMPSchedulerAdapter:
    def __init__(
        self,
        server_url: str,
        context: zmq.Context,
        model_name: str,
        vllm_block_size: int,
        parallel_strategy: ParallelStrategy,
    ):
        """
        Args:
            server_url: The server URL for the LMCache message queue
            context: The ZMQ context

            model_name: The model name used for LMCache keys
            vllm_block_size: The block size used in vLLM
            parallel_strategy:
                The parallel strategy, which includes `use_mla`,
                `world_size`, `worker_id` and so on
        """
        self.mq_client = MessageQueueClient(server_url, context)

        # Request futures
        self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {}

        self.model_name = model_name
        self.parallel_strategy = parallel_strategy

        # Read chunk size from lmcache
        self.chunk_size = get_lmcache_chunk_size(self.mq_client)
        assert self.chunk_size % vllm_block_size == 0, (
            "LMCache chunk size should be a multiple of vLLM block size"
        )
        self.blocks_in_chunk = self.chunk_size // vllm_block_size

    @property
    def world_size(self) -> int:
        """The world size."""
        return self.parallel_strategy.kv_world_size

    @property
    def worker_id(self) -> int:
        """The worker id."""
        return self.parallel_strategy.kv_worker_id

    @property
    def tp_size(self) -> int:
        """The tensor parallel size."""
        return self.parallel_strategy.tp_size

    @_lmcache_nvtx_annotate
    def maybe_submit_lookup_request(
        self,
        request_id: str,
        block_hashes: list[bytes] | None = None,
        token_ids: list[int] | None = None,
    ) -> None:
        """
        Submit a new lookup request to LMCache if there is no ongoing request.

        Supports both token-based and hash-based vLLM:
        - token_ids: token IDs (token-based vLLM) -> single token-mode key
        - block_hashes: block hashes (hash-based vLLM) -> strided hash-mode keys

        Exactly one of block_hashes or token_ids must be provided.

        Args:
            request_id: The ID of the lookup request. The same ID indicates it's
                from the same request
            block_hashes: Block hashes to lookup from LMCache (hash mode)
            token_ids: Token IDs to lookup from LMCache (token mode)

        Returns:
            None

        Notes:
            This function will have a side-effect: submitting a look up request to
            LMCache, which will essentially 'lock' the KV cache chunks in the LMCache
            for later retrieve operations.
            In the meantime, this function will record the lookup request, and the
            status of the look up request can be checked by `check_lookup_result`.
        """
        if request_id in self.lookup_futures:
            # Skip if there is already a lookup request
            return

        assert (block_hashes is None) != (token_ids is None), (
            "Exactly one of block_hashes or token_ids must be provided"
        )

        if block_hashes is not None:
            # Hash mode: stride block hashes -> N hash-mode keys
            chunk_hashes = list(
                striding_block_hashes(block_hashes, self.blocks_in_chunk)
            )
            keys = [
                self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
            ]
        else:
            # Token mode: truncate to chunk-aligned length
            assert token_ids is not None
            aligned_end = (len(token_ids) // self.chunk_size) * self.chunk_size
            if aligned_end == 0:
                return
            keys = [
                self._create_key(
                    token_ids,
                    start=0,
                    end=aligned_end,
                    request_id=request_id,
                ).no_worker_id_version()
            ]

        future = send_lmcache_request(
            self.mq_client,
            RequestType.LOOKUP,
            [keys],
        )
        self.lookup_futures[request_id] = future

    @_lmcache_nvtx_annotate
    def check_lookup_result(self, request_id: str) -> int | None:
        """
        Check the result of a previously submitted lookup request.

        Args:
            request_id: The ID of the lookup request submitted in
                `maybe_submit_lookup_request`

        Returns:
            An integer representing the total number of tokens matched
            in LMCache (prefix matching), or
            None if the lookup request is not finished yet.
        """
        assert request_id in self.lookup_futures, (
            f"Lookup request for request_id={request_id} has not been submitted"
        )

        future = self.lookup_futures[request_id]
        if not future.query():
            return None

        result = future.result()
        num_chunks = result
        return num_chunks * self.chunk_size

    def num_blocks_per_chunk(self) -> int:
        """
        Returns:
            The number of vllm blocks in a LMCache data chunk
        """
        return self.blocks_in_chunk

    def cleanup_lookup_result(self, request_id: str) -> None:
        """
        Clean up lookup future for a finished request to prevent memory leak.
        Args:
            request_id: The ID of the finished request.
        """
        self.lookup_futures.pop(request_id, None)

    def end_session(self, request_id: str) -> None:
        """
        Notify LMCache server to remove the session for a finished request.
        Args:
            request_id: The ID of the finished request.
        """
        send_lmcache_request(
            self.mq_client,
            RequestType.END_SESSION,
            [request_id],
        )

    # Helper functions
    def _create_key(
        self,
        token_ids: list[int],
        start: int = 0,
        end: int = 0,
        request_id: str | None = None,
    ) -> IPCCacheEngineKey:
        """Convert token IDs to an IPC cache engine key"""
        return IPCCacheEngineKey(
            model_name=self.model_name,
            world_size=self.world_size,
            worker_id=self.worker_id,
            token_ids=tuple(token_ids),
            start=start,
            end=end,
            request_id=request_id,
            tp_size=self.tp_size,
        )

    def _create_hash_key(
        self, chunk_hash: bytes, request_id: str | None = None
    ) -> IPCCacheEngineKey:
        """Create a hash-mode IPC cache engine key"""
        return IPCCacheEngineKey(
            model_name=self.model_name,
            world_size=self.world_size,
            worker_id=None,
            chunk_hash=chunk_hash,
            request_id=request_id,
            tp_size=self.tp_size,
        )

tp_size property

tp_size: int

The tensor parallel size.

worker_id property

worker_id: int

The worker id.

world_size property

world_size: int

The world size.

__init__

__init__(
    server_url: str,
    context: Context,
    model_name: str,
    vllm_block_size: int,
    parallel_strategy: ParallelStrategy,
)

Parameters:

Name Type Description Default
server_url str

The server URL for the LMCache message queue

required
context Context

The ZMQ context

required
model_name str

The model name used for LMCache keys

required
vllm_block_size int

The block size used in vLLM

required
parallel_strategy ParallelStrategy

The parallel strategy, which includes use_mla, world_size, worker_id and so on

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def __init__(
    self,
    server_url: str,
    context: zmq.Context,
    model_name: str,
    vllm_block_size: int,
    parallel_strategy: ParallelStrategy,
):
    """
    Args:
        server_url: The server URL for the LMCache message queue
        context: The ZMQ context

        model_name: The model name used for LMCache keys
        vllm_block_size: The block size used in vLLM
        parallel_strategy:
            The parallel strategy, which includes `use_mla`,
            `world_size`, `worker_id` and so on
    """
    self.mq_client = MessageQueueClient(server_url, context)

    # Request futures
    self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {}

    self.model_name = model_name
    self.parallel_strategy = parallel_strategy

    # Read chunk size from lmcache
    self.chunk_size = get_lmcache_chunk_size(self.mq_client)
    assert self.chunk_size % vllm_block_size == 0, (
        "LMCache chunk size should be a multiple of vLLM block size"
    )
    self.blocks_in_chunk = self.chunk_size // vllm_block_size

_create_hash_key

_create_hash_key(
    chunk_hash: bytes, request_id: str | None = None
) -> IPCCacheEngineKey

Create a hash-mode IPC cache engine key

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def _create_hash_key(
    self, chunk_hash: bytes, request_id: str | None = None
) -> IPCCacheEngineKey:
    """Create a hash-mode IPC cache engine key"""
    return IPCCacheEngineKey(
        model_name=self.model_name,
        world_size=self.world_size,
        worker_id=None,
        chunk_hash=chunk_hash,
        request_id=request_id,
        tp_size=self.tp_size,
    )

_create_key

_create_key(
    token_ids: list[int],
    start: int = 0,
    end: int = 0,
    request_id: str | None = None,
) -> IPCCacheEngineKey

Convert token IDs to an IPC cache engine key

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def _create_key(
    self,
    token_ids: list[int],
    start: int = 0,
    end: int = 0,
    request_id: str | None = None,
) -> IPCCacheEngineKey:
    """Convert token IDs to an IPC cache engine key"""
    return IPCCacheEngineKey(
        model_name=self.model_name,
        world_size=self.world_size,
        worker_id=self.worker_id,
        token_ids=tuple(token_ids),
        start=start,
        end=end,
        request_id=request_id,
        tp_size=self.tp_size,
    )

check_lookup_result

check_lookup_result(request_id: str) -> int | None

Check the result of a previously submitted lookup request.

Parameters:

Name Type Description Default
request_id str

The ID of the lookup request submitted in maybe_submit_lookup_request

required

Returns:

Type Description
int | None

An integer representing the total number of tokens matched

int | None

in LMCache (prefix matching), or

int | None

None if the lookup request is not finished yet.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def check_lookup_result(self, request_id: str) -> int | None:
    """
    Check the result of a previously submitted lookup request.

    Args:
        request_id: The ID of the lookup request submitted in
            `maybe_submit_lookup_request`

    Returns:
        An integer representing the total number of tokens matched
        in LMCache (prefix matching), or
        None if the lookup request is not finished yet.
    """
    assert request_id in self.lookup_futures, (
        f"Lookup request for request_id={request_id} has not been submitted"
    )

    future = self.lookup_futures[request_id]
    if not future.query():
        return None

    result = future.result()
    num_chunks = result
    return num_chunks * self.chunk_size

cleanup_lookup_result

cleanup_lookup_result(request_id: str) -> None

Clean up lookup future for a finished request to prevent memory leak. Args: request_id: The ID of the finished request.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def cleanup_lookup_result(self, request_id: str) -> None:
    """
    Clean up lookup future for a finished request to prevent memory leak.
    Args:
        request_id: The ID of the finished request.
    """
    self.lookup_futures.pop(request_id, None)

end_session

end_session(request_id: str) -> None

Notify LMCache server to remove the session for a finished request. Args: request_id: The ID of the finished request.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def end_session(self, request_id: str) -> None:
    """
    Notify LMCache server to remove the session for a finished request.
    Args:
        request_id: The ID of the finished request.
    """
    send_lmcache_request(
        self.mq_client,
        RequestType.END_SESSION,
        [request_id],
    )

maybe_submit_lookup_request

maybe_submit_lookup_request(
    request_id: str,
    block_hashes: list[bytes] | None = None,
    token_ids: list[int] | None = None,
) -> None

Submit a new lookup request to LMCache if there is no ongoing request.

Supports both token-based and hash-based vLLM: - token_ids: token IDs (token-based vLLM) -> single token-mode key - block_hashes: block hashes (hash-based vLLM) -> strided hash-mode keys

Exactly one of block_hashes or token_ids must be provided.

Parameters:

Name Type Description Default
request_id str

The ID of the lookup request. The same ID indicates it's from the same request

required
block_hashes list[bytes] | None

Block hashes to lookup from LMCache (hash mode)

None
token_ids list[int] | None

Token IDs to lookup from LMCache (token mode)

None

Returns:

Type Description
None

None

Notes

This function will have a side-effect: submitting a look up request to LMCache, which will essentially 'lock' the KV cache chunks in the LMCache for later retrieve operations. In the meantime, this function will record the lookup request, and the status of the look up request can be checked by check_lookup_result.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def maybe_submit_lookup_request(
    self,
    request_id: str,
    block_hashes: list[bytes] | None = None,
    token_ids: list[int] | None = None,
) -> None:
    """
    Submit a new lookup request to LMCache if there is no ongoing request.

    Supports both token-based and hash-based vLLM:
    - token_ids: token IDs (token-based vLLM) -> single token-mode key
    - block_hashes: block hashes (hash-based vLLM) -> strided hash-mode keys

    Exactly one of block_hashes or token_ids must be provided.

    Args:
        request_id: The ID of the lookup request. The same ID indicates it's
            from the same request
        block_hashes: Block hashes to lookup from LMCache (hash mode)
        token_ids: Token IDs to lookup from LMCache (token mode)

    Returns:
        None

    Notes:
        This function will have a side-effect: submitting a look up request to
        LMCache, which will essentially 'lock' the KV cache chunks in the LMCache
        for later retrieve operations.
        In the meantime, this function will record the lookup request, and the
        status of the look up request can be checked by `check_lookup_result`.
    """
    if request_id in self.lookup_futures:
        # Skip if there is already a lookup request
        return

    assert (block_hashes is None) != (token_ids is None), (
        "Exactly one of block_hashes or token_ids must be provided"
    )

    if block_hashes is not None:
        # Hash mode: stride block hashes -> N hash-mode keys
        chunk_hashes = list(
            striding_block_hashes(block_hashes, self.blocks_in_chunk)
        )
        keys = [
            self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
        ]
    else:
        # Token mode: truncate to chunk-aligned length
        assert token_ids is not None
        aligned_end = (len(token_ids) // self.chunk_size) * self.chunk_size
        if aligned_end == 0:
            return
        keys = [
            self._create_key(
                token_ids,
                start=0,
                end=aligned_end,
                request_id=request_id,
            ).no_worker_id_version()
        ]

    future = send_lmcache_request(
        self.mq_client,
        RequestType.LOOKUP,
        [keys],
    )
    self.lookup_futures[request_id] = future

num_blocks_per_chunk

num_blocks_per_chunk() -> int

Returns:

Type Description
int

The number of vllm blocks in a LMCache data chunk

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def num_blocks_per_chunk(self) -> int:
    """
    Returns:
        The number of vllm blocks in a LMCache data chunk
    """
    return self.blocks_in_chunk

LMCacheMPWorkerAdapter

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
class LMCacheMPWorkerAdapter:
    def __init__(
        self,
        server_url: str,
        context: zmq.Context,
        model_name: str,
        vllm_block_size: int,
        parallel_strategy: ParallelStrategy,
    ):
        self.mq_client = MessageQueueClient(server_url, context)

        # Instance id for GPU worker
        self.instance_id = os.getpid()

        # Registered kv caches from vLLM
        self.kv_caches: dict[str, torch.Tensor] = {}

        # Request futures
        # request_id -> (future, other merged requests)
        self.store_futures: dict[
            str, tuple[MessagingFuture[StoreResult], list[str]]
        ] = {}
        self.retrieve_futures: dict[
            str, tuple[MessagingFuture[RetrieveResult], list[str]]
        ] = {}

        # The store requests that have finished execution in LMCache
        self.finished_stores: set[str] = set()
        # The finished request ids that are passed via vLLM and also
        # have corresponding store requests submitted to LMCache before
        self.previously_finished: set[str] = set()

        self.model_name = model_name
        self.parallel_strategy = parallel_strategy

        # Read chunk size from lmcache
        chunk_size = get_lmcache_chunk_size(self.mq_client)
        assert chunk_size % vllm_block_size == 0, (
            "LMCache chunk size should be a multiple of vLLM block size"
        )
        self.blocks_in_chunk = chunk_size // vllm_block_size

    @property
    def world_size(self) -> int:
        """The world size."""
        return self.parallel_strategy.kv_world_size

    @property
    def worker_id(self) -> int:
        """The worker id."""
        return self.parallel_strategy.kv_worker_id

    @property
    def use_mla(self) -> bool:
        """Whether to use MLA."""
        return self.parallel_strategy.use_mla

    @property
    def is_first_rank_of_pp_group(self) -> bool:
        """Is the first rank of the pipeline parallel group."""
        return (
            self.parallel_strategy.actual_worker_id % self.parallel_strategy.tp_size
            == 0
        )

    def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
        """
        Register the kv caches with LMCache server

        Args:
            kv_caches: A dict of kv caches to register. The keys are the
                layer names and the values are the corresponding tensors.
        """
        # Register kv cache and send the request
        self.kv_caches = kv_caches
        logger.info("Registering kv caches")
        future = send_lmcache_request(
            self.mq_client,
            RequestType.REGISTER_KV_CACHE,
            [self.instance_id, wrap_kv_caches(kv_caches)],
        )
        future.result()

    @_lmcache_nvtx_annotate
    def submit_store_request(
        self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
    ):
        """
        Submit a KV cache store request to LMCache

        Args:
            request_id: The ID of the request
            op: The LoadStoreOp describing the store operation.
            event: The CUDA event that is recorded after the current
                model inference step
        """
        if op.block_hashes is not None:
            # Hash mode
            chunk_hashes = list(
                striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
            )
            keys = [
                self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
            ]
        else:
            # Token mode
            assert op.token_ids is not None
            keys = [
                self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
            ]
        future = send_lmcache_request(
            self.mq_client,
            RequestType.STORE,
            [keys, self.instance_id, op.block_ids, event.ipc_handle()],
        ).to_cuda_future()
        self.store_futures[request_id] = (future, [])

    @_lmcache_nvtx_annotate
    def submit_retrieve_request(
        self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
    ):
        """
        Submit a KV cache retrieve request to LMCache

        Args:
            request_id: The ID of the request
            op: The LoadStoreOp describing the retrieve operation.
            event: The CUDA event that is recorded after the current
                model inference step
        """
        if op.block_hashes is not None:
            # Hash mode
            chunk_hashes = list(
                striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
            )
            keys = [
                self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
            ]
        else:
            # Token mode
            assert op.token_ids is not None
            keys = [
                self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
            ]
        future = send_lmcache_request(
            self.mq_client,
            RequestType.RETRIEVE,
            [keys, self.instance_id, op.block_ids, event.ipc_handle()],
        ).to_cuda_future()
        self.retrieve_futures[request_id] = (future, [])

    @_lmcache_nvtx_annotate
    def batched_submit_store_requests(
        self,
        request_ids: list[str],
        ops: list[LoadStoreOp],
        event: torch.cuda.Event,
    ):
        """
        Submit a batched store request to LMCache

        Args:
            request_ids: The IDs of the requests
            ops: The LoadStoreOps describing the store operations. Should have
                the same length as request_ids
            event: The CUDA event that is recorded after the current
                model inference step
        """
        all_keys: list[IPCCacheEngineKey] = []
        block_ids: list[int] = []
        for request_id, op in zip(request_ids, ops, strict=False):
            if op.block_hashes is not None:
                chunk_hashes = list(
                    striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
                )
                keys = [
                    self._create_hash_key(ch, request_id=request_id)
                    for ch in chunk_hashes
                ]
                all_keys.extend(keys)
            else:
                assert op.token_ids is not None
                all_keys.append(
                    self._create_key(
                        op.token_ids, op.start, op.end, request_id=request_id
                    )
                )
            block_ids.extend(op.block_ids)
        future = send_lmcache_request(
            self.mq_client,
            RequestType.STORE,
            [
                all_keys,
                self.instance_id,
                block_ids,
                event.ipc_handle(),
            ],
        ).to_cuda_future()
        self.store_futures[request_ids[0]] = (future, list(request_ids[1:]))

    @_lmcache_nvtx_annotate
    def batched_submit_retrieve_requests(
        self,
        request_ids: list[str],
        ops: list[LoadStoreOp],
        event: torch.cuda.Event,
    ):
        """
        Submit a batched retrieve request to LMCache

        Args:
            request_ids: The IDs of the requests
            ops: The LoadStoreOps describing the retrieve operations. Should have
                the same length as request_ids
            event: The CUDA event that is recorded after the current
                model inference step
        """
        all_keys: list[IPCCacheEngineKey] = []
        block_ids: list[int] = []
        for request_id, op in zip(request_ids, ops, strict=False):
            if op.block_hashes is not None:
                chunk_hashes = list(
                    striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
                )
                keys = [
                    self._create_hash_key(ch, request_id=request_id)
                    for ch in chunk_hashes
                ]
                all_keys.extend(keys)
            else:
                assert op.token_ids is not None
                all_keys.append(
                    self._create_key(
                        op.token_ids, op.start, op.end, request_id=request_id
                    )
                )
            block_ids.extend(op.block_ids)
        future = send_lmcache_request(
            self.mq_client,
            RequestType.RETRIEVE,
            [
                all_keys,
                self.instance_id,
                block_ids,
                event.ipc_handle(),
            ],
        ).to_cuda_future()
        self.retrieve_futures[request_ids[0]] = (future, list(request_ids[1:]))

    @_lmcache_nvtx_annotate
    def get_finished(
        self, finished_req_ids_from_engine: set[str]
    ) -> tuple[set[str] | None, set[str] | None]:
        """
        Check and get the finished store and retrieve requests.

        Args:
            finished_req_ids_from_engine: the set of request ids that are
                reported as finished from the vLLM engine side.

        Returns:
            A tuple of two sets:
            - The first set contains the finished store request ids. The returned
                store request ids MUST be seen before in the
                `finished_req_ids_from_engine`.
            - The second set contains the finished retrieve request ids.

        Notes:
            When enabling async scheduling in vLLM, the same request ID may appear
            multiple times in `finished_req_ids_from_engine`. The adapter should
            take care of deduplicating the request IDs and only return the request
            IDs that have not been returned before.
        """
        finished_stores = set()
        finished_retrieves = set()
        for request_id, (s_future, other_reqs) in self.store_futures.items():
            if not s_future.query():
                continue

            s_result = s_future.result()
            finished_stores.add(request_id)
            finished_stores.update(other_reqs)

            if not s_result:
                # TODO: add error handling here
                logger.error(
                    "Something went wrong when processing the "
                    "store request for request_id=%s",
                    request_id,
                )

        for request_id, (r_future, other_reqs) in self.retrieve_futures.items():
            if not r_future.query():
                continue

            r_result = r_future.result()
            finished_retrieves.add(request_id)
            finished_retrieves.update(other_reqs)

            if not all(r_result):
                # TODO: add error handing here
                logger.error(
                    "Something went wrong when processing the "
                    "retrieve request for request_id=%s, result=%s",
                    request_id,
                    r_result,
                )

        # Remove the finished requests from the tracking dicts
        for request_id in finished_stores:
            self.store_futures.pop(request_id, None)
        for request_id in finished_retrieves:
            self.retrieve_futures.pop(request_id, None)

        # Update the internal states
        self.finished_stores.update(finished_stores)

        ret_stores = set()
        for req_id in finished_req_ids_from_engine:
            if req_id in self.finished_stores or req_id in self.store_futures:
                self.previously_finished.add(req_id)
            else:
                ret_stores.add(req_id)

        # Calculate the final finished stores
        ret_stores.update(self._update_and_get_finished_store())

        return ret_stores, finished_retrieves

    def num_blocks_per_chunk(self) -> int:
        """
        Returns:
            The number of vllm blocks in a LMCache data chunk
        """
        return self.blocks_in_chunk

    def shutdown(self):
        """
        Shutdown the LMCache MP worker adapter
        """
        logger.info("Unregistering kv caches")
        send_lmcache_request(
            self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
        ).result()

        self.mq_client.close()

    # Helper functions
    def _update_and_get_finished_store(
        self,
    ) -> set[str]:
        """Converge the internal states about finished stores
        and returns the 'safe finished store request ids' back
        """
        safe_finished_s = self.finished_stores.intersection(self.previously_finished)
        self.finished_stores.difference_update(self.previously_finished)
        self.previously_finished.difference_update(safe_finished_s)

        return safe_finished_s

    def _create_key(
        self,
        token_ids: list[int],
        start: int = 0,
        end: int = 0,
        request_id: str | None = None,
    ) -> IPCCacheEngineKey:
        """Convert token IDs to an IPC cache engine key"""
        return IPCCacheEngineKey(
            model_name=self.model_name,
            world_size=self.world_size,
            worker_id=self.worker_id,
            token_ids=tuple(token_ids),
            start=start,
            end=end,
            request_id=request_id,
        )

    def _create_hash_key(
        self, chunk_hash: bytes, request_id: str | None = None
    ) -> IPCCacheEngineKey:
        """Create a hash-mode IPC cache engine key"""
        return IPCCacheEngineKey(
            model_name=self.model_name,
            world_size=self.world_size,
            worker_id=self.worker_id,
            chunk_hash=chunk_hash,
            request_id=request_id,
        )

is_first_rank_of_pp_group property

is_first_rank_of_pp_group: bool

Is the first rank of the pipeline parallel group.

use_mla property

use_mla: bool

Whether to use MLA.

worker_id property

worker_id: int

The worker id.

world_size property

world_size: int

The world size.

_create_hash_key

_create_hash_key(
    chunk_hash: bytes, request_id: str | None = None
) -> IPCCacheEngineKey

Create a hash-mode IPC cache engine key

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def _create_hash_key(
    self, chunk_hash: bytes, request_id: str | None = None
) -> IPCCacheEngineKey:
    """Create a hash-mode IPC cache engine key"""
    return IPCCacheEngineKey(
        model_name=self.model_name,
        world_size=self.world_size,
        worker_id=self.worker_id,
        chunk_hash=chunk_hash,
        request_id=request_id,
    )

_create_key

_create_key(
    token_ids: list[int],
    start: int = 0,
    end: int = 0,
    request_id: str | None = None,
) -> IPCCacheEngineKey

Convert token IDs to an IPC cache engine key

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def _create_key(
    self,
    token_ids: list[int],
    start: int = 0,
    end: int = 0,
    request_id: str | None = None,
) -> IPCCacheEngineKey:
    """Convert token IDs to an IPC cache engine key"""
    return IPCCacheEngineKey(
        model_name=self.model_name,
        world_size=self.world_size,
        worker_id=self.worker_id,
        token_ids=tuple(token_ids),
        start=start,
        end=end,
        request_id=request_id,
    )

_update_and_get_finished_store

_update_and_get_finished_store() -> set[str]

Converge the internal states about finished stores and returns the 'safe finished store request ids' back

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def _update_and_get_finished_store(
    self,
) -> set[str]:
    """Converge the internal states about finished stores
    and returns the 'safe finished store request ids' back
    """
    safe_finished_s = self.finished_stores.intersection(self.previously_finished)
    self.finished_stores.difference_update(self.previously_finished)
    self.previously_finished.difference_update(safe_finished_s)

    return safe_finished_s

batched_submit_retrieve_requests

batched_submit_retrieve_requests(
    request_ids: list[str],
    ops: list[LoadStoreOp],
    event: Event,
)

Submit a batched retrieve request to LMCache

Parameters:

Name Type Description Default
request_ids list[str]

The IDs of the requests

required
ops list[LoadStoreOp]

The LoadStoreOps describing the retrieve operations. Should have the same length as request_ids

required
event Event

The CUDA event that is recorded after the current model inference step

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def batched_submit_retrieve_requests(
    self,
    request_ids: list[str],
    ops: list[LoadStoreOp],
    event: torch.cuda.Event,
):
    """
    Submit a batched retrieve request to LMCache

    Args:
        request_ids: The IDs of the requests
        ops: The LoadStoreOps describing the retrieve operations. Should have
            the same length as request_ids
        event: The CUDA event that is recorded after the current
            model inference step
    """
    all_keys: list[IPCCacheEngineKey] = []
    block_ids: list[int] = []
    for request_id, op in zip(request_ids, ops, strict=False):
        if op.block_hashes is not None:
            chunk_hashes = list(
                striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
            )
            keys = [
                self._create_hash_key(ch, request_id=request_id)
                for ch in chunk_hashes
            ]
            all_keys.extend(keys)
        else:
            assert op.token_ids is not None
            all_keys.append(
                self._create_key(
                    op.token_ids, op.start, op.end, request_id=request_id
                )
            )
        block_ids.extend(op.block_ids)
    future = send_lmcache_request(
        self.mq_client,
        RequestType.RETRIEVE,
        [
            all_keys,
            self.instance_id,
            block_ids,
            event.ipc_handle(),
        ],
    ).to_cuda_future()
    self.retrieve_futures[request_ids[0]] = (future, list(request_ids[1:]))

batched_submit_store_requests

batched_submit_store_requests(
    request_ids: list[str],
    ops: list[LoadStoreOp],
    event: Event,
)

Submit a batched store request to LMCache

Parameters:

Name Type Description Default
request_ids list[str]

The IDs of the requests

required
ops list[LoadStoreOp]

The LoadStoreOps describing the store operations. Should have the same length as request_ids

required
event Event

The CUDA event that is recorded after the current model inference step

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def batched_submit_store_requests(
    self,
    request_ids: list[str],
    ops: list[LoadStoreOp],
    event: torch.cuda.Event,
):
    """
    Submit a batched store request to LMCache

    Args:
        request_ids: The IDs of the requests
        ops: The LoadStoreOps describing the store operations. Should have
            the same length as request_ids
        event: The CUDA event that is recorded after the current
            model inference step
    """
    all_keys: list[IPCCacheEngineKey] = []
    block_ids: list[int] = []
    for request_id, op in zip(request_ids, ops, strict=False):
        if op.block_hashes is not None:
            chunk_hashes = list(
                striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
            )
            keys = [
                self._create_hash_key(ch, request_id=request_id)
                for ch in chunk_hashes
            ]
            all_keys.extend(keys)
        else:
            assert op.token_ids is not None
            all_keys.append(
                self._create_key(
                    op.token_ids, op.start, op.end, request_id=request_id
                )
            )
        block_ids.extend(op.block_ids)
    future = send_lmcache_request(
        self.mq_client,
        RequestType.STORE,
        [
            all_keys,
            self.instance_id,
            block_ids,
            event.ipc_handle(),
        ],
    ).to_cuda_future()
    self.store_futures[request_ids[0]] = (future, list(request_ids[1:]))

get_finished

get_finished(
    finished_req_ids_from_engine: set[str],
) -> tuple[set[str] | None, set[str] | None]

Check and get the finished store and retrieve requests.

Parameters:

Name Type Description Default
finished_req_ids_from_engine set[str]

the set of request ids that are reported as finished from the vLLM engine side.

required

Returns:

Type Description
set[str] | None

A tuple of two sets:

set[str] | None
  • The first set contains the finished store request ids. The returned store request ids MUST be seen before in the finished_req_ids_from_engine.
tuple[set[str] | None, set[str] | None]
  • The second set contains the finished retrieve request ids.
Notes

When enabling async scheduling in vLLM, the same request ID may appear multiple times in finished_req_ids_from_engine. The adapter should take care of deduplicating the request IDs and only return the request IDs that have not been returned before.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def get_finished(
    self, finished_req_ids_from_engine: set[str]
) -> tuple[set[str] | None, set[str] | None]:
    """
    Check and get the finished store and retrieve requests.

    Args:
        finished_req_ids_from_engine: the set of request ids that are
            reported as finished from the vLLM engine side.

    Returns:
        A tuple of two sets:
        - The first set contains the finished store request ids. The returned
            store request ids MUST be seen before in the
            `finished_req_ids_from_engine`.
        - The second set contains the finished retrieve request ids.

    Notes:
        When enabling async scheduling in vLLM, the same request ID may appear
        multiple times in `finished_req_ids_from_engine`. The adapter should
        take care of deduplicating the request IDs and only return the request
        IDs that have not been returned before.
    """
    finished_stores = set()
    finished_retrieves = set()
    for request_id, (s_future, other_reqs) in self.store_futures.items():
        if not s_future.query():
            continue

        s_result = s_future.result()
        finished_stores.add(request_id)
        finished_stores.update(other_reqs)

        if not s_result:
            # TODO: add error handling here
            logger.error(
                "Something went wrong when processing the "
                "store request for request_id=%s",
                request_id,
            )

    for request_id, (r_future, other_reqs) in self.retrieve_futures.items():
        if not r_future.query():
            continue

        r_result = r_future.result()
        finished_retrieves.add(request_id)
        finished_retrieves.update(other_reqs)

        if not all(r_result):
            # TODO: add error handing here
            logger.error(
                "Something went wrong when processing the "
                "retrieve request for request_id=%s, result=%s",
                request_id,
                r_result,
            )

    # Remove the finished requests from the tracking dicts
    for request_id in finished_stores:
        self.store_futures.pop(request_id, None)
    for request_id in finished_retrieves:
        self.retrieve_futures.pop(request_id, None)

    # Update the internal states
    self.finished_stores.update(finished_stores)

    ret_stores = set()
    for req_id in finished_req_ids_from_engine:
        if req_id in self.finished_stores or req_id in self.store_futures:
            self.previously_finished.add(req_id)
        else:
            ret_stores.add(req_id)

    # Calculate the final finished stores
    ret_stores.update(self._update_and_get_finished_store())

    return ret_stores, finished_retrieves

num_blocks_per_chunk

num_blocks_per_chunk() -> int

Returns:

Type Description
int

The number of vllm blocks in a LMCache data chunk

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def num_blocks_per_chunk(self) -> int:
    """
    Returns:
        The number of vllm blocks in a LMCache data chunk
    """
    return self.blocks_in_chunk

register_kv_caches

register_kv_caches(kv_caches: dict[str, Tensor])

Register the kv caches with LMCache server

Parameters:

Name Type Description Default
kv_caches dict[str, Tensor]

A dict of kv caches to register. The keys are the layer names and the values are the corresponding tensors.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
    """
    Register the kv caches with LMCache server

    Args:
        kv_caches: A dict of kv caches to register. The keys are the
            layer names and the values are the corresponding tensors.
    """
    # Register kv cache and send the request
    self.kv_caches = kv_caches
    logger.info("Registering kv caches")
    future = send_lmcache_request(
        self.mq_client,
        RequestType.REGISTER_KV_CACHE,
        [self.instance_id, wrap_kv_caches(kv_caches)],
    )
    future.result()

shutdown

shutdown()

Shutdown the LMCache MP worker adapter

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
def shutdown(self):
    """
    Shutdown the LMCache MP worker adapter
    """
    logger.info("Unregistering kv caches")
    send_lmcache_request(
        self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
    ).result()

    self.mq_client.close()

submit_retrieve_request

submit_retrieve_request(
    request_id: str, op: LoadStoreOp, event: Event
)

Submit a KV cache retrieve request to LMCache

Parameters:

Name Type Description Default
request_id str

The ID of the request

required
op LoadStoreOp

The LoadStoreOp describing the retrieve operation.

required
event Event

The CUDA event that is recorded after the current model inference step

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def submit_retrieve_request(
    self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
):
    """
    Submit a KV cache retrieve request to LMCache

    Args:
        request_id: The ID of the request
        op: The LoadStoreOp describing the retrieve operation.
        event: The CUDA event that is recorded after the current
            model inference step
    """
    if op.block_hashes is not None:
        # Hash mode
        chunk_hashes = list(
            striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
        )
        keys = [
            self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
        ]
    else:
        # Token mode
        assert op.token_ids is not None
        keys = [
            self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
        ]
    future = send_lmcache_request(
        self.mq_client,
        RequestType.RETRIEVE,
        [keys, self.instance_id, op.block_ids, event.ipc_handle()],
    ).to_cuda_future()
    self.retrieve_futures[request_id] = (future, [])

submit_store_request

submit_store_request(
    request_id: str, op: LoadStoreOp, event: Event
)

Submit a KV cache store request to LMCache

Parameters:

Name Type Description Default
request_id str

The ID of the request

required
op LoadStoreOp

The LoadStoreOp describing the store operation.

required
event Event

The CUDA event that is recorded after the current model inference step

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@_lmcache_nvtx_annotate
def submit_store_request(
    self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
):
    """
    Submit a KV cache store request to LMCache

    Args:
        request_id: The ID of the request
        op: The LoadStoreOp describing the store operation.
        event: The CUDA event that is recorded after the current
            model inference step
    """
    if op.block_hashes is not None:
        # Hash mode
        chunk_hashes = list(
            striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
        )
        keys = [
            self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
        ]
    else:
        # Token mode
        assert op.token_ids is not None
        keys = [
            self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
        ]
    future = send_lmcache_request(
        self.mq_client,
        RequestType.STORE,
        [keys, self.instance_id, op.block_ids, event.ipc_handle()],
    ).to_cuda_future()
    self.store_futures[request_id] = (future, [])

LoadStoreOp dataclass

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@dataclass
class LoadStoreOp:
    block_ids: list[int]
    """Block ids for the load/store operation"""

    token_ids: list[int] | None = None
    """Token IDs for the load/store operation (token mode)"""

    block_hashes: list[bytes] | None = None
    """Block hashes for the load/store operation (hash mode)"""

    start: int = 0
    """Start token index (token mode only)"""

    end: int = 0
    """End token index (token mode only)"""

    def __len__(self) -> int:
        return len(self.block_ids)

block_hashes class-attribute instance-attribute

block_hashes: list[bytes] | None = None

Block hashes for the load/store operation (hash mode)

block_ids instance-attribute

block_ids: list[int]

Block ids for the load/store operation

end class-attribute instance-attribute

end: int = 0

End token index (token mode only)

start class-attribute instance-attribute

start: int = 0

Start token index (token mode only)

token_ids class-attribute instance-attribute

token_ids: list[int] | None = None

Token IDs for the load/store operation (token mode)

ParallelStrategy dataclass

Source code in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@dataclass
class ParallelStrategy:
    use_mla: bool
    """Whether to use the MLA."""

    kv_world_size: int
    """
    The kv world size, kv_world_size may not be equal to the actual_world_size, 
    in the case of mla, it will 'exclude' the effect of TP, the value is 
    calculated by `extract_world_size_and_kv_rank` in `lmcache_mp_connector.py`.
    """

    kv_worker_id: int
    """
    The kv worker id of the sub-process, kv_worker_id may not be equal to the 
    actual_worker_id, in the case of mla, it will 'exclude' the effect of TP, 
    the value is calculated by `extract_world_size_and_kv_rank` in 
    `lmcache_mp_connector.py`.
    """

    actual_world_size: int
    """The actual world size."""

    actual_worker_id: int
    """The actual worker id of the sub-process."""

    tp_size: int
    """The tensor parallel size."""

    pp_size: int
    """The pipeline parallel size."""

actual_worker_id instance-attribute

actual_worker_id: int

The actual worker id of the sub-process.

actual_world_size instance-attribute

actual_world_size: int

The actual world size.

kv_worker_id instance-attribute

kv_worker_id: int

The kv worker id of the sub-process, kv_worker_id may not be equal to the actual_worker_id, in the case of mla, it will 'exclude' the effect of TP, the value is calculated by extract_world_size_and_kv_rank in lmcache_mp_connector.py.

kv_world_size instance-attribute

kv_world_size: int

The kv world size, kv_world_size may not be equal to the actual_world_size, in the case of mla, it will 'exclude' the effect of TP, the value is calculated by extract_world_size_and_kv_rank in lmcache_mp_connector.py.

pp_size instance-attribute

pp_size: int

The pipeline parallel size.

tp_size instance-attribute

tp_size: int

The tensor parallel size.

use_mla instance-attribute

use_mla: bool

Whether to use the MLA.