pegaflow-llm 0.0.2__cp310-cp310-manylinux_2_34_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pegaflow/__init__.py +21 -0
- pegaflow/_server.py +44 -0
- pegaflow/connector/__init__.py +235 -0
- pegaflow/connector/common.py +308 -0
- pegaflow/connector/scheduler.py +225 -0
- pegaflow/connector/worker.py +451 -0
- pegaflow/connector copy.py +941 -0
- pegaflow/ipc_wrapper.py +183 -0
- pegaflow/logging_utils.py +61 -0
- pegaflow/pegaflow-server-py +0 -0
- pegaflow/pegaflow.cpython-310-x86_64-linux-gnu.so +0 -0
- pegaflow_llm-0.0.2.dist-info/METADATA +100 -0
- pegaflow_llm-0.0.2.dist-info/RECORD +15 -0
- pegaflow_llm-0.0.2.dist-info/WHEEL +4 -0
- pegaflow_llm-0.0.2.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Scheduler-side connector logic.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from collections.abc import Iterable
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
from pegaflow.connector.common import (
|
|
10
|
+
ConnectorContext,
|
|
11
|
+
LoadIntent,
|
|
12
|
+
PegaConnectorMetadata,
|
|
13
|
+
RequestTracker,
|
|
14
|
+
SaveIntent,
|
|
15
|
+
logger,
|
|
16
|
+
)
|
|
17
|
+
from pegaflow.logging_utils import timing_wrapper
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
21
|
+
from vllm.v1.core.sched.output import SchedulerOutput
|
|
22
|
+
from vllm.v1.outputs import KVConnectorOutput
|
|
23
|
+
from vllm.v1.request import Request
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SchedulerConnector:
|
|
27
|
+
"""Holds scheduler-only state and behaviors."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, context: ConnectorContext):
|
|
30
|
+
self._ctx = context
|
|
31
|
+
self._trackers: dict[str, RequestTracker] = {}
|
|
32
|
+
self._pending_load_intents: dict[str, LoadIntent] = {}
|
|
33
|
+
self._held_requests: set[str] = set()
|
|
34
|
+
|
|
35
|
+
@timing_wrapper
|
|
36
|
+
def get_num_new_matched_tokens(
|
|
37
|
+
self,
|
|
38
|
+
request: "Request",
|
|
39
|
+
num_computed_tokens: int,
|
|
40
|
+
) -> tuple[int | None, bool]:
|
|
41
|
+
req_id = request.request_id
|
|
42
|
+
num_tokens = request.num_tokens
|
|
43
|
+
block_hashes = request.block_hashes
|
|
44
|
+
|
|
45
|
+
tracker = self._get_or_create_tracker(request)
|
|
46
|
+
|
|
47
|
+
lookup_start = time.perf_counter()
|
|
48
|
+
hit_blocks = self._count_available_block_prefix(block_hashes)
|
|
49
|
+
lookup_end = time.perf_counter()
|
|
50
|
+
elapsed_us = (lookup_end - lookup_start) * 1e6
|
|
51
|
+
|
|
52
|
+
computed_blocks = num_computed_tokens // self._ctx.block_size
|
|
53
|
+
|
|
54
|
+
tracker.on_lookup(hit_blocks, computed_blocks)
|
|
55
|
+
|
|
56
|
+
num_hit_tokens = hit_blocks * self._ctx.block_size - num_computed_tokens
|
|
57
|
+
if num_hit_tokens <= 0:
|
|
58
|
+
return (0, False)
|
|
59
|
+
|
|
60
|
+
if num_hit_tokens >= num_tokens:
|
|
61
|
+
num_hit_tokens = num_tokens - 1
|
|
62
|
+
|
|
63
|
+
need_to_compute_tokens = num_tokens - num_hit_tokens
|
|
64
|
+
|
|
65
|
+
logger.info(
|
|
66
|
+
"[PegaKVConnector] hit_blocks=%d computed_blocks=%d need_to_compute_tokens=%d "
|
|
67
|
+
"hit_tokens=%d elapsed_us=%.0f for request %s",
|
|
68
|
+
hit_blocks,
|
|
69
|
+
computed_blocks,
|
|
70
|
+
need_to_compute_tokens,
|
|
71
|
+
num_hit_tokens,
|
|
72
|
+
elapsed_us,
|
|
73
|
+
req_id,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return (num_hit_tokens, True)
|
|
77
|
+
|
|
78
|
+
@timing_wrapper
|
|
79
|
+
def update_state_after_alloc(
|
|
80
|
+
self,
|
|
81
|
+
request: "Request",
|
|
82
|
+
blocks: "KVCacheBlocks",
|
|
83
|
+
num_external_tokens: int,
|
|
84
|
+
) -> None:
|
|
85
|
+
req_id = request.request_id
|
|
86
|
+
tracker = self._trackers.get(req_id)
|
|
87
|
+
if tracker is None:
|
|
88
|
+
logger.warning(
|
|
89
|
+
"[PegaKVConnector] No tracker for request %s in update_state_after_alloc",
|
|
90
|
+
req_id)
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
block_ids = list(blocks.get_block_ids()[0]) if blocks else []
|
|
94
|
+
tracker.on_alloc(block_ids, num_external_tokens)
|
|
95
|
+
|
|
96
|
+
# Always consume to clear _load state, avoiding stale state on preemption
|
|
97
|
+
load_intent = tracker.consume_load_intent()
|
|
98
|
+
if load_intent is not None:
|
|
99
|
+
self._pending_load_intents[req_id] = load_intent
|
|
100
|
+
logger.debug(
|
|
101
|
+
"[PegaKVConnector] update_state_after_alloc req=%s created LoadIntent: "
|
|
102
|
+
"%d blocks, %d tokens",
|
|
103
|
+
req_id,
|
|
104
|
+
len(load_intent.block_ids),
|
|
105
|
+
load_intent.num_tokens,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
logger.debug(
|
|
109
|
+
"[PegaKVConnector] update_state_after_alloc req=%s blocks=%d external_tokens=%d phase=%s",
|
|
110
|
+
req_id,
|
|
111
|
+
len(block_ids),
|
|
112
|
+
num_external_tokens,
|
|
113
|
+
tracker.phase.value,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
@timing_wrapper
|
|
117
|
+
def build_connector_meta(
|
|
118
|
+
self, scheduler_output: "SchedulerOutput") -> PegaConnectorMetadata:
|
|
119
|
+
save_intents: dict[str, SaveIntent] = {}
|
|
120
|
+
|
|
121
|
+
load_intents = self._pending_load_intents
|
|
122
|
+
self._pending_load_intents = {}
|
|
123
|
+
|
|
124
|
+
for req in scheduler_output.scheduled_new_reqs:
|
|
125
|
+
req_id = req.req_id
|
|
126
|
+
tracker = self._trackers.get(req_id)
|
|
127
|
+
if tracker is None:
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
|
|
131
|
+
tracker.on_scheduled(num_tokens)
|
|
132
|
+
|
|
133
|
+
if req_id not in load_intents:
|
|
134
|
+
if load_intent := tracker.consume_load_intent():
|
|
135
|
+
load_intents[req_id] = load_intent
|
|
136
|
+
|
|
137
|
+
if save_intent := tracker.consume_save_intent():
|
|
138
|
+
save_intents[req_id] = save_intent
|
|
139
|
+
|
|
140
|
+
cached_reqs = scheduler_output.scheduled_cached_reqs
|
|
141
|
+
for idx, req_id in enumerate(cached_reqs.req_ids):
|
|
142
|
+
tracker = self._trackers.get(req_id)
|
|
143
|
+
if tracker is None:
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
|
|
147
|
+
tracker.on_scheduled(num_tokens)
|
|
148
|
+
|
|
149
|
+
new_block_ids = cached_reqs.new_block_ids[idx]
|
|
150
|
+
if new_block_ids:
|
|
151
|
+
tracker.on_alloc(list(new_block_ids[0]), 0)
|
|
152
|
+
|
|
153
|
+
if save_intent := tracker.consume_save_intent():
|
|
154
|
+
save_intents[req_id] = save_intent
|
|
155
|
+
|
|
156
|
+
logger.debug(
|
|
157
|
+
"[PegaKVConnector] build_connector_meta: %d loads, %d saves",
|
|
158
|
+
len(load_intents),
|
|
159
|
+
len(save_intents),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
return PegaConnectorMetadata(
|
|
163
|
+
load_intents=load_intents,
|
|
164
|
+
save_intents=save_intents,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def update_connector_output(self,
|
|
168
|
+
connector_output: "KVConnectorOutput") -> None:
|
|
169
|
+
for req_id in connector_output.finished_sending or []:
|
|
170
|
+
tracker = self._trackers.get(req_id)
|
|
171
|
+
if tracker:
|
|
172
|
+
while tracker._saved_layers < tracker._total_layers:
|
|
173
|
+
tracker.on_layer_saved()
|
|
174
|
+
logger.debug(
|
|
175
|
+
"[PegaKVConnector] Request %s save completed, phase=%s",
|
|
176
|
+
req_id,
|
|
177
|
+
tracker.phase.value,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if tracker.is_done():
|
|
181
|
+
del self._trackers[req_id]
|
|
182
|
+
logger.debug(
|
|
183
|
+
"[PegaKVConnector] Cleaned up tracker for %s", req_id)
|
|
184
|
+
|
|
185
|
+
def request_finished(
|
|
186
|
+
self,
|
|
187
|
+
request: "Request",
|
|
188
|
+
block_ids: list[int],
|
|
189
|
+
) -> tuple[bool, dict | None]:
|
|
190
|
+
req_id = request.request_id
|
|
191
|
+
tracker = self._trackers.get(req_id)
|
|
192
|
+
|
|
193
|
+
if tracker:
|
|
194
|
+
tracker.on_finished()
|
|
195
|
+
|
|
196
|
+
if tracker.should_hold_blocks():
|
|
197
|
+
self._held_requests.add(req_id)
|
|
198
|
+
logger.debug(
|
|
199
|
+
"[PegaKVConnector] Request %s blocks held for async save",
|
|
200
|
+
req_id,
|
|
201
|
+
)
|
|
202
|
+
return (True, None)
|
|
203
|
+
|
|
204
|
+
return (False, None)
|
|
205
|
+
|
|
206
|
+
def _get_or_create_tracker(self, request: "Request") -> RequestTracker:
|
|
207
|
+
req_id = request.request_id
|
|
208
|
+
if req_id not in self._trackers:
|
|
209
|
+
self._trackers[req_id] = RequestTracker(
|
|
210
|
+
request_id=req_id,
|
|
211
|
+
block_hashes=list(request.block_hashes),
|
|
212
|
+
block_size=self._ctx.block_size,
|
|
213
|
+
num_layers=self._ctx.num_layers,
|
|
214
|
+
)
|
|
215
|
+
return self._trackers[req_id]
|
|
216
|
+
|
|
217
|
+
def _count_available_block_prefix(self, block_hashes: Iterable[bytes]) -> int:
|
|
218
|
+
ok, message, hit_blocks = self._ctx.engine_client.query(
|
|
219
|
+
self._ctx.instance_id, list(block_hashes))
|
|
220
|
+
if not ok:
|
|
221
|
+
raise RuntimeError(f"Query failed: {message}")
|
|
222
|
+
return hit_blocks
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
__all__ = ["SchedulerConnector"]
|
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Worker-side connector logic.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import pickle
|
|
6
|
+
import queue
|
|
7
|
+
import threading
|
|
8
|
+
import time
|
|
9
|
+
from collections.abc import Iterable
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any, TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
from pegaflow.connector.common import ConnectorContext, PegaConnectorMetadata, logger
|
|
16
|
+
from pegaflow.logging_utils import timing_wrapper
|
|
17
|
+
from pegaflow.ipc_wrapper import CudaIPCWrapper
|
|
18
|
+
from pegaflow.pegaflow import PyLoadState
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from vllm.attention.backends.abstract import AttentionMetadata
|
|
22
|
+
from vllm.forward_context import ForwardContext
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class SaveTask:
|
|
27
|
+
layer_name: str
|
|
28
|
+
attn_metadata: "AttentionMetadata"
|
|
29
|
+
metadata: PegaConnectorMetadata
|
|
30
|
+
request_ids: list[str]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class WorkerConnector:
|
|
34
|
+
"""Holds worker-only state and behaviors."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, context: ConnectorContext):
|
|
37
|
+
self._ctx = context
|
|
38
|
+
|
|
39
|
+
self._save_queue = queue.Queue()
|
|
40
|
+
self._save_thread = threading.Thread(target=self._save_worker,
|
|
41
|
+
daemon=True,
|
|
42
|
+
name="PegaSaveWorker")
|
|
43
|
+
self._save_thread.start()
|
|
44
|
+
|
|
45
|
+
self._req_pending_layers: dict[str, int] = {}
|
|
46
|
+
self._completed_saves: set[str] = set()
|
|
47
|
+
self._save_completion_lock = threading.Lock()
|
|
48
|
+
|
|
49
|
+
self._current_save_intents: set[str] = set()
|
|
50
|
+
|
|
51
|
+
self._pending_loads: dict[str, PyLoadState] = {}
|
|
52
|
+
self._pending_load_reqs: dict[str, set[str]] = {}
|
|
53
|
+
self._load_completion_lock = threading.Lock()
|
|
54
|
+
|
|
55
|
+
self._registered_layers: list[str] = []
|
|
56
|
+
self._layer_name_to_id: dict[str, int] = {}
|
|
57
|
+
|
|
58
|
+
self._finished_requests: set[str] = set()
|
|
59
|
+
|
|
60
|
+
def shutdown(self) -> None:
|
|
61
|
+
self.unregister_context()
|
|
62
|
+
self._save_queue.put(None)
|
|
63
|
+
self._save_thread.join()
|
|
64
|
+
|
|
65
|
+
def unregister_context(self) -> None:
|
|
66
|
+
if not self._registered_layers:
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
if self._ctx.tp_rank == 0:
|
|
70
|
+
ok, message = self._ctx.engine_client.unregister_context(
|
|
71
|
+
self._ctx.instance_id)
|
|
72
|
+
if not ok:
|
|
73
|
+
logger.warning(
|
|
74
|
+
"[PegaKVConnector] Unregister context failed: %s", message)
|
|
75
|
+
|
|
76
|
+
self._registered_layers.clear()
|
|
77
|
+
|
|
78
|
+
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
79
|
+
assert self._ctx.device_id is not None, "CUDA device id is unknown; cannot register KV caches"
|
|
80
|
+
|
|
81
|
+
self._registered_layers = list(kv_caches.keys())
|
|
82
|
+
|
|
83
|
+
self._layer_name_to_id.clear()
|
|
84
|
+
for layer_id, layer_name in enumerate(kv_caches.keys()):
|
|
85
|
+
self._layer_name_to_id[layer_name] = layer_id
|
|
86
|
+
|
|
87
|
+
layout = "unknown"
|
|
88
|
+
for layer_name, kv_cache in kv_caches.items():
|
|
89
|
+
assert kv_cache.storage_offset(
|
|
90
|
+
) == 0, f"KV cache for {layer_name} must have zero storage offset"
|
|
91
|
+
|
|
92
|
+
wrapper = CudaIPCWrapper(kv_cache)
|
|
93
|
+
wrapper_bytes = pickle.dumps(wrapper)
|
|
94
|
+
|
|
95
|
+
shape = tuple(kv_cache.shape)
|
|
96
|
+
stride = tuple(kv_cache.stride())
|
|
97
|
+
element_size = kv_cache.element_size()
|
|
98
|
+
|
|
99
|
+
if len(shape) >= 2 and shape[0] == 2:
|
|
100
|
+
num_blocks = shape[1]
|
|
101
|
+
bytes_per_block = stride[1] * element_size
|
|
102
|
+
kv_stride_bytes = stride[0] * element_size
|
|
103
|
+
segments = 2
|
|
104
|
+
layout = "KV-first"
|
|
105
|
+
else:
|
|
106
|
+
num_blocks = shape[0]
|
|
107
|
+
bytes_per_block = stride[0] * element_size
|
|
108
|
+
kv_stride_bytes = 0
|
|
109
|
+
segments = 1
|
|
110
|
+
layout = "blocks-first"
|
|
111
|
+
|
|
112
|
+
assert bytes_per_block != 0, f"Invalid bytes_per_block for {layer_name}: stride={stride}"
|
|
113
|
+
|
|
114
|
+
ok, message = self._ctx.engine_client.register_context(
|
|
115
|
+
self._ctx.instance_id,
|
|
116
|
+
self._ctx.namespace,
|
|
117
|
+
self._ctx.tp_rank,
|
|
118
|
+
self._ctx.tp_size,
|
|
119
|
+
self._ctx.device_id,
|
|
120
|
+
self._ctx.num_layers,
|
|
121
|
+
layer_name,
|
|
122
|
+
wrapper_bytes,
|
|
123
|
+
num_blocks,
|
|
124
|
+
bytes_per_block,
|
|
125
|
+
kv_stride_bytes,
|
|
126
|
+
segments,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if not ok:
|
|
130
|
+
raise RuntimeError(
|
|
131
|
+
f"Register context failed for {layer_name}: {message}")
|
|
132
|
+
|
|
133
|
+
logger.info(
|
|
134
|
+
"[PegaKVConnector] Registered %d KV cache layers (%s layout) instance=%s",
|
|
135
|
+
len(kv_caches),
|
|
136
|
+
layout,
|
|
137
|
+
self._ctx.instance_id,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def get_finished(
|
|
141
|
+
self, finished_req_ids: set[str]
|
|
142
|
+
) -> tuple[set[str] | None, set[str] | None]:
|
|
143
|
+
finished_sending: set[str] | None = None
|
|
144
|
+
finished_recving: set[str] | None = None
|
|
145
|
+
|
|
146
|
+
with self._save_completion_lock:
|
|
147
|
+
# 1. Add newly finished requests (if they have pending saves) to tracking
|
|
148
|
+
self._finished_requests.update(finished_req_ids & self._req_pending_layers.keys())
|
|
149
|
+
# 2. Identify requests whose saves have completed
|
|
150
|
+
done_saves = self._completed_saves & self._finished_requests
|
|
151
|
+
done_saves.update(self._completed_saves & finished_req_ids)
|
|
152
|
+
|
|
153
|
+
if done_saves:
|
|
154
|
+
# 3. Clean up completed requests
|
|
155
|
+
self._completed_saves -= done_saves
|
|
156
|
+
self._finished_requests -= done_saves
|
|
157
|
+
finished_sending = done_saves
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
with self._load_completion_lock:
|
|
161
|
+
completed_reqs: set[str] = set()
|
|
162
|
+
completed_shms: list[str] = []
|
|
163
|
+
|
|
164
|
+
for shm_name, req_ids in self._pending_load_reqs.items():
|
|
165
|
+
sample_req_id = next(iter(req_ids))
|
|
166
|
+
load_state = self._pending_loads.get(sample_req_id)
|
|
167
|
+
if load_state is None:
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
if load_state.is_ready():
|
|
171
|
+
state = load_state.get_state()
|
|
172
|
+
if state < 0:
|
|
173
|
+
logger.error(
|
|
174
|
+
"[PegaKVConnector] async load failed with state=%d for reqs=%s",
|
|
175
|
+
state,
|
|
176
|
+
req_ids,
|
|
177
|
+
)
|
|
178
|
+
else:
|
|
179
|
+
logger.debug(
|
|
180
|
+
"[PegaKVConnector] async load completed for %d reqs, shm=%s",
|
|
181
|
+
len(req_ids),
|
|
182
|
+
shm_name,
|
|
183
|
+
)
|
|
184
|
+
completed_reqs.update(req_ids)
|
|
185
|
+
completed_shms.append(shm_name)
|
|
186
|
+
|
|
187
|
+
for shm_name in completed_shms:
|
|
188
|
+
req_ids = self._pending_load_reqs.pop(shm_name, set())
|
|
189
|
+
for req_id in req_ids:
|
|
190
|
+
self._pending_loads.pop(req_id, None)
|
|
191
|
+
|
|
192
|
+
if completed_reqs:
|
|
193
|
+
finished_recving = completed_reqs
|
|
194
|
+
|
|
195
|
+
if finished_sending:
|
|
196
|
+
logger.debug(
|
|
197
|
+
"[PegaKVConnector] finished saving KV for requests: %s",
|
|
198
|
+
finished_sending,
|
|
199
|
+
)
|
|
200
|
+
if finished_recving:
|
|
201
|
+
logger.debug(
|
|
202
|
+
"[PegaKVConnector] finished loading KV for requests: %s",
|
|
203
|
+
finished_recving,
|
|
204
|
+
)
|
|
205
|
+
return (finished_sending, finished_recving)
|
|
206
|
+
|
|
207
|
+
@timing_wrapper
|
|
208
|
+
def start_load_kv(self, metadata: PegaConnectorMetadata,
|
|
209
|
+
forward_context: "ForwardContext",
|
|
210
|
+
**kwargs: Any) -> None:
|
|
211
|
+
self._current_save_intents = set(metadata.save_intents.keys())
|
|
212
|
+
|
|
213
|
+
if not metadata.load_intents:
|
|
214
|
+
return
|
|
215
|
+
|
|
216
|
+
total_requests = len(metadata.load_intents)
|
|
217
|
+
load_start = time.perf_counter()
|
|
218
|
+
|
|
219
|
+
all_block_ids: list[int] = []
|
|
220
|
+
all_block_hashes: list[bytes] = []
|
|
221
|
+
request_ids: list[str] = []
|
|
222
|
+
|
|
223
|
+
for req_id, load_intent in metadata.load_intents.items():
|
|
224
|
+
all_block_ids.extend(load_intent.block_ids)
|
|
225
|
+
all_block_hashes.extend(load_intent.block_hashes)
|
|
226
|
+
request_ids.append(req_id)
|
|
227
|
+
|
|
228
|
+
if not all_block_ids:
|
|
229
|
+
return
|
|
230
|
+
|
|
231
|
+
target_layers: list[str] = []
|
|
232
|
+
for layer_name, layer in forward_context.no_compile_layers.items():
|
|
233
|
+
if hasattr(layer, 'kv_cache'):
|
|
234
|
+
target_layers.append(layer_name)
|
|
235
|
+
|
|
236
|
+
if not target_layers:
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
load_state = PyLoadState()
|
|
240
|
+
shm_name = load_state.shm_name()
|
|
241
|
+
|
|
242
|
+
ok, message = self._ctx.engine_client.load(
|
|
243
|
+
self._ctx.instance_id,
|
|
244
|
+
self._ctx.tp_rank,
|
|
245
|
+
self._ctx.device_id,
|
|
246
|
+
shm_name,
|
|
247
|
+
target_layers,
|
|
248
|
+
all_block_ids,
|
|
249
|
+
all_block_hashes,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
if not ok:
|
|
253
|
+
raise RuntimeError(f"Load request failed: {message}")
|
|
254
|
+
|
|
255
|
+
num_layers = len(target_layers)
|
|
256
|
+
num_blocks = len(all_block_ids)
|
|
257
|
+
|
|
258
|
+
schedule_end = time.perf_counter()
|
|
259
|
+
schedule_time_us = (schedule_end - load_start) * 1e6
|
|
260
|
+
|
|
261
|
+
with self._load_completion_lock:
|
|
262
|
+
for req_id in request_ids:
|
|
263
|
+
self._pending_loads[req_id] = load_state
|
|
264
|
+
self._pending_load_reqs[shm_name] = set(request_ids)
|
|
265
|
+
|
|
266
|
+
logger.debug(
|
|
267
|
+
"[PegaKVConnector] started async load: %d blocks across %d layers for %d reqs, "
|
|
268
|
+
"schedule %.0f us, shm=%s",
|
|
269
|
+
num_blocks,
|
|
270
|
+
num_layers,
|
|
271
|
+
total_requests,
|
|
272
|
+
schedule_time_us,
|
|
273
|
+
shm_name,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
def wait_for_layer_load(self, layer_name: str) -> None:
|
|
277
|
+
pass
|
|
278
|
+
|
|
279
|
+
def save_kv_layer(
|
|
280
|
+
self,
|
|
281
|
+
metadata: PegaConnectorMetadata,
|
|
282
|
+
layer_name: str,
|
|
283
|
+
kv_layer: "torch.Tensor",
|
|
284
|
+
attn_metadata: "AttentionMetadata",
|
|
285
|
+
**kwargs: Any,
|
|
286
|
+
) -> None:
|
|
287
|
+
request_ids = list(metadata.save_intents.keys())
|
|
288
|
+
if not request_ids:
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
with self._save_completion_lock:
|
|
292
|
+
for req_id in request_ids:
|
|
293
|
+
if req_id not in self._req_pending_layers:
|
|
294
|
+
self._req_pending_layers[req_id] = len(
|
|
295
|
+
self._registered_layers)
|
|
296
|
+
|
|
297
|
+
self._save_queue.put(SaveTask(
|
|
298
|
+
layer_name=layer_name,
|
|
299
|
+
attn_metadata=attn_metadata,
|
|
300
|
+
metadata=metadata,
|
|
301
|
+
request_ids=request_ids,
|
|
302
|
+
))
|
|
303
|
+
|
|
304
|
+
@timing_wrapper
|
|
305
|
+
def wait_for_save(self) -> None:
|
|
306
|
+
skipped_requests: set[str] = set()
|
|
307
|
+
|
|
308
|
+
with self._save_completion_lock:
|
|
309
|
+
pending_layers = set(self._req_pending_layers.keys())
|
|
310
|
+
skipped_requests = self._current_save_intents - pending_layers
|
|
311
|
+
if skipped_requests:
|
|
312
|
+
self._completed_saves.update(skipped_requests)
|
|
313
|
+
|
|
314
|
+
self._current_save_intents = set()
|
|
315
|
+
|
|
316
|
+
pending_reqs = len(self._req_pending_layers)
|
|
317
|
+
if pending_reqs > 0:
|
|
318
|
+
logger.debug(
|
|
319
|
+
"[PegaKVConnector] %d requests still have pending layer saves",
|
|
320
|
+
pending_reqs,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
if skipped_requests:
|
|
324
|
+
logger.debug(
|
|
325
|
+
"[PegaKVConnector] Detected %d skipped saves (CUDA graph): %s",
|
|
326
|
+
len(skipped_requests),
|
|
327
|
+
skipped_requests,
|
|
328
|
+
)
|
|
329
|
+
self._handle_save_completion(skipped_requests,
|
|
330
|
+
reason="CUDA graph skip")
|
|
331
|
+
|
|
332
|
+
def _save_worker(self) -> None:
|
|
333
|
+
logger.info("[PegaKVConnector] Save worker thread started")
|
|
334
|
+
|
|
335
|
+
while True:
|
|
336
|
+
task = self._save_queue.get()
|
|
337
|
+
if task is None:
|
|
338
|
+
self._save_queue.task_done()
|
|
339
|
+
break
|
|
340
|
+
|
|
341
|
+
batch: list[SaveTask] = [task]
|
|
342
|
+
while True:
|
|
343
|
+
try:
|
|
344
|
+
t = self._save_queue.get_nowait()
|
|
345
|
+
if t is None:
|
|
346
|
+
self._process_save_batch(batch)
|
|
347
|
+
self._save_queue.task_done()
|
|
348
|
+
logger.info("[PegaKVConnector] Save worker thread stopped")
|
|
349
|
+
return
|
|
350
|
+
batch.append(t)
|
|
351
|
+
except queue.Empty:
|
|
352
|
+
break
|
|
353
|
+
|
|
354
|
+
self._process_save_batch(batch)
|
|
355
|
+
for _ in batch:
|
|
356
|
+
self._save_queue.task_done()
|
|
357
|
+
|
|
358
|
+
logger.info("[PegaKVConnector] Save worker thread stopped")
|
|
359
|
+
|
|
360
|
+
def _process_save_batch(self, batch: list[SaveTask]) -> None:
|
|
361
|
+
saves_by_layer: dict[str, tuple[list[int], list[bytes]]] = {}
|
|
362
|
+
all_request_ids: list[str] = []
|
|
363
|
+
|
|
364
|
+
for task in batch:
|
|
365
|
+
all_request_ids.extend(task.request_ids)
|
|
366
|
+
|
|
367
|
+
if task.attn_metadata.block_table is None:
|
|
368
|
+
continue
|
|
369
|
+
|
|
370
|
+
for save_intent in task.metadata.save_intents.values():
|
|
371
|
+
if not save_intent.block_ids:
|
|
372
|
+
continue
|
|
373
|
+
|
|
374
|
+
if task.layer_name not in saves_by_layer:
|
|
375
|
+
saves_by_layer[task.layer_name] = ([], [])
|
|
376
|
+
|
|
377
|
+
saves_by_layer[task.layer_name][0].extend(save_intent.block_ids)
|
|
378
|
+
saves_by_layer[task.layer_name][1].extend(save_intent.block_hashes)
|
|
379
|
+
|
|
380
|
+
if saves_by_layer:
|
|
381
|
+
# Ensure all GPU kernels have completed before reading KV cache
|
|
382
|
+
# Otherwise we may copy uninitialized memory (attention kernel is async)
|
|
383
|
+
torch.cuda.synchronize()
|
|
384
|
+
|
|
385
|
+
saves_list = [(name, ids, hashes)
|
|
386
|
+
for name, (ids, hashes) in saves_by_layer.items()]
|
|
387
|
+
|
|
388
|
+
try:
|
|
389
|
+
ok, message = self._ctx.engine_client.save(
|
|
390
|
+
self._ctx.instance_id,
|
|
391
|
+
self._ctx.tp_rank,
|
|
392
|
+
self._ctx.device_id,
|
|
393
|
+
saves_list,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
if not ok:
|
|
397
|
+
logger.error(
|
|
398
|
+
"[PegaKVConnector] Save batch failed: %s (continuing without save)",
|
|
399
|
+
message,
|
|
400
|
+
)
|
|
401
|
+
else:
|
|
402
|
+
logger.debug(
|
|
403
|
+
"[PegaKVConnector] Batch saved %d layers, %d total blocks",
|
|
404
|
+
len(saves_list),
|
|
405
|
+
sum(len(ids) for _, ids, _ in saves_list),
|
|
406
|
+
)
|
|
407
|
+
except Exception as e:
|
|
408
|
+
logger.error(
|
|
409
|
+
"[PegaKVConnector] Save RPC exception: %s (continuing without save)",
|
|
410
|
+
e,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Always decrement layer counter to release blocks, even if save failed
|
|
414
|
+
self._decrement_layer_counter(all_request_ids)
|
|
415
|
+
|
|
416
|
+
def _decrement_layer_counter(self, request_ids: list[str]) -> None:
|
|
417
|
+
completed_reqs: list[str] = []
|
|
418
|
+
|
|
419
|
+
with self._save_completion_lock:
|
|
420
|
+
for req_id in request_ids:
|
|
421
|
+
if req_id in self._req_pending_layers:
|
|
422
|
+
self._req_pending_layers[req_id] -= 1
|
|
423
|
+
assert self._req_pending_layers[req_id] >= 0, \
|
|
424
|
+
f"Layer count mismatch for request {req_id}: counter went negative"
|
|
425
|
+
|
|
426
|
+
if self._req_pending_layers[req_id] == 0:
|
|
427
|
+
self._completed_saves.add(req_id)
|
|
428
|
+
del self._req_pending_layers[req_id]
|
|
429
|
+
completed_reqs.append(req_id)
|
|
430
|
+
|
|
431
|
+
self._handle_save_completion(completed_reqs)
|
|
432
|
+
|
|
433
|
+
def _handle_save_completion(self,
|
|
434
|
+
request_ids: Iterable[str],
|
|
435
|
+
reason: str | None = None) -> None:
|
|
436
|
+
req_list = list(request_ids)
|
|
437
|
+
if not req_list:
|
|
438
|
+
return
|
|
439
|
+
|
|
440
|
+
suffix = "" if not reason else f" ({reason})"
|
|
441
|
+
layer_count = len(self._registered_layers) or self._ctx.num_layers
|
|
442
|
+
for req_id in req_list:
|
|
443
|
+
logger.debug(
|
|
444
|
+
"[PegaKVConnector] Request %s all %d layers saved%s",
|
|
445
|
+
req_id,
|
|
446
|
+
layer_count,
|
|
447
|
+
suffix,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
__all__ = ["WorkerConnector"]
|