pegaflow-llm 0.0.2__cp310-cp310-manylinux_2_34_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pegaflow/__init__.py ADDED
@@ -0,0 +1,21 @@
1
+ """PegaFlow - High-performance key-value storage engine with Python bindings.
2
+
3
+ This package provides:
4
+ 1. PegaEngine: Rust-based high-performance KV storage (via PyO3)
5
+ 2. PegaKVConnector: vLLM KV connector for distributed inference
6
+ """
7
+
8
+ # Import Rust-based PegaEngine from the compiled extension
9
+ try:
10
+ from .pegaflow import EngineRpcClient, PegaEngine, PyLoadState
11
+ except ImportError:
12
+ # Fallback for development when the Rust extension is not built
13
+ EngineRpcClient = None
14
+ PegaEngine = None
15
+ PyLoadState = None
16
+
17
+ # Import Python-based vLLM connector
18
+ from .connector import PegaKVConnector, KVConnectorRole
19
+
20
+ __version__ = "0.0.1"
21
+ __all__ = ["PegaEngine", "EngineRpcClient", "PyLoadState", "PegaKVConnector", "KVConnectorRole"]
pegaflow/_server.py ADDED
@@ -0,0 +1,44 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Wrapper script to launch pegaflow-server binary from the installed package.
4
+ """
5
+ import os
6
+ import subprocess
7
+ import sys
8
+ from pathlib import Path
9
+
10
+
11
+ def get_server_binary():
12
+ """Locate the pegaflow-server-py binary in the installed package."""
13
+ # The binary is in the same directory as this Python module
14
+ module_dir = Path(__file__).parent
15
+ binary_path = module_dir / "pegaflow-server-py"
16
+
17
+ if binary_path.exists() and binary_path.is_file():
18
+ return str(binary_path)
19
+
20
+ # Fallback: try to find in PATH
21
+ return "pegaflow-server-py"
22
+
23
+
24
+ def main():
25
+ """Launch pegaflow-server with command-line arguments."""
26
+ server_binary = get_server_binary()
27
+
28
+ try:
29
+ # Pass through all command-line arguments
30
+ result = subprocess.run(
31
+ [server_binary] + sys.argv[1:],
32
+ check=False,
33
+ )
34
+ sys.exit(result.returncode)
35
+ except FileNotFoundError:
36
+ print(f"Error: pegaflow-server-py binary not found at {server_binary}", file=sys.stderr)
37
+ print("Please ensure pegaflow is properly installed.", file=sys.stderr)
38
+ sys.exit(1)
39
+ except KeyboardInterrupt:
40
+ sys.exit(0)
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()
@@ -0,0 +1,235 @@
1
+ from __future__ import annotations
2
+ """
3
+ Facade for the PegaFlow vLLM connector, split into scheduler/worker implementations.
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ from typing import Any, Iterable, Optional, Tuple
9
+
10
+ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
11
+ KVConnectorBase_V1,
12
+ KVConnectorRole,
13
+ )
14
+ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
15
+
16
+ from pegaflow.connector.common import (
17
+ ConnectorContext,
18
+ ENGINE_ENDPOINT,
19
+ PegaConnectorMetadata,
20
+ RequestPhase,
21
+ RequestTracker,
22
+ derive_namespace,
23
+ logger,
24
+ resolve_instance_id,
25
+ )
26
+ from pegaflow.connector.scheduler import SchedulerConnector
27
+ from pegaflow.connector.worker import WorkerConnector
28
+ from pegaflow.logging_utils import timing_wrapper
29
+ from pegaflow.pegaflow import EngineRpcClient
30
+
31
+
32
+ class PegaKVConnector(KVConnectorBase_V1):
33
+ """v1 KV connector for PegaFlow with separated scheduler/worker logic."""
34
+
35
+ def __init__(self, vllm_config, role: KVConnectorRole):
36
+ super().__init__(vllm_config, role)
37
+
38
+ instance_id = resolve_instance_id(vllm_config)
39
+ tp_size = vllm_config.parallel_config.tensor_parallel_size
40
+ namespace = derive_namespace(vllm_config, tp_size)
41
+ num_layers = getattr(vllm_config.model_config.hf_text_config,
42
+ "num_hidden_layers", 0)
43
+ block_size = vllm_config.cache_config.block_size
44
+
45
+ tp_rank: Optional[int] = None
46
+ device_id: Optional[int] = None
47
+ if role == KVConnectorRole.WORKER:
48
+ tp_rank = get_tensor_model_parallel_rank()
49
+ if torch.cuda.is_available():
50
+ device_id = _resolve_device_id()
51
+
52
+ self._engine_endpoint = ENGINE_ENDPOINT
53
+ engine_client = EngineRpcClient(self._engine_endpoint)
54
+ logger.info("[PegaKVConnector] Connected to engine server at %s",
55
+ self._engine_endpoint)
56
+
57
+ self._ctx = ConnectorContext(
58
+ instance_id=instance_id,
59
+ namespace=namespace,
60
+ block_size=block_size,
61
+ num_layers=num_layers,
62
+ tp_size=tp_size,
63
+ tp_rank=tp_rank,
64
+ device_id=device_id,
65
+ engine_client=engine_client,
66
+ )
67
+
68
+ self._scheduler: SchedulerConnector | None = None
69
+ self._worker: WorkerConnector | None = None
70
+ if role == KVConnectorRole.SCHEDULER:
71
+ self._scheduler = SchedulerConnector(self._ctx)
72
+ else:
73
+ self._worker = WorkerConnector(self._ctx)
74
+
75
+ logger.info(
76
+ "[PegaKVConnector] Initialized role=%s instance_id=%s device=%s tp_rank=%s tp_size=%d layers=%d namespace=%s",
77
+ role.name,
78
+ instance_id,
79
+ device_id if device_id is not None else "cpu",
80
+ tp_rank if tp_rank is not None else "N/A",
81
+ tp_size,
82
+ num_layers,
83
+ namespace,
84
+ )
85
+
86
+ # ==============================
87
+ # Worker-side methods
88
+ # ==============================
89
+ @timing_wrapper
90
+ def start_load_kv(self, forward_context, **kwargs: Any) -> None:
91
+ if not self._worker:
92
+ return
93
+ metadata = self._get_connector_metadata()
94
+ if metadata is None:
95
+ return
96
+ self._worker.start_load_kv(metadata, forward_context, **kwargs)
97
+
98
+ def wait_for_layer_load(self, layer_name: str) -> None:
99
+ if not self._worker:
100
+ return
101
+ self._worker.wait_for_layer_load(layer_name)
102
+
103
+ def save_kv_layer(
104
+ self,
105
+ layer_name: str,
106
+ kv_layer: "torch.Tensor",
107
+ attn_metadata,
108
+ **kwargs: Any,
109
+ ) -> None:
110
+ if not self._worker:
111
+ return
112
+ metadata = self._get_connector_metadata()
113
+ if metadata is None:
114
+ return
115
+ self._worker.save_kv_layer(metadata, layer_name, kv_layer,
116
+ attn_metadata, **kwargs)
117
+
118
+ @timing_wrapper
119
+ def wait_for_save(self) -> None:
120
+ if not self._worker:
121
+ return
122
+ self._worker.wait_for_save()
123
+
124
+ def get_finished(
125
+ self, finished_req_ids: set[str]
126
+ ) -> tuple[set[str] | None, set[str] | None]:
127
+ if not self._worker:
128
+ return (None, None)
129
+ return self._worker.get_finished(finished_req_ids)
130
+
131
+ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
132
+ if not self._worker:
133
+ return
134
+ self._worker.register_kv_caches(kv_caches)
135
+
136
+ def unregister_context(self) -> None:
137
+ if self._worker:
138
+ self._worker.unregister_context()
139
+
140
+ # ==============================
141
+ # Scheduler-side methods
142
+ # ==============================
143
+ def update_connector_output(self, connector_output) -> None:
144
+ if self._scheduler:
145
+ self._scheduler.update_connector_output(connector_output)
146
+
147
+ def request_finished(
148
+ self,
149
+ request,
150
+ block_ids: list[int],
151
+ ) -> tuple[bool, dict[str, Any] | None]:
152
+ if self._scheduler:
153
+ return self._scheduler.request_finished(request, block_ids)
154
+ return (False, None)
155
+
156
+ def take_events(self) -> Iterable:
157
+ return ()
158
+
159
+ @timing_wrapper
160
+ def get_num_new_matched_tokens(
161
+ self,
162
+ request,
163
+ num_computed_tokens: int,
164
+ ) -> Tuple[Optional[int], bool]:
165
+ if not self._scheduler:
166
+ return (0, False)
167
+ return self._scheduler.get_num_new_matched_tokens(
168
+ request, num_computed_tokens)
169
+
170
+ @timing_wrapper
171
+ def update_state_after_alloc(
172
+ self,
173
+ request,
174
+ blocks,
175
+ num_external_tokens: int,
176
+ ) -> None:
177
+ if self._scheduler:
178
+ self._scheduler.update_state_after_alloc(
179
+ request, blocks, num_external_tokens)
180
+
181
+ @timing_wrapper
182
+ def build_connector_meta(self, scheduler_output) -> PegaConnectorMetadata:
183
+ if not self._scheduler:
184
+ return PegaConnectorMetadata()
185
+ return self._scheduler.build_connector_meta(scheduler_output)
186
+
187
+ # ==============================
188
+ # Defaults and shutdown
189
+ # ==============================
190
+ def get_block_ids_with_load_errors(self) -> set[int]:
191
+ return set()
192
+
193
+ def get_kv_connector_stats(self):
194
+ return None
195
+
196
+ def get_handshake_metadata(self):
197
+ return None
198
+
199
+ def set_host_xfer_buffer_ops(self, copy_operation):
200
+ return
201
+
202
+ def get_finished_count(self) -> int | None:
203
+ return None
204
+
205
+ def shutdown(self):
206
+ if self._worker:
207
+ self._worker.shutdown()
208
+
209
+
210
+ def _resolve_device_id() -> int:
211
+ """
212
+ Return the global CUDA device id even when CUDA_VISIBLE_DEVICES masks GPUs.
213
+
214
+ torch.cuda.current_device() returns the local index within the visible set,
215
+ but we need the actual global device ID for operations like CUDA IPC.
216
+ This function maps the local index back to the global device ID.
217
+ """
218
+ local_id = torch.cuda.current_device()
219
+ visible = os.environ.get("CUDA_VISIBLE_DEVICES")
220
+ if not visible:
221
+ return local_id
222
+
223
+ slots = [slot.strip() for slot in visible.split(",") if slot.strip()]
224
+ try:
225
+ mapped = slots[local_id]
226
+ except IndexError:
227
+ return local_id
228
+
229
+ try:
230
+ return int(mapped)
231
+ except ValueError:
232
+ return local_id
233
+
234
+
235
+ __all__ = ["PegaKVConnector", "KVConnectorRole"]
@@ -0,0 +1,308 @@
1
+ """
2
+ Shared types and helpers for the PegaFlow vLLM connector.
3
+ """
4
+
5
+ import enum
6
+ import hashlib
7
+ import os
8
+ import uuid
9
+ from dataclasses import dataclass, field
10
+
11
+ from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
12
+
13
+ from pegaflow.logging_utils import get_connector_logger
14
+ from pegaflow.pegaflow import EngineRpcClient
15
+
16
+ logger = get_connector_logger()
17
+
18
+ # Engine server endpoint (gRPC URL)
19
+ ENGINE_ENDPOINT = os.environ.get("PEGAFLOW_ENGINE_ENDPOINT",
20
+ "http://127.0.0.1:50055")
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class ConnectorContext:
25
+ """Shared configuration for scheduler/worker connectors."""
26
+ instance_id: str
27
+ namespace: str
28
+ block_size: int
29
+ num_layers: int
30
+ tp_size: int
31
+ tp_rank: int | None
32
+ device_id: int | None
33
+ engine_client: EngineRpcClient
34
+
35
+
36
+ class RequestPhase(enum.Enum):
37
+ """Lifecycle phase of a request in the KV connector."""
38
+ LOOKUP = "lookup" # Waiting for lookup result from external storage
39
+ LOADING = "loading" # Need to load KV from external storage
40
+ ACTIVE = "active" # Actively generating (may be saving concurrently)
41
+ DRAINING = "draining" # Generation done, waiting for async save to complete
42
+ DONE = "done" # Fully completed
43
+
44
+
45
+ @dataclass(frozen=True)
46
+ class LoadIntent:
47
+ """Intent for a KV load operation."""
48
+ block_ids: tuple[int, ...]
49
+ block_hashes: tuple[bytes, ...]
50
+ num_tokens: int
51
+
52
+
53
+ @dataclass(slots=True)
54
+ class LoadState:
55
+ """
56
+ Mutable state for an in-progress load operation.
57
+
58
+ Lifecycle:
59
+ - Created by on_lookup() when cache hit is detected
60
+ - Updated by on_alloc() with allocated block IDs
61
+ - Consumed by consume_load_intent() which returns LoadIntent and clears state
62
+ """
63
+ hit_blocks: int
64
+ computed_blocks: int
65
+ allocated_blocks: list[int] = field(default_factory=list)
66
+ external_tokens: int = 0
67
+
68
+ def to_intent(
69
+ self,
70
+ block_hashes: tuple[bytes, ...],
71
+ block_size: int,
72
+ ) -> LoadIntent | None:
73
+ """
74
+ Convert to LoadIntent if conditions are met.
75
+
76
+ Returns None if:
77
+ - No external tokens to load
78
+ - All hits are already computed (prefix cache)
79
+ - No blocks to load after accounting for computed blocks
80
+ """
81
+ if self.external_tokens <= 0 or self.hit_blocks <= self.computed_blocks:
82
+ return None
83
+
84
+ num_blocks = min(self.hit_blocks, len(self.allocated_blocks), len(block_hashes))
85
+ load_blocks = num_blocks - self.computed_blocks
86
+ if load_blocks <= 0:
87
+ return None
88
+
89
+ start = self.computed_blocks
90
+ return LoadIntent(
91
+ block_ids=tuple(self.allocated_blocks[start:start + load_blocks]),
92
+ block_hashes=block_hashes[start:start + load_blocks],
93
+ num_tokens=load_blocks * block_size,
94
+ )
95
+
96
+
97
+ @dataclass(frozen=True)
98
+ class SaveIntent:
99
+ """Intent for a KV save operation."""
100
+ block_ids: tuple[int, ...]
101
+ block_hashes: tuple[bytes, ...]
102
+
103
+
104
+ class RequestTracker:
105
+ """
106
+ Tracks the KV cache state for a single request.
107
+
108
+ Load state lifecycle:
109
+ - on_lookup() creates LoadState (or None if no hit)
110
+ - on_alloc() updates LoadState with allocated blocks
111
+ - consume_load_intent() returns LoadIntent and clears LoadState
112
+ - Preemption: next on_lookup() replaces stale LoadState
113
+ """
114
+
115
+ __slots__ = (
116
+ 'request_id',
117
+ '_block_hashes',
118
+ '_block_size',
119
+ '_load',
120
+ '_allocated_blocks',
121
+ '_scheduled_tokens',
122
+ '_stored_blocks',
123
+ '_total_layers',
124
+ '_saved_layers',
125
+ '_finished',
126
+ )
127
+
128
+ def __init__(
129
+ self,
130
+ request_id: str,
131
+ block_hashes: list[bytes],
132
+ block_size: int,
133
+ num_layers: int,
134
+ ):
135
+ self.request_id = request_id
136
+ self._block_hashes = tuple(block_hashes)
137
+ self._block_size = block_size
138
+ self._load: LoadState | None = None
139
+ self._allocated_blocks: list[int] = []
140
+ self._scheduled_tokens: int = 0
141
+ self._stored_blocks: int = 0
142
+ self._total_layers = num_layers
143
+ self._saved_layers: int = 0
144
+ self._finished: bool = False
145
+
146
+ @property
147
+ def phase(self) -> RequestPhase:
148
+ if self._load is not None:
149
+ return RequestPhase.LOADING
150
+ if not self._finished:
151
+ return RequestPhase.ACTIVE
152
+ if self._saved_layers < self._total_layers:
153
+ return RequestPhase.DRAINING
154
+ return RequestPhase.DONE
155
+
156
+ @property
157
+ def num_blocks(self) -> int:
158
+ return len(self._block_hashes)
159
+
160
+ def on_lookup(self, hit_blocks: int, computed_blocks: int) -> None:
161
+ """New lookup = fresh load state. Handles preemption implicitly."""
162
+ self._load = (
163
+ LoadState(hit_blocks=hit_blocks, computed_blocks=computed_blocks)
164
+ if hit_blocks > computed_blocks
165
+ else None
166
+ )
167
+ self._allocated_blocks = []
168
+
169
+ def on_alloc(self, block_ids: list[int], num_external_tokens: int) -> None:
170
+ self._allocated_blocks.extend(block_ids)
171
+ if self._load is not None:
172
+ self._load.allocated_blocks.extend(block_ids)
173
+ if num_external_tokens > 0:
174
+ self._load.external_tokens = num_external_tokens
175
+
176
+ def consume_load_intent(self) -> LoadIntent | None:
177
+ load, self._load = self._load, None
178
+ if load is None:
179
+ return None
180
+ return load.to_intent(self._block_hashes, self._block_size)
181
+
182
+ def on_scheduled(self, num_tokens: int) -> None:
183
+ self._scheduled_tokens += num_tokens
184
+
185
+ def on_layer_saved(self) -> None:
186
+ self._saved_layers += 1
187
+
188
+ def on_finished(self) -> None:
189
+ self._finished = True
190
+
191
+ def consume_save_intent(self) -> SaveIntent | None:
192
+ saveable = min(
193
+ len(self._block_hashes),
194
+ len(self._allocated_blocks),
195
+ self._scheduled_tokens // self._block_size,
196
+ )
197
+ new_blocks = saveable - self._stored_blocks
198
+ if new_blocks <= 0:
199
+ return None
200
+
201
+ start = self._stored_blocks
202
+ self._stored_blocks = start + new_blocks
203
+ return SaveIntent(
204
+ block_ids=tuple(self._allocated_blocks[start:self._stored_blocks]),
205
+ block_hashes=self._block_hashes[start:self._stored_blocks],
206
+ )
207
+
208
+ def should_hold_blocks(self) -> bool:
209
+ return (self._finished and self._stored_blocks > 0 and
210
+ self._saved_layers < self._total_layers)
211
+
212
+ def is_done(self) -> bool:
213
+ return self.phase == RequestPhase.DONE
214
+
215
+ def __repr__(self) -> str:
216
+ return (
217
+ f"RequestTracker({self.request_id}, {self.phase.value}, "
218
+ f"load={self._load}, alloc={len(self._allocated_blocks)}, "
219
+ f"stored={self._stored_blocks}, saved={self._saved_layers}/{self._total_layers})"
220
+ )
221
+
222
+
223
+ class PegaConnectorMetadata(KVConnectorMetadata):
224
+ """Metadata passed from scheduler to worker for KV cache operations."""
225
+
226
+ def __init__(
227
+ self,
228
+ load_intents: dict[str, LoadIntent] | None = None,
229
+ save_intents: dict[str, SaveIntent] | None = None,
230
+ ):
231
+ super().__init__()
232
+ # Maps request_id -> intent
233
+ self.load_intents: dict[str, LoadIntent] = load_intents or {}
234
+ self.save_intents: dict[str, SaveIntent] = save_intents or {}
235
+
236
+ def __repr__(self) -> str:
237
+ return (f"PegaConnectorMetadata(loads={len(self.load_intents)}, "
238
+ f"saves={len(self.save_intents)})")
239
+
240
+
241
+ def resolve_instance_id(vllm_config, dp_rank_suffix: bool = True) -> str:
242
+ """Resolve or generate connector instance_id with optional DP rank suffix."""
243
+ instance_id = vllm_config.kv_transfer_config.engine_id
244
+ if instance_id:
245
+ logger.info(
246
+ "[PegaKVConnector] Using kv_transfer_config.engine_id: %s",
247
+ instance_id)
248
+ return instance_id
249
+
250
+ instance_id = vllm_config.instance_id or os.environ.get(
251
+ "PEGAFLOW_INSTANCE_ID", "")
252
+ if not instance_id:
253
+ instance_id = uuid.uuid4().hex
254
+ logger.info(
255
+ "[PegaKVConnector] No instance_id from vLLM; generated fallback %s",
256
+ instance_id)
257
+
258
+ if dp_rank_suffix:
259
+ parallel_config = vllm_config.parallel_config
260
+ if parallel_config.data_parallel_size > 1:
261
+ local_dp_rank = parallel_config.data_parallel_rank_local
262
+ if local_dp_rank is not None:
263
+ instance_id = f"{instance_id}_dp{local_dp_rank}"
264
+ logger.info(
265
+ "[PegaKVConnector] Appended DP rank to instance_id: %s (dp_size=%d, local_dp_rank=%d)",
266
+ instance_id,
267
+ parallel_config.data_parallel_size,
268
+ local_dp_rank,
269
+ )
270
+
271
+ return instance_id
272
+
273
+
274
+ def derive_namespace(vllm_config, tp_size: int) -> str:
275
+ """
276
+ Derive namespace for storage isolation.
277
+ """
278
+ model_config = vllm_config.model_config
279
+ cache_config = vllm_config.cache_config
280
+
281
+ factors = {
282
+ "model": model_config.model,
283
+ "dtype": str(model_config.dtype),
284
+ "tp_size": tp_size,
285
+ "num_kv_heads": model_config.get_total_num_kv_heads(),
286
+ "head_size": model_config.get_head_size(),
287
+ "num_hidden_layers": model_config.get_total_num_hidden_layers(),
288
+ "cache_dtype": str(cache_config.cache_dtype),
289
+ }
290
+
291
+ factor_str = str(sorted(factors.items()))
292
+ hash_suffix = hashlib.sha256(factor_str.encode()).hexdigest()[:8]
293
+ return f"{hash_suffix}"
294
+
295
+
296
+ __all__ = [
297
+ "ConnectorContext",
298
+ "ENGINE_ENDPOINT",
299
+ "LoadIntent",
300
+ "LoadState",
301
+ "PegaConnectorMetadata",
302
+ "RequestPhase",
303
+ "RequestTracker",
304
+ "SaveIntent",
305
+ "derive_namespace",
306
+ "logger",
307
+ "resolve_instance_id",
308
+ ]