sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +4 -2
- sglang/srt/layers/linear.py +159 -55
- sglang/srt/layers/logits_processor.py +6 -6
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -3
- sglang/srt/layers/parameter.py +431 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/fp8.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -1
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/schedule_batch.py +7 -1
- sglang/srt/managers/scheduler.py +10 -6
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +6 -2
- sglang/srt/mem_cache/memory_pool.py +206 -1
- sglang/srt/metrics/collector.py +22 -30
- sglang/srt/model_executor/cuda_graph_runner.py +14 -7
- sglang/srt/model_executor/forward_batch_info.py +20 -15
- sglang/srt/model_executor/model_runner.py +10 -4
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/grok.py +25 -16
- sglang/srt/models/llama.py +9 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -0
- sglang/srt/server.py +11 -8
- sglang/srt/server_args.py +12 -1
- sglang/srt/speculative/eagle_utils.py +93 -85
- sglang/srt/speculative/eagle_worker.py +47 -33
- sglang/srt/utils.py +32 -5
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +6 -7
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +48 -43
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,307 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Copyright 2023-2025 SGLang Team
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
6
|
+
you may not use this file except in compliance with the License.
|
7
|
+
You may obtain a copy of the License at
|
8
|
+
|
9
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
|
+
See the License for the specific language governing permissions and
|
15
|
+
limitations under the License.
|
16
|
+
"""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import threading
|
20
|
+
from queue import PriorityQueue, Queue
|
21
|
+
from typing import Optional
|
22
|
+
|
23
|
+
import torch
|
24
|
+
|
25
|
+
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
class CacheOperation:
|
31
|
+
|
32
|
+
counter = 0
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
host_indices: torch.Tensor,
|
37
|
+
device_indices: torch.Tensor,
|
38
|
+
node_id: int,
|
39
|
+
priority: Optional[int] = None,
|
40
|
+
):
|
41
|
+
self.host_indices = host_indices
|
42
|
+
self.device_indices = device_indices
|
43
|
+
self.node_ids = [node_id]
|
44
|
+
self.data = None
|
45
|
+
|
46
|
+
self.id = CacheOperation.counter
|
47
|
+
CacheOperation.counter += 1
|
48
|
+
# default priority is the order of creation
|
49
|
+
self.priority = priority if priority is not None else self.id
|
50
|
+
|
51
|
+
def merge(self, other: "CacheOperation") -> None:
|
52
|
+
# multiple operations can be merged into a single operation for batch processing
|
53
|
+
self.host_indices = torch.cat([self.host_indices, other.host_indices])
|
54
|
+
self.device_indices = torch.cat([self.device_indices, other.device_indices])
|
55
|
+
self.priority = min(self.priority, other.priority)
|
56
|
+
self.node_ids.extend(other.node_ids)
|
57
|
+
|
58
|
+
def __lt__(self, other: "CacheOperation"):
|
59
|
+
return self.priority < other.priority
|
60
|
+
|
61
|
+
|
62
|
+
class TransferBuffer:
|
63
|
+
"""
|
64
|
+
Overlapping buffer preparation and transfer operations to improve throughput.
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None:
|
68
|
+
self.buffers = Queue(maxsize=buffer_count)
|
69
|
+
# todo: adjust the buffer size based on throughput profile of the system
|
70
|
+
self.max_buffer_size = max_buffer_size
|
71
|
+
|
72
|
+
def full(self) -> bool:
|
73
|
+
return self.buffers.full()
|
74
|
+
|
75
|
+
def empty(self) -> bool:
|
76
|
+
return self.buffers.empty()
|
77
|
+
|
78
|
+
def put(self, item, block=True) -> None:
|
79
|
+
self.buffers.put(item, block=block)
|
80
|
+
|
81
|
+
def get(self, block=True) -> Optional[CacheOperation]:
|
82
|
+
try:
|
83
|
+
return self.buffers.get(block=block)
|
84
|
+
except Exception as e:
|
85
|
+
logger.error(e)
|
86
|
+
|
87
|
+
|
88
|
+
class HiCacheController:
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
mem_pool_device: MHATokenToKVPool,
|
93
|
+
mem_pool_host: MLATokenToKVPoolHost,
|
94
|
+
write_policy: str = "write_through_selective",
|
95
|
+
):
|
96
|
+
|
97
|
+
self.mem_pool_device = mem_pool_device
|
98
|
+
self.mem_pool_host = mem_pool_host
|
99
|
+
self.write_policy = write_policy
|
100
|
+
|
101
|
+
if write_policy not in [
|
102
|
+
"write_through",
|
103
|
+
"write_through_selective",
|
104
|
+
"write_back",
|
105
|
+
]:
|
106
|
+
raise ValueError(f"Invalid write policy: {write_policy}")
|
107
|
+
|
108
|
+
self.write_queue = PriorityQueue()
|
109
|
+
self.load_queue = PriorityQueue()
|
110
|
+
|
111
|
+
self.ack_write_queue = Queue()
|
112
|
+
self.ack_load_queue = Queue()
|
113
|
+
|
114
|
+
self.write_buffer = TransferBuffer()
|
115
|
+
self.load_buffer = TransferBuffer()
|
116
|
+
|
117
|
+
self.write_stream = torch.cuda.Stream()
|
118
|
+
self.load_stream = torch.cuda.Stream()
|
119
|
+
|
120
|
+
self.write_thread = threading.Thread(
|
121
|
+
target=self.write_thread_func_buffer, daemon=True
|
122
|
+
)
|
123
|
+
self.load_thread = threading.Thread(
|
124
|
+
target=self.load_thread_func_buffer, daemon=True
|
125
|
+
)
|
126
|
+
self.write_thread.start()
|
127
|
+
self.load_thread.start()
|
128
|
+
|
129
|
+
def write(
|
130
|
+
self,
|
131
|
+
device_indices: torch.Tensor,
|
132
|
+
priority: Optional[int] = None,
|
133
|
+
node_id: int = 0,
|
134
|
+
) -> Optional[torch.Tensor]:
|
135
|
+
"""
|
136
|
+
Back up KV caches from device memory to host memory.
|
137
|
+
"""
|
138
|
+
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
139
|
+
if host_indices is None:
|
140
|
+
return None
|
141
|
+
self.write_queue.put(
|
142
|
+
CacheOperation(host_indices, device_indices, node_id, priority)
|
143
|
+
)
|
144
|
+
self.mem_pool_host.protect_write(host_indices)
|
145
|
+
return host_indices
|
146
|
+
|
147
|
+
def load(
|
148
|
+
self,
|
149
|
+
host_indices: torch.Tensor,
|
150
|
+
priority: Optional[int] = None,
|
151
|
+
node_id: int = 0,
|
152
|
+
) -> Optional[torch.Tensor]:
|
153
|
+
"""
|
154
|
+
Load KV caches from host memory to device memory.
|
155
|
+
"""
|
156
|
+
device_indices = self.mem_pool_device.alloc(len(host_indices))
|
157
|
+
if device_indices is None:
|
158
|
+
return None
|
159
|
+
self.load_queue.put(
|
160
|
+
CacheOperation(host_indices, device_indices, node_id, priority)
|
161
|
+
)
|
162
|
+
self.mem_pool_host.protect_load(host_indices)
|
163
|
+
return device_indices
|
164
|
+
|
165
|
+
def write_thread_func_direct(self):
|
166
|
+
"""
|
167
|
+
Directly write through KV caches to host memory without buffering.
|
168
|
+
"""
|
169
|
+
with torch.cuda.stream(self.write_stream):
|
170
|
+
while True:
|
171
|
+
try:
|
172
|
+
operation = self.write_queue.get(block=True)
|
173
|
+
operation.data = self.mem_pool_device.get_flat_data(
|
174
|
+
operation.device_indices
|
175
|
+
)
|
176
|
+
self.mem_pool_host.transfer(operation.host_indices, operation.data)
|
177
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
178
|
+
for node_id in operation.node_ids:
|
179
|
+
self.ack_write_queue.put(node_id)
|
180
|
+
except Exception as e:
|
181
|
+
logger.error(e)
|
182
|
+
|
183
|
+
def load_thread_func_direct(self):
|
184
|
+
"""
|
185
|
+
Directly load KV caches from host memory to device memory without buffering.
|
186
|
+
"""
|
187
|
+
with torch.cuda.stream(self.load_stream):
|
188
|
+
while True:
|
189
|
+
try:
|
190
|
+
operation = self.load_queue.get(block=True)
|
191
|
+
operation.data = self.mem_pool_host.get_flat_data(
|
192
|
+
operation.host_indices
|
193
|
+
)
|
194
|
+
self.mem_pool_device.transfer(
|
195
|
+
operation.device_indices, operation.data
|
196
|
+
)
|
197
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
198
|
+
for node_id in operation.node_ids:
|
199
|
+
self.ack_load_queue.put(node_id)
|
200
|
+
except Exception as e:
|
201
|
+
logger.error(e)
|
202
|
+
|
203
|
+
def write_aux_func(self, no_wait=False):
|
204
|
+
"""
|
205
|
+
Auxiliary function to prepare the buffer for write operations.
|
206
|
+
"""
|
207
|
+
buffer = None
|
208
|
+
while True:
|
209
|
+
try:
|
210
|
+
operation = self.write_queue.get(block=True)
|
211
|
+
if buffer is None:
|
212
|
+
buffer = operation
|
213
|
+
else:
|
214
|
+
buffer.merge(operation)
|
215
|
+
if (
|
216
|
+
no_wait
|
217
|
+
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
218
|
+
or self.write_queue.empty()
|
219
|
+
or self.write_buffer.empty()
|
220
|
+
):
|
221
|
+
assert (
|
222
|
+
buffer.device_indices.is_cuda
|
223
|
+
), "Device indices should be on GPU"
|
224
|
+
buffer.data = self.mem_pool_device.get_flat_data(
|
225
|
+
buffer.device_indices
|
226
|
+
).contiguous()
|
227
|
+
self.write_buffer.put(buffer, block=True)
|
228
|
+
buffer = None
|
229
|
+
except Exception as e:
|
230
|
+
logger.error(e)
|
231
|
+
|
232
|
+
def load_aux_func(self):
|
233
|
+
"""
|
234
|
+
Auxiliary function to prepare the buffer for load operations.
|
235
|
+
"""
|
236
|
+
buffer = None
|
237
|
+
while True:
|
238
|
+
try:
|
239
|
+
operation = self.load_queue.get(block=True)
|
240
|
+
if buffer is None:
|
241
|
+
buffer = operation
|
242
|
+
else:
|
243
|
+
buffer.merge(operation)
|
244
|
+
if (
|
245
|
+
len(buffer.host_indices) >= self.load_buffer.max_buffer_size
|
246
|
+
or self.load_queue.empty()
|
247
|
+
or self.load_buffer.empty()
|
248
|
+
):
|
249
|
+
buffer.data = (
|
250
|
+
self.mem_pool_host.get_flat_data(buffer.host_indices)
|
251
|
+
.contiguous()
|
252
|
+
.pin_memory()
|
253
|
+
)
|
254
|
+
self.load_buffer.put(buffer, block=True)
|
255
|
+
buffer = None
|
256
|
+
except Exception as e:
|
257
|
+
logger.error(e)
|
258
|
+
|
259
|
+
def write_thread_func_buffer(self):
|
260
|
+
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
261
|
+
aux_thread.start()
|
262
|
+
with torch.cuda.stream(self.write_stream):
|
263
|
+
while True:
|
264
|
+
operation = self.write_buffer.get()
|
265
|
+
if operation is None:
|
266
|
+
continue
|
267
|
+
self.mem_pool_host.transfer(operation.host_indices, operation.data)
|
268
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
269
|
+
for node_id in operation.node_ids:
|
270
|
+
self.ack_write_queue.put(node_id)
|
271
|
+
|
272
|
+
def load_thread_func_buffer(self):
|
273
|
+
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
274
|
+
aux_thread.start()
|
275
|
+
with torch.cuda.stream(self.load_stream):
|
276
|
+
while True:
|
277
|
+
operation = self.load_buffer.get()
|
278
|
+
if operation is None:
|
279
|
+
continue
|
280
|
+
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
281
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
282
|
+
for node_id in operation.node_ids:
|
283
|
+
self.ack_load_queue.put(node_id)
|
284
|
+
|
285
|
+
def evict_device(
|
286
|
+
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
287
|
+
) -> int:
|
288
|
+
if self.mem_pool_host.is_synced(host_indices):
|
289
|
+
self.mem_pool_device.free(device_indices)
|
290
|
+
self.mem_pool_host.update_backup(host_indices)
|
291
|
+
return len(device_indices)
|
292
|
+
else:
|
293
|
+
raise ValueError(
|
294
|
+
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
295
|
+
)
|
296
|
+
|
297
|
+
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
|
298
|
+
if not backup_only:
|
299
|
+
raise ValueError("Other eviction policies are not supported yet.")
|
300
|
+
|
301
|
+
if self.mem_pool_host.is_backup(host_indices):
|
302
|
+
self.mem_pool_host.free(host_indices)
|
303
|
+
return len(host_indices)
|
304
|
+
else:
|
305
|
+
raise ValueError(
|
306
|
+
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
307
|
+
)
|
@@ -20,6 +20,7 @@ import threading
|
|
20
20
|
from enum import Enum, auto
|
21
21
|
|
22
22
|
import psutil
|
23
|
+
import setproctitle
|
23
24
|
import zmq
|
24
25
|
|
25
26
|
from sglang.srt.managers.io_struct import (
|
@@ -230,6 +231,7 @@ def run_data_parallel_controller_process(
|
|
230
231
|
port_args: PortArgs,
|
231
232
|
pipe_writer,
|
232
233
|
):
|
234
|
+
setproctitle.setproctitle("sglang::data_parallel_controller")
|
233
235
|
configure_logger(server_args)
|
234
236
|
parent_process = psutil.Process().parent()
|
235
237
|
|
@@ -44,7 +44,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
|
44
44
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
45
45
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
46
46
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
47
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
47
|
+
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
48
48
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
49
49
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
50
50
|
from sglang.srt.server_args import ServerArgs
|
@@ -1163,6 +1163,11 @@ class ScheduleBatch:
|
|
1163
1163
|
input_embeds=self.input_embeds,
|
1164
1164
|
spec_algorithm=self.spec_algorithm,
|
1165
1165
|
spec_info=self.spec_info,
|
1166
|
+
capture_hidden_mode=(
|
1167
|
+
getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
|
1168
|
+
if self.spec_info
|
1169
|
+
else CaptureHiddenMode.NULL
|
1170
|
+
),
|
1166
1171
|
)
|
1167
1172
|
|
1168
1173
|
def copy(self):
|
@@ -1237,6 +1242,7 @@ class ModelWorkerBatch:
|
|
1237
1242
|
# Speculative decoding
|
1238
1243
|
spec_algorithm: SpeculativeAlgorithm = None
|
1239
1244
|
spec_info: Optional[SpecInfo] = None
|
1245
|
+
capture_hidden_mode: CaptureHiddenMode = None
|
1240
1246
|
|
1241
1247
|
|
1242
1248
|
@triton.jit
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -962,10 +962,13 @@ class Scheduler:
|
|
962
962
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
963
963
|
)
|
964
964
|
else:
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
965
|
+
(
|
966
|
+
logits_output,
|
967
|
+
next_token_ids,
|
968
|
+
model_worker_batch,
|
969
|
+
num_accepted_tokens,
|
970
|
+
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
971
|
+
self.num_generated_tokens += num_accepted_tokens
|
969
972
|
elif batch.forward_mode.is_idle():
|
970
973
|
model_worker_batch = batch.get_model_worker_batch()
|
971
974
|
self.tp_worker.forward_batch_idle(model_worker_batch)
|
@@ -1513,8 +1516,9 @@ class Scheduler:
|
|
1513
1516
|
return success, message
|
1514
1517
|
|
1515
1518
|
def update_weights_from_distributed(
|
1516
|
-
self,
|
1517
|
-
|
1519
|
+
self,
|
1520
|
+
recv_req: UpdateWeightsFromDistributedReqInput,
|
1521
|
+
) -> Tuple[bool, str]:
|
1518
1522
|
"""Update the online model parameter."""
|
1519
1523
|
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
1520
1524
|
if success:
|
@@ -99,7 +99,7 @@ class Session:
|
|
99
99
|
|
100
100
|
if last_req is not None:
|
101
101
|
# trim bos token if it is an append
|
102
|
-
if req.input_ids[0] == tokenizer.bos_token_id:
|
102
|
+
if tokenizer is not None and req.input_ids[0] == tokenizer.bos_token_id:
|
103
103
|
req.input_ids = req.input_ids[1:]
|
104
104
|
|
105
105
|
input_ids = (
|
@@ -688,7 +688,7 @@ class TokenizerManager:
|
|
688
688
|
if self.enable_metrics:
|
689
689
|
completion_tokens = (
|
690
690
|
recv_obj.completion_tokens[i]
|
691
|
-
if recv_obj
|
691
|
+
if getattr(recv_obj, "completion_tokens", None)
|
692
692
|
else 0
|
693
693
|
)
|
694
694
|
|
@@ -716,7 +716,11 @@ class TokenizerManager:
|
|
716
716
|
time.time() - state.created_time
|
717
717
|
)
|
718
718
|
# Compute time_per_output_token for the non-streaming case
|
719
|
-
if
|
719
|
+
if (
|
720
|
+
hasattr(state.obj, "stream")
|
721
|
+
and not state.obj.stream
|
722
|
+
and completion_tokens >= 1
|
723
|
+
):
|
720
724
|
self.metrics_collector.observe_time_per_output_token(
|
721
725
|
(time.time() - state.created_time)
|
722
726
|
/ completion_tokens
|
@@ -22,12 +22,16 @@ BaseTokenToKVPool maps a token location to its KV cache data.
|
|
22
22
|
"""
|
23
23
|
|
24
24
|
import logging
|
25
|
+
import threading
|
26
|
+
from enum import IntEnum
|
27
|
+
from functools import wraps
|
25
28
|
from typing import List, Tuple, Union
|
26
29
|
|
30
|
+
import psutil
|
27
31
|
import torch
|
28
32
|
|
29
33
|
from sglang.srt.layers.radix_attention import RadixAttention
|
30
|
-
from sglang.srt.utils import get_compiler_backend
|
34
|
+
from sglang.srt.utils import debug_timing, get_compiler_backend
|
31
35
|
|
32
36
|
logger = logging.getLogger(__name__)
|
33
37
|
|
@@ -213,6 +217,26 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
213
217
|
del self.k_buffer
|
214
218
|
del self.v_buffer
|
215
219
|
|
220
|
+
# Todo: different memory layout
|
221
|
+
def get_flat_data(self, indices):
|
222
|
+
# prepare a large chunk of contiguous data for efficient transfer
|
223
|
+
flatten = torch.stack(
|
224
|
+
[
|
225
|
+
torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
|
226
|
+
torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
|
227
|
+
]
|
228
|
+
)
|
229
|
+
return flatten
|
230
|
+
|
231
|
+
@debug_timing
|
232
|
+
def transfer(self, indices, flat_data):
|
233
|
+
# transfer prepared data from host to device
|
234
|
+
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
235
|
+
k_data, v_data = flat_data[0], flat_data[1]
|
236
|
+
for i in range(self.layer_num):
|
237
|
+
self.k_buffer[i][indices] = k_data[i]
|
238
|
+
self.v_buffer[i][indices] = v_data[i]
|
239
|
+
|
216
240
|
def get_key_buffer(self, layer_id: int):
|
217
241
|
if self.store_dtype != self.dtype:
|
218
242
|
return self.k_buffer[layer_id].view(self.dtype)
|
@@ -361,3 +385,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
|
361
385
|
self.k_buffer[layer_id][loc] = cache_k
|
362
386
|
self.v_buffer[layer_id][loc] = cache_v
|
363
387
|
self.label_buffer[layer_id][loc] = cache_label
|
388
|
+
|
389
|
+
|
390
|
+
class MemoryStateInt(IntEnum):
|
391
|
+
IDLE = 0
|
392
|
+
RESERVED = 1
|
393
|
+
PROTECTED = 2
|
394
|
+
SYNCED = 3
|
395
|
+
BACKUP = 4
|
396
|
+
|
397
|
+
|
398
|
+
def synchronized(func):
|
399
|
+
@wraps(func)
|
400
|
+
def wrapper(self, *args, **kwargs):
|
401
|
+
with self.lock:
|
402
|
+
return func(self, *args, **kwargs)
|
403
|
+
|
404
|
+
return wrapper
|
405
|
+
|
406
|
+
|
407
|
+
class MLATokenToKVPoolHost:
|
408
|
+
|
409
|
+
def __init__(
|
410
|
+
self,
|
411
|
+
device_pool: MHATokenToKVPool,
|
412
|
+
host_to_device_ratio: float = 2.0,
|
413
|
+
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
414
|
+
device: str = "cpu",
|
415
|
+
):
|
416
|
+
assert (
|
417
|
+
host_to_device_ratio >= 1
|
418
|
+
), "The host memory should be larger than the device memory with the current protocol"
|
419
|
+
# todo, other ways of configuring the size
|
420
|
+
|
421
|
+
self.device_pool = device_pool
|
422
|
+
self.host_to_device_ratio = host_to_device_ratio
|
423
|
+
self.pin_memory = pin_memory
|
424
|
+
self.device = device
|
425
|
+
|
426
|
+
self.size = int(device_pool.size * host_to_device_ratio)
|
427
|
+
self.dtype = device_pool.store_dtype
|
428
|
+
self.head_num = device_pool.head_num
|
429
|
+
self.head_dim = device_pool.head_dim
|
430
|
+
self.layer_num = device_pool.layer_num
|
431
|
+
self.size_per_token = (
|
432
|
+
self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
433
|
+
)
|
434
|
+
|
435
|
+
# Verify there is enough available host memory.
|
436
|
+
host_mem = psutil.virtual_memory()
|
437
|
+
requested_bytes = self.size * self.size_per_token
|
438
|
+
# preserve at least 10GB for other usage
|
439
|
+
ten_gb = 10 * (1024**3)
|
440
|
+
if requested_bytes > host_mem.available - ten_gb:
|
441
|
+
raise ValueError(
|
442
|
+
f"Not enough host memory available. Requesting "
|
443
|
+
f"{requested_bytes / 1e9:.2f} GB but only have "
|
444
|
+
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
|
445
|
+
f"size of the hierarchical cache."
|
446
|
+
)
|
447
|
+
else:
|
448
|
+
logger.info(
|
449
|
+
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
450
|
+
)
|
451
|
+
|
452
|
+
self.kv_buffer = torch.empty(
|
453
|
+
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
454
|
+
dtype=self.dtype,
|
455
|
+
device=self.device,
|
456
|
+
pin_memory=self.pin_memory,
|
457
|
+
)
|
458
|
+
|
459
|
+
# Initialize memory states and tracking structures.
|
460
|
+
self.mem_state = torch.zeros(
|
461
|
+
(self.size,), dtype=torch.uint8, device=self.device
|
462
|
+
)
|
463
|
+
self.free_slots = torch.arange(self.size, dtype=torch.int32)
|
464
|
+
self.can_use_mem_size = self.size
|
465
|
+
|
466
|
+
# A lock for synchronized operations on memory allocation and state transitions.
|
467
|
+
self.lock = threading.RLock()
|
468
|
+
|
469
|
+
def get_flat_data(self, indices):
|
470
|
+
return self.kv_buffer[:, :, indices]
|
471
|
+
|
472
|
+
@debug_timing
|
473
|
+
def transfer(self, indices, flat_data):
|
474
|
+
# backup prepared data from device to host
|
475
|
+
self.kv_buffer[:, :, indices] = flat_data.to(
|
476
|
+
device=self.device, non_blocking=False
|
477
|
+
)
|
478
|
+
|
479
|
+
@synchronized
|
480
|
+
def clear(self):
|
481
|
+
self.mem_state.fill_(0)
|
482
|
+
self.can_use_mem_size = self.size
|
483
|
+
self.free_slots = torch.arange(self.size, dtype=torch.int32)
|
484
|
+
|
485
|
+
@synchronized
|
486
|
+
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
487
|
+
assert len(indices) > 0, "The indices should not be empty"
|
488
|
+
states = self.mem_state[indices]
|
489
|
+
assert (
|
490
|
+
states == states[0]
|
491
|
+
).all(), "The memory slots should have the same state {}".format(states)
|
492
|
+
return MemoryStateInt(states[0].item())
|
493
|
+
|
494
|
+
@synchronized
|
495
|
+
def alloc(self, need_size: int) -> torch.Tensor:
|
496
|
+
if need_size > self.can_use_mem_size:
|
497
|
+
return None
|
498
|
+
|
499
|
+
# todo: de-fragementation
|
500
|
+
select_index = self.free_slots[:need_size]
|
501
|
+
self.free_slots = self.free_slots[need_size:]
|
502
|
+
|
503
|
+
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
504
|
+
self.can_use_mem_size -= need_size
|
505
|
+
|
506
|
+
return select_index
|
507
|
+
|
508
|
+
@synchronized
|
509
|
+
def is_reserved(self, indices: torch.Tensor) -> bool:
|
510
|
+
return self.get_state(indices) == MemoryStateInt.RESERVED
|
511
|
+
|
512
|
+
@synchronized
|
513
|
+
def is_protected(self, indices: torch.Tensor) -> bool:
|
514
|
+
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
515
|
+
|
516
|
+
@synchronized
|
517
|
+
def is_synced(self, indices: torch.Tensor) -> bool:
|
518
|
+
return self.get_state(indices) == MemoryStateInt.SYNCED
|
519
|
+
|
520
|
+
@synchronized
|
521
|
+
def is_backup(self, indices: torch.Tensor) -> bool:
|
522
|
+
return self.get_state(indices) == MemoryStateInt.BACKUP
|
523
|
+
|
524
|
+
@synchronized
|
525
|
+
def update_backup(self, indices: torch.Tensor):
|
526
|
+
assert self.is_synced(indices), (
|
527
|
+
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
528
|
+
f"Current state: {self.get_state(indices)}"
|
529
|
+
)
|
530
|
+
self.mem_state[indices] = MemoryStateInt.BACKUP
|
531
|
+
|
532
|
+
@synchronized
|
533
|
+
def update_synced(self, indices: torch.Tensor):
|
534
|
+
self.mem_state[indices] = MemoryStateInt.SYNCED
|
535
|
+
|
536
|
+
@synchronized
|
537
|
+
def protect_write(self, indices: torch.Tensor):
|
538
|
+
assert self.is_reserved(indices), (
|
539
|
+
f"The host memory slots should be RESERVED before write operations. "
|
540
|
+
f"Current state: {self.get_state(indices)}"
|
541
|
+
)
|
542
|
+
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
543
|
+
|
544
|
+
@synchronized
|
545
|
+
def protect_load(self, indices: torch.Tensor):
|
546
|
+
assert self.is_backup(indices), (
|
547
|
+
f"The host memory slots should be in BACKUP state before load operations. "
|
548
|
+
f"Current state: {self.get_state(indices)}"
|
549
|
+
)
|
550
|
+
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
551
|
+
|
552
|
+
@synchronized
|
553
|
+
def complete_io(self, indices: torch.Tensor):
|
554
|
+
assert self.is_protected(indices), (
|
555
|
+
f"The host memory slots should be PROTECTED during I/O operations. "
|
556
|
+
f"Current state: {self.get_state(indices)}"
|
557
|
+
)
|
558
|
+
self.mem_state[indices] = MemoryStateInt.SYNCED
|
559
|
+
|
560
|
+
def available_size(self):
|
561
|
+
return len(self.free_slots)
|
562
|
+
|
563
|
+
@synchronized
|
564
|
+
def free(self, indices: torch.Tensor) -> int:
|
565
|
+
self.mem_state[indices] = MemoryStateInt.IDLE
|
566
|
+
self.free_slots = torch.concat([self.free_slots, indices])
|
567
|
+
self.can_use_mem_size += len(indices)
|
568
|
+
return len(indices)
|