pegaflow-llm 0.0.2__cp310-cp310-manylinux_2_34_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,941 @@
1
+ from __future__ import annotations
2
+
3
+ """Baseline vLLM v1 KV connector for local development.
4
+
5
+ This module defines :class:`PegaKVConnector`, a thin subclass of
6
+ ``vllm.distributed.kv_transfer.kv_connector.v1.base.KVConnectorBase_V1``.
7
+
8
+ At the moment it only mirrors the abstract API and raises
9
+ ``NotImplementedError`` in all required methods, so that we have a
10
+ self-contained place inside this repo to start iterating on our own
11
+ PegaFlow-backed connector implementation.
12
+
13
+ Usage example (scheduler/worker side)::
14
+
15
+ from pegaflow import PegaKVConnector, KVConnectorRole
16
+
17
+ connector = PegaKVConnector(vllm_config, KVConnectorRole.WORKER)
18
+
19
+ Later we can register this class as a dynamic connector in vLLM by
20
+ referencing it via its full import path.
21
+ """
22
+
23
+ import functools
24
+ import logging
25
+ import os
26
+ import pickle
27
+ import threading
28
+ import time
29
+ from typing import Any, Dict, List, Optional, Tuple
30
+
31
+ import msgpack
32
+ import torch
33
+ import zmq
34
+
35
+ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
36
+ KVConnectorBase_V1,
37
+ KVConnectorMetadata,
38
+ KVConnectorRole,
39
+ )
40
+
41
+ # Import CUDA IPC wrapper for cross-process tensor sharing
42
+ from pegaflow.ipc_wrapper import CudaIPCWrapper
43
+
44
+ logger = logging.getLogger(__name__)
45
+ # Enable INFO logs by default so required operational logs are visible even if
46
+ # the host application doesn't configure logging.
47
+ logger.setLevel(logging.INFO)
48
+
49
+ # Environment variable to control timing logging
50
+ _ENABLE_TIMING = os.environ.get("PEGAFLOW_ENABLE_TIMING", "1") == "1"
51
+
52
+
53
+ def timing_wrapper(func):
54
+ """Decorator to log function name and execution time when enabled.
55
+
56
+ Enable by setting environment variable: PEGAFLOW_ENABLE_TIMING=1
57
+ """
58
+ @functools.wraps(func)
59
+ def wrapper(*args, **kwargs):
60
+ if not _ENABLE_TIMING:
61
+ return func(*args, **kwargs)
62
+
63
+ start = time.perf_counter()
64
+ try:
65
+ result = func(*args, **kwargs)
66
+ return result
67
+ finally:
68
+ elapsed_ms = (time.perf_counter() - start) * 1000
69
+ logger.info(
70
+ "[PegaKVConnector] %s took %.2f ms",
71
+ func.__name__,
72
+ elapsed_ms,
73
+ )
74
+ return wrapper
75
+
76
+ if not logger.hasHandlers():
77
+ _handler = logging.StreamHandler()
78
+ _handler.setLevel(logging.NOTSET)
79
+ _handler.setFormatter(logging.Formatter("%(message)s"))
80
+ logger.addHandler(_handler)
81
+ # Prevent duplicate output while using the fallback handler.
82
+ logger.propagate = False
83
+
84
+ _LOOKUP_ENDPOINT = os.environ.get("PEGAFLOW_KV_LOOKUP_ENDPOINT")
85
+ if _LOOKUP_ENDPOINT is None:
86
+ unique_id = getattr(os, "getuid", os.getpid)()
87
+ _LOOKUP_ENDPOINT = f"ipc:///tmp/pegaflow_kv_lookup_{unique_id}.sock"
88
+
89
+ # Engine server endpoint (independent process)
90
+ _ENGINE_ENDPOINT = os.environ.get("PEGAFLOW_ENGINE_ENDPOINT", "ipc:///tmp/pega_engine.sock")
91
+
92
+
93
+ class PegaConnectorMetadata(KVConnectorMetadata):
94
+ """Metadata for PegaFlow KV connector.
95
+
96
+ Contains information needed to save/load KV cache blocks:
97
+ - block_hashes: content hashes for each block
98
+ - requests_to_load: mapping from request ID to load information
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ block_hashes: Optional[Dict[str, List[bytes]]] = None,
104
+ requests_to_load: Optional[Dict[str, Dict]] = None,
105
+ ):
106
+ super().__init__()
107
+ self.block_hashes = block_hashes or {}
108
+ self.requests_to_load = requests_to_load or {}
109
+
110
+ class PegaKVConnector(KVConnectorBase_V1):
111
+ """Skeleton v1 KV connector for PegaFlow.
112
+
113
+ This class intentionally keeps the same method signatures as
114
+ :class:`KVConnectorBase_V1` so that it can be used as a drop-in
115
+ implementation once we fill in the logic. All abstract methods
116
+ currently raise :class:`NotImplementedError`.
117
+ """
118
+
119
+ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
120
+ """Create a new PegaKVConnector.
121
+
122
+ Args:
123
+ vllm_config: vLLM configuration object.
124
+ role: Whether this connector instance runs in the scheduler
125
+ process or the worker process.
126
+ """
127
+ super().__init__(vllm_config, role)
128
+
129
+ # ZMQ client for connecting to engine server (independent process)
130
+ self._engine_endpoint = _ENGINE_ENDPOINT
131
+ self._engine_context: Optional[zmq.Context] = None
132
+ self._engine_socket = None
133
+ self._engine_lock = threading.Lock() # Protect socket access
134
+
135
+ # Track block hashes for each request across steps
136
+ self._request_block_hashes = {} # req_id -> list[bytes]
137
+
138
+ # Track pending save operations
139
+ self._pending_saves = [] # list[dict]
140
+
141
+ # Track requests that need to load KV cache from CPU
142
+ self._requests_to_load = {} # req_id -> dict with load info
143
+
144
+ # Track registered KV cache layers
145
+ self._registered_layers: list[str] = []
146
+
147
+ # Scheduler/worker lookup channel state
148
+ self._lookup_endpoint = _LOOKUP_ENDPOINT
149
+ self._lookup_context: Optional[zmq.Context] = None
150
+ self._lookup_server_socket = None
151
+ self._lookup_server_thread: Optional[threading.Thread] = None
152
+ self._lookup_stop_event = threading.Event()
153
+ self._lookup_client = None
154
+
155
+ # Get block size from vllm_config
156
+ self._block_size = vllm_config.cache_config.block_size
157
+ # NOTE: KV cache layout is detected in register_kv_caches() by checking tensor shape.
158
+ # vLLM uses KV-first layout: (2, num_blocks, block_size, num_heads, head_dim)
159
+ # where the first dimension (2) represents K and V separately.
160
+
161
+ # Only worker rank 0 needs to host the lookup server
162
+ parallel_config = getattr(vllm_config, "parallel_config", None)
163
+ data_parallel_rank = getattr(parallel_config, "data_parallel_rank", 0)
164
+ if role == KVConnectorRole.WORKER and data_parallel_rank == 0:
165
+ self._start_lookup_server()
166
+
167
+ # ==============================
168
+ # Engine client helper methods
169
+ # ==============================
170
+
171
+ def _ensure_engine_socket(self) -> None:
172
+ """Ensure engine socket is connected."""
173
+ if self._engine_socket is not None:
174
+ return
175
+
176
+ if self._engine_context is None:
177
+ self._engine_context = zmq.Context()
178
+
179
+ self._engine_socket = self._engine_context.socket(zmq.REQ)
180
+ self._engine_socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
181
+ self._engine_socket.setsockopt(zmq.SNDTIMEO, 5000)
182
+ self._engine_socket.connect(self._engine_endpoint)
183
+ logger.info("[PegaKVConnector] Connected to engine server at %s", self._engine_endpoint)
184
+
185
+ def _send_engine_request(self, command: str, payload: dict) -> dict:
186
+ """Send request to engine server and get response.
187
+
188
+ Args:
189
+ command: Command name (REGISTER, SAVE, LOAD, QUERY, etc.)
190
+ payload: Command payload dict
191
+
192
+ Returns:
193
+ Response dict with 'status' and optional result data
194
+
195
+ Raises:
196
+ RuntimeError: If request fails or times out
197
+ """
198
+ with self._engine_lock:
199
+ try:
200
+ self._ensure_engine_socket()
201
+
202
+ # Send request
203
+ request = (command, payload)
204
+ self._engine_socket.send(pickle.dumps(request))
205
+
206
+ # Receive response
207
+ response_bytes = self._engine_socket.recv()
208
+ response = pickle.loads(response_bytes)
209
+
210
+ if response.get('status') != 'success':
211
+ error_msg = response.get('message', 'Unknown error')
212
+ raise RuntimeError(f"Engine request failed: {error_msg}")
213
+
214
+ return response
215
+
216
+ except zmq.error.Again:
217
+ raise RuntimeError(f"Engine request timeout: {command}")
218
+ except Exception as e:
219
+ # Try to reconnect on next request
220
+ if self._engine_socket is not None:
221
+ try:
222
+ self._engine_socket.close()
223
+ except Exception:
224
+ pass
225
+ self._engine_socket = None
226
+ raise RuntimeError(f"Engine request error: {e}") from e
227
+
228
+ # ==============================
229
+ # Worker-side methods
230
+ # ==============================
231
+
232
+ @timing_wrapper
233
+ def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
234
+ """
235
+ Start loading the KV cache from the connector to vLLM's paged
236
+ KV buffer. This is called from the forward context before the
237
+ forward pass to enable async loading during model execution.
238
+
239
+ Args:
240
+ forward_context (ForwardContext): the forward context.
241
+ **kwargs: additional arguments for the load operation
242
+
243
+ Note:
244
+ The number of elements in kv_caches and layer_names should be
245
+ the same.
246
+
247
+ """
248
+ # ============================================================
249
+ # STEP 1: Get connector metadata
250
+ # ============================================================
251
+ metadata = self._get_connector_metadata()
252
+
253
+ if not isinstance(metadata, PegaConnectorMetadata):
254
+ return
255
+
256
+ # ============================================================
257
+ # STEP 2: Check if there are requests to load
258
+ # ============================================================
259
+ if not metadata.requests_to_load:
260
+ return
261
+
262
+ total_requests = len(metadata.requests_to_load)
263
+
264
+ # ============================================================
265
+ # STEP 3: Load KV blocks for each request and each layer
266
+ # ============================================================
267
+ try:
268
+ load_start = time.perf_counter()
269
+
270
+ # Aggregate all blocks from all requests
271
+ all_block_ids: List[int] = []
272
+ all_block_hashes: List[bytes] = []
273
+
274
+ for req_id, load_info in metadata.requests_to_load.items():
275
+ block_ids = load_info['block_ids']
276
+ block_hashes = load_info['block_hashes']
277
+ num_tokens = load_info['num_tokens']
278
+
279
+ all_block_ids.extend(block_ids)
280
+ all_block_hashes.extend(block_hashes)
281
+
282
+ if not all_block_ids:
283
+ return
284
+
285
+ # Identify all KV cache layers in the order provided by vLLM
286
+ target_layers: List[str] = []
287
+ for layer_name, layer in forward_context.no_compile_layers.items():
288
+ if hasattr(layer, 'kv_cache'):
289
+ target_layers.append(layer_name)
290
+
291
+ if not target_layers:
292
+ return
293
+
294
+ # Batch load for all layers with async transfers
295
+ response = self._send_engine_request('LOAD', {
296
+ 'layer_names': target_layers,
297
+ 'block_ids': all_block_ids,
298
+ 'block_hashes': all_block_hashes
299
+ })
300
+
301
+ num_layers_loaded = response.get('num_layers_loaded', 0)
302
+ total_bytes = response.get('total_bytes', 0)
303
+
304
+ total_blocks = len(all_block_ids) * num_layers_loaded
305
+ total_layers = num_layers_loaded
306
+
307
+ transfer_end = time.perf_counter()
308
+ total_time_us = (transfer_end - load_start) * 1e6
309
+ total_time_s = total_time_us / 1e6
310
+ bandwidth_gbps = (total_bytes / 1e9) / total_time_s if total_time_s > 0 else 0.0
311
+
312
+ logger.info(
313
+ "[PegaKVConnector] queued %d blocks (%.2f GB) across %d layers for %d reqs, "
314
+ "schedule %.0f us (%.2f GB/s)",
315
+ total_blocks,
316
+ total_bytes / 1e9,
317
+ num_layers_loaded,
318
+ total_requests,
319
+ total_time_us,
320
+ bandwidth_gbps,
321
+ )
322
+
323
+ except Exception as e:
324
+ logger.debug(
325
+ "[PegaKVConnector] Error in start_load_kv: %s",
326
+ e,
327
+ exc_info=True,
328
+ )
329
+
330
+ def wait_for_layer_load(self, layer_name: str) -> None:
331
+ """
332
+ Block until the KV for a specific layer is loaded into vLLM's
333
+ paged buffer. This is called from within attention layer to ensure
334
+ async copying from start_load_kv is complete.
335
+
336
+ This interface will be useful for layer-by-layer pipelining.
337
+
338
+ Args:
339
+ layer_name: the name of that layer
340
+ """
341
+ try:
342
+ self._send_engine_request('WAIT_LAYER', {'layer_name': layer_name})
343
+ except Exception as exc:
344
+ logger.debug(
345
+ "[PegaKVConnector] wait_for_layer_load failed for %s: %s",
346
+ layer_name,
347
+ exc,
348
+ exc_info=True,
349
+ )
350
+
351
+ def save_kv_layer(
352
+ self,
353
+ layer_name: str,
354
+ kv_layer: "torch.Tensor", # type: ignore[name-defined]
355
+ attn_metadata: "AttentionMetadata",
356
+ **kwargs: Any,
357
+ ) -> None:
358
+ """
359
+ Start saving a layer of KV cache from vLLM's paged buffer
360
+ to the connector. This is called from within attention layer to
361
+ enable async copying during execution.
362
+
363
+ Args:
364
+ layer_name (str): the name of the layer.
365
+ kv_layer (torch.Tensor): the paged KV buffer of the current
366
+ layer in vLLM.
367
+ attn_metadata (AttentionMetadata): the attention metadata.
368
+ **kwargs: additional arguments for the save operation.
369
+ """
370
+
371
+ # Store for later processing in wait_for_save
372
+ self._pending_saves.append({
373
+ 'layer_name': layer_name,
374
+ 'attn_metadata': attn_metadata,
375
+ })
376
+
377
+ @timing_wrapper
378
+ def wait_for_save(self) -> None:
379
+ """
380
+ Block until all the save operations is done. This is called
381
+ as the forward context exits to ensure that the async saving
382
+ from save_kv_layer is complete before finishing the forward.
383
+
384
+ This prevents overwrites of paged KV buffer before saving done.
385
+ """
386
+ import time
387
+
388
+ # ============================================================
389
+ # STEP 1: Check if there are pending saves
390
+ # ============================================================
391
+ if len(self._pending_saves) == 0:
392
+ return
393
+
394
+ try:
395
+ total_start = time.perf_counter()
396
+
397
+ # ============================================================
398
+ # STEP 2: Get connector metadata
399
+ # ============================================================
400
+ metadata = self._get_connector_metadata()
401
+
402
+ if not isinstance(metadata, PegaConnectorMetadata):
403
+ return
404
+
405
+ # ============================================================
406
+ # STEP 3: Create CUDA event for synchronization
407
+ # ============================================================
408
+ with torch.cuda.stream(torch.cuda.current_stream()):
409
+ event = torch.cuda.Event(interprocess=True)
410
+ event.record()
411
+
412
+ # ============================================================
413
+ # STEP 4: Process each layer's save operation
414
+ # ============================================================
415
+ total_blocks_saved = 0
416
+ total_layers_saved = 0
417
+
418
+ for save_info in self._pending_saves:
419
+ layer_name = save_info['layer_name']
420
+ attn_metadata = save_info['attn_metadata']
421
+
422
+ # Skip if block_table is missing or None
423
+ if attn_metadata.block_table is None:
424
+ continue
425
+
426
+ block_table = attn_metadata.block_table # [num_seqs, max_blocks]
427
+ seq_lens = attn_metadata.seq_lens
428
+
429
+ layer_blocks_saved = 0
430
+
431
+ # Process each sequence in the batch
432
+ for seq_idx in range(block_table.shape[0]):
433
+ # Calculate number of blocks needed for this sequence
434
+ if seq_lens is not None:
435
+ seq_len = seq_lens[seq_idx].item()
436
+ num_blocks = (seq_len + self._block_size - 1) // self._block_size
437
+ else:
438
+ # Fallback: count non-zero blocks
439
+ num_blocks = (block_table[seq_idx] != 0).sum().item()
440
+
441
+ if num_blocks == 0:
442
+ continue
443
+
444
+ # Get active block IDs for this sequence
445
+ active_blocks = block_table[seq_idx, :num_blocks].cpu().tolist()
446
+
447
+ # Find matching block hashes from metadata
448
+ # TODO: Improve mapping between seq_idx and req_id
449
+ block_hashes_for_seq = None
450
+ matched_req_id = None
451
+ for req_id, hashes in metadata.block_hashes.items():
452
+ if len(hashes) > 0:
453
+ num_use = min(num_blocks, len(hashes))
454
+ block_hashes_for_seq = hashes[:num_use]
455
+ active_blocks = active_blocks[:num_use]
456
+ matched_req_id = req_id
457
+ break
458
+
459
+ if block_hashes_for_seq is None:
460
+ continue
461
+
462
+ # Save blocks to storage via Rust backend
463
+ try:
464
+ self._send_engine_request('SAVE', {
465
+ 'layer_name': layer_name,
466
+ 'block_ids': active_blocks,
467
+ 'block_hashes': block_hashes_for_seq
468
+ })
469
+ layer_blocks_saved += len(block_hashes_for_seq)
470
+ except Exception:
471
+ # Silently skip failed saves
472
+ pass
473
+
474
+ if layer_blocks_saved > 0:
475
+ total_blocks_saved += layer_blocks_saved
476
+ total_layers_saved += 1
477
+
478
+ # ============================================================
479
+ # STEP 5: Wait for CUDA operations to complete
480
+ # ============================================================
481
+ event.synchronize()
482
+ total_end = time.perf_counter()
483
+ total_time_ms = (total_end - total_start) * 1000
484
+
485
+ if total_blocks_saved > 0:
486
+ logger.debug(
487
+ "[PegaKVConnector] saved %d blocks across %d layers (%.2f ms)",
488
+ total_blocks_saved,
489
+ total_layers_saved,
490
+ total_time_ms,
491
+ )
492
+
493
+ except Exception:
494
+ # Silently handle errors
495
+ pass
496
+ finally:
497
+ # ============================================================
498
+ # STEP 6: Clean up pending saves
499
+ # ============================================================
500
+ self._pending_saves.clear()
501
+
502
+ # ==============================
503
+ # Scheduler-side methods
504
+ # ==============================
505
+
506
+ @timing_wrapper
507
+ def get_num_new_matched_tokens(
508
+ self,
509
+ request: "Request",
510
+ num_computed_tokens: int,
511
+ ) -> Tuple[Optional[int], bool]:
512
+ """
513
+ Get number of new tokens that can be loaded from the
514
+ external KV cache beyond the num_computed_tokens.
515
+
516
+ Args:
517
+ request (Request): the request object.
518
+ num_computed_tokens (int): the number of locally
519
+ computed tokens for this request
520
+
521
+ Returns:
522
+ A tuple with the following elements:
523
+ - An optional number of tokens that can be loaded from the
524
+ external KV cache beyond what is already computed.
525
+ If None, it means that the connector needs more time to
526
+ determine the number of matched tokens, and the scheduler
527
+ should query for this request again later.
528
+ - `True` if external KV cache tokens will be loaded
529
+ asynchronously (between scheduler steps). Must be
530
+ 'False' if the first element is 0.
531
+
532
+ Notes:
533
+ The connector should only consider the largest prefix of prompt-
534
+ tokens for which KV cache is actually available at the time of the
535
+ call. If the cache cannot be loaded for some tokens (e.g., due to
536
+ connectivity issues or eviction), those tokens must not be taken
537
+ into account.
538
+ """
539
+ prompt_token_ids = request.prompt_token_ids or []
540
+ req_id = request.request_id
541
+ num_tokens = len(prompt_token_ids)
542
+ block_hashes = request.block_hashes
543
+
544
+ matched_blocks = self._send_lookup_request(req_id, block_hashes)
545
+ if matched_blocks <= 0:
546
+ return (0, False)
547
+
548
+ available_tokens = min(matched_blocks * self._block_size, num_tokens)
549
+ if available_tokens <= 1:
550
+ return (0, False)
551
+
552
+ # Always leave at least one prompt token for the scheduler to compute
553
+ reusable_tokens = available_tokens - 1
554
+ num_new_tokens = reusable_tokens - num_computed_tokens
555
+
556
+ if num_new_tokens <= 0:
557
+ return (0, False)
558
+
559
+ return (num_new_tokens, False)
560
+
561
+ @timing_wrapper
562
+ def update_state_after_alloc(
563
+ self,
564
+ request: "Request",
565
+ blocks: "KVCacheBlocks",
566
+ num_external_tokens: int,
567
+ ) -> None:
568
+ """
569
+ Update KVConnector state after block allocation.
570
+
571
+ If get_num_new_matched_tokens previously returned True for a
572
+ request, this function may be called twice for that same request -
573
+ first when blocks are allocated for the connector tokens to be
574
+ asynchronously loaded into, and second when any additional blocks
575
+ are allocated, after the load/transfer is complete.
576
+
577
+ Args:
578
+ request (Request): the request object.
579
+ blocks (KVCacheBlocks): the blocks allocated for the request.
580
+ num_external_tokens (int): the number of tokens that will be
581
+ loaded from the external KV cache.
582
+ """
583
+ req_id = request.request_id
584
+
585
+ # block hashes is a list[bytes]
586
+ self._request_block_hashes[req_id] = request.block_hashes
587
+
588
+ # If there are external tokens to load, record this request
589
+ if num_external_tokens > 0:
590
+ self._requests_to_load[req_id] = {
591
+ 'request': request,
592
+ 'blocks': blocks,
593
+ 'num_external_tokens': num_external_tokens,
594
+ }
595
+
596
+ @timing_wrapper
597
+ def build_connector_meta(self, scheduler_output: "SchedulerOutput") -> KVConnectorMetadata:
598
+ """
599
+ Build the connector metadata for this step.
600
+
601
+ This function should NOT modify fields in the scheduler_output.
602
+ Also, calling this function will reset the state of the connector.
603
+
604
+ Args:
605
+ scheduler_output (SchedulerOutput): the scheduler output object.
606
+ """
607
+ block_hashes = {}
608
+
609
+ # ============================================================
610
+ # STEP 1: Process new requests (first time scheduled)
611
+ # ============================================================
612
+ new_reqs = scheduler_output.scheduled_new_reqs
613
+ for req in new_reqs:
614
+ req_id = req.req_id
615
+
616
+ # Use block hashes saved from update_state_after_alloc()
617
+ # These are vLLM's content-based hashes computed from token sequences
618
+ if req_id in self._request_block_hashes:
619
+ saved_hashes = self._request_block_hashes[req_id]
620
+ block_hashes[req_id] = saved_hashes
621
+
622
+ # ============================================================
623
+ # STEP 2: Process cached requests (already scheduled, now in decode phase)
624
+ # ============================================================
625
+ # Note: For cached requests, block_hashes are already updated in
626
+ # update_state_after_alloc() when new blocks are allocated during decode.
627
+ # We just need to retrieve them from our persistent state.
628
+ cached_reqs = scheduler_output.scheduled_cached_reqs
629
+ for i, req_id in enumerate(cached_reqs.req_ids):
630
+ # Use block hashes from persistent state (updated by update_state_after_alloc)
631
+ if req_id in self._request_block_hashes:
632
+ saved_hashes = self._request_block_hashes[req_id]
633
+ block_hashes[req_id] = saved_hashes
634
+
635
+ # ============================================================
636
+ # STEP 3: Process requests that need to load from CPU storage
637
+ # ============================================================
638
+ requests_to_load = {}
639
+
640
+ for req_id, load_info in self._requests_to_load.items():
641
+ num_external_tokens = load_info['num_external_tokens']
642
+
643
+ # Find this request in scheduler_output
644
+ found = False
645
+ for req in scheduler_output.scheduled_new_reqs:
646
+ if req.req_id == req_id:
647
+ # Extract block IDs from the request
648
+ block_ids = list(req.block_ids[0]) if req.block_ids else []
649
+
650
+ # Calculate number of blocks needed, clamp to available hashes
651
+ num_blocks = (num_external_tokens + self._block_size - 1) // self._block_size
652
+ saved_hashes = self._request_block_hashes.get(req_id, [])
653
+ num_blocks = min(num_blocks, len(saved_hashes))
654
+
655
+ if num_blocks > 0 and len(block_ids) >= num_blocks:
656
+ load_hashes = saved_hashes[:num_blocks]
657
+
658
+ # Store load information
659
+ requests_to_load[req_id] = {
660
+ 'block_ids': block_ids[:num_blocks],
661
+ 'block_hashes': load_hashes,
662
+ 'num_tokens': num_external_tokens,
663
+ }
664
+
665
+ found = True
666
+ break
667
+
668
+ # Clear the requests_to_load after processing
669
+ self._requests_to_load.clear()
670
+
671
+ # ============================================================
672
+ # STEP 4: Build and return metadata
673
+ # ============================================================
674
+ metadata = PegaConnectorMetadata(
675
+ block_hashes=block_hashes,
676
+ requests_to_load=requests_to_load,
677
+ )
678
+
679
+ return metadata
680
+
681
+ def _start_lookup_server(self) -> None:
682
+ """Start background REP server for scheduler lookup requests."""
683
+ if self._lookup_server_thread is not None:
684
+ return
685
+
686
+ self._lookup_context = zmq.Context()
687
+ self._lookup_stop_event.clear()
688
+ self._lookup_server_socket = self._lookup_context.socket(zmq.REP)
689
+
690
+ if self._lookup_endpoint.startswith("ipc://"):
691
+ ipc_path = self._lookup_endpoint.replace("ipc://", "")
692
+ try:
693
+ os.unlink(ipc_path)
694
+ except FileNotFoundError:
695
+ pass
696
+ except PermissionError:
697
+ pass
698
+
699
+ self._lookup_server_socket.bind(self._lookup_endpoint)
700
+
701
+ thread = threading.Thread(target=self._lookup_server_loop, daemon=True)
702
+ thread.start()
703
+ self._lookup_server_thread = thread
704
+ logger.info(
705
+ "[PegaKVConnector] Lookup server started at %s",
706
+ self._lookup_endpoint,
707
+ )
708
+
709
+ def _lookup_server_loop(self) -> None:
710
+ """Handle lookup requests from the scheduler."""
711
+ assert self._lookup_server_socket is not None
712
+
713
+ while not self._lookup_stop_event.is_set():
714
+ try:
715
+ message = self._lookup_server_socket.recv()
716
+ except zmq.error.ZMQError:
717
+ if self._lookup_stop_event.is_set():
718
+ break
719
+ continue
720
+
721
+ hit_blocks = 0
722
+ try:
723
+ # Directly deserialize block_hashes list for faster deserialization
724
+ block_hashes = msgpack.unpackb(message)
725
+ hit_blocks = self._count_available_block_prefix(block_hashes)
726
+ except Exception:
727
+ hit_blocks = 0
728
+
729
+ # Directly serialize hit_blocks int for faster serialization
730
+ reply = msgpack.packb(hit_blocks)
731
+ try:
732
+ self._lookup_server_socket.send(reply)
733
+ except zmq.error.ZMQError:
734
+ if self._lookup_stop_event.is_set():
735
+ break
736
+
737
+ def _count_available_block_prefix(self, block_hashes: List[bytes]) -> int:
738
+ """Return length of contiguous prefix available in CPU storage."""
739
+ if not block_hashes:
740
+ return 0
741
+
742
+ try:
743
+ response = self._send_engine_request('QUERY', {
744
+ 'block_hashes': block_hashes
745
+ })
746
+ return response.get('hit_blocks', 0)
747
+ except Exception:
748
+ return 0
749
+
750
+ def _ensure_lookup_client(self) -> None:
751
+ if self._lookup_client is not None:
752
+ return
753
+ if self._lookup_context is None:
754
+ self._lookup_context = zmq.Context()
755
+ self._lookup_client = self._lookup_context.socket(zmq.REQ)
756
+ # Avoid hanging forever if worker is not reachable
757
+ self._lookup_client.setsockopt(zmq.RCVTIMEO, 2000)
758
+ self._lookup_client.setsockopt(zmq.SNDTIMEO, 2000)
759
+ self._lookup_client.connect(self._lookup_endpoint)
760
+
761
+ def _send_lookup_request(self, req_id: str, block_hashes: List[bytes]) -> int:
762
+ """Query worker for contiguous cached prefix length (in blocks)."""
763
+ if not block_hashes:
764
+ return 0
765
+
766
+ try:
767
+ self._ensure_lookup_client()
768
+ except Exception:
769
+ return 0
770
+
771
+ # Directly serialize block_hashes list for faster serialization
772
+ payload = msgpack.packb(block_hashes)
773
+
774
+ lookup_start = time.perf_counter()
775
+ try:
776
+ assert self._lookup_client is not None
777
+ self._lookup_client.send(payload)
778
+ reply = self._lookup_client.recv()
779
+ except (zmq.error.Again, zmq.error.ZMQError):
780
+ return 0
781
+ lookup_end = time.perf_counter()
782
+
783
+ try:
784
+ # Directly deserialize hit_blocks int for faster deserialization
785
+ hit_blocks = int(msgpack.unpackb(reply))
786
+ except Exception:
787
+ return 0
788
+
789
+ total_blocks = len(block_hashes)
790
+ elapsed_us = (lookup_end - lookup_start) * 1e6
791
+ logger.info(
792
+ "[PegaKVConnector] scheduler_lookup req=%s hit_blocks=%d/%d (%.1f%%) cost=%.0f us",
793
+ req_id,
794
+ hit_blocks,
795
+ total_blocks,
796
+ (hit_blocks / total_blocks * 100) if total_blocks > 0 else 0.0,
797
+ elapsed_us,
798
+ )
799
+ return hit_blocks
800
+
801
+ def _stop_lookup_server(self) -> None:
802
+ if self._lookup_server_thread is None:
803
+ return
804
+ self._lookup_stop_event.set()
805
+ if self._lookup_server_socket is not None:
806
+ try:
807
+ self._lookup_server_socket.close(0)
808
+ except Exception:
809
+ pass
810
+ self._lookup_server_socket = None
811
+ self._lookup_server_thread.join(timeout=1.0)
812
+ self._lookup_server_thread = None
813
+ if self._lookup_context is not None:
814
+ self._lookup_context.term()
815
+ self._lookup_context = None
816
+
817
+ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
818
+ """Register the active inference context with the engine service.
819
+
820
+ Args:
821
+ kv_caches: Dictionary mapping layer names to KV cache tensors
822
+ """
823
+ self._registered_layers = list(kv_caches.keys())
824
+
825
+ for layer_name, kv_cache in kv_caches.items():
826
+ if kv_cache.storage_offset() != 0:
827
+ raise RuntimeError(
828
+ f"KV cache for {layer_name} must have zero storage offset"
829
+ )
830
+
831
+ # Create CUDA IPC wrapper for cross-process sharing
832
+ wrapper = CudaIPCWrapper(kv_cache)
833
+ wrapper_bytes = pickle.dumps(wrapper)
834
+
835
+ shape = tuple(kv_cache.shape)
836
+ stride = tuple(kv_cache.stride())
837
+ element_size = kv_cache.element_size()
838
+
839
+ # Detect KV cache layout:
840
+ # - KV-first layout: shape = (2, num_blocks, block_size, num_heads, head_dim)
841
+ # where shape[0] = 2 for K and V
842
+ # - Blocks-first layout: shape = (num_blocks, block_size, num_heads, head_dim)
843
+ # where shape[0] = num_blocks
844
+ #
845
+ # We detect this by checking if shape[0] == 2, which indicates KV-first layout.
846
+ # In KV-first layout, the actual num_blocks is shape[1].
847
+ if len(shape) >= 2 and shape[0] == 2:
848
+ # KV-first layout: (2, num_blocks, ...)
849
+ num_blocks = shape[1]
850
+ bytes_per_block = stride[1] * element_size
851
+ kv_stride_bytes = stride[0] * element_size
852
+ segments = 2
853
+ layout = "KV-first"
854
+ else:
855
+ # Blocks-first layout: (num_blocks, ...)
856
+ num_blocks = shape[0]
857
+ bytes_per_block = stride[0] * element_size
858
+ kv_stride_bytes = 0
859
+ segments = 1
860
+ layout = "blocks-first"
861
+
862
+ if bytes_per_block == 0:
863
+ raise RuntimeError(
864
+ f"Invalid bytes_per_block for {layer_name}: stride={stride}"
865
+ )
866
+
867
+ # Send to engine server for registration
868
+ try:
869
+ self._send_engine_request('REGISTER_CONTEXT', {
870
+ 'layer_name': layer_name,
871
+ 'wrapper_bytes': wrapper_bytes,
872
+ 'num_blocks': num_blocks,
873
+ 'bytes_per_block': bytes_per_block,
874
+ 'kv_stride_bytes': kv_stride_bytes,
875
+ 'segments': segments,
876
+ })
877
+ except Exception as e:
878
+ raise RuntimeError(
879
+ f"Failed to register layer {layer_name} with engine: {e}"
880
+ ) from e
881
+
882
+ logger.info(
883
+ "[PegaKVConnector] Registered %d KV cache layers (%s layout)",
884
+ len(kv_caches),
885
+ layout if kv_caches else "unknown",
886
+ )
887
+
888
+ def unregister_context(self) -> None:
889
+ """Unregister the active inference context from the engine server."""
890
+ if not self._registered_layers:
891
+ return
892
+
893
+ try:
894
+ self._send_engine_request('UNREGISTER_CONTEXT', {})
895
+ except Exception as exc:
896
+ logger.debug(
897
+ "[PegaKVConnector] Failed to unregister context: %s",
898
+ exc,
899
+ exc_info=True,
900
+ )
901
+ finally:
902
+ self._registered_layers.clear()
903
+
904
+ def shutdown(self):
905
+ """Shutdown the connector and unregister all KV caches."""
906
+ self.unregister_context()
907
+ self._stop_lookup_server()
908
+
909
+ # Shutdown engine connection
910
+ if self._engine_socket is not None:
911
+ try:
912
+ # Send shutdown command to engine (optional - engine can stay running)
913
+ # self._send_engine_request('SHUTDOWN', {})
914
+ self._engine_socket.close(0)
915
+ except Exception:
916
+ pass
917
+ self._engine_socket = None
918
+
919
+ if self._lookup_client is not None:
920
+ try:
921
+ self._lookup_client.close(0)
922
+ except Exception:
923
+ pass
924
+ self._lookup_client = None
925
+
926
+ if self._engine_context is not None:
927
+ try:
928
+ self._engine_context.term()
929
+ except Exception:
930
+ pass
931
+ self._engine_context = None
932
+
933
+ if self._lookup_context is not None:
934
+ try:
935
+ self._lookup_context.term()
936
+ except Exception:
937
+ pass
938
+ self._lookup_context = None
939
+
940
+
941
+ __all__ = ["PegaKVConnector", "KVConnectorRole"]