sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +39 -674
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
- sglang/srt/layers/quantization/fp8.py +52 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +165 -67
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +90 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +60 -17
- sglang/srt/mem_cache/hiradix_cache.py +36 -8
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +418 -29
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/glm4_moe.py +6 -4
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +991 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +49 -18
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,21 @@ import psutil
|
|
8
8
|
import torch
|
9
9
|
|
10
10
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
11
|
+
from sglang.srt.utils import is_npu
|
12
|
+
|
13
|
+
_is_npu = is_npu()
|
14
|
+
if not _is_npu:
|
15
|
+
from sgl_kernel.kvcacheio import (
|
16
|
+
transfer_kv_all_layer,
|
17
|
+
transfer_kv_all_layer_lf_pf,
|
18
|
+
transfer_kv_all_layer_mla,
|
19
|
+
transfer_kv_all_layer_mla_lf_pf,
|
20
|
+
transfer_kv_direct,
|
21
|
+
transfer_kv_per_layer,
|
22
|
+
transfer_kv_per_layer_mla,
|
23
|
+
transfer_kv_per_layer_mla_pf_lf,
|
24
|
+
transfer_kv_per_layer_pf_lf,
|
25
|
+
)
|
11
26
|
|
12
27
|
logger = logging.getLogger(__name__)
|
13
28
|
|
@@ -25,7 +40,6 @@ def synchronized(debug_only=False):
|
|
25
40
|
@wraps(func)
|
26
41
|
def wrapper(self, *args, **kwargs):
|
27
42
|
if (not debug_only) or self.debug:
|
28
|
-
return func(self, *args, **kwargs)
|
29
43
|
with self.lock:
|
30
44
|
return func(self, *args, **kwargs)
|
31
45
|
else:
|
@@ -43,15 +57,18 @@ class HostKVCache(abc.ABC):
|
|
43
57
|
device_pool: KVCache,
|
44
58
|
host_to_device_ratio: float,
|
45
59
|
host_size: int,
|
60
|
+
page_size: int,
|
61
|
+
layout: str,
|
46
62
|
pin_memory: bool,
|
47
63
|
device: str,
|
48
|
-
page_size: int,
|
49
64
|
):
|
50
65
|
self.device_pool = device_pool
|
51
|
-
self.
|
66
|
+
self.page_size = page_size
|
67
|
+
self.layout = layout
|
52
68
|
self.pin_memory = pin_memory
|
53
69
|
self.device = device
|
54
|
-
|
70
|
+
|
71
|
+
self.dtype = device_pool.store_dtype
|
55
72
|
self.size_per_token = self.get_size_per_token()
|
56
73
|
if host_size > 0:
|
57
74
|
self.size = int(host_size * 1e9 // self.size_per_token)
|
@@ -99,6 +116,24 @@ class HostKVCache(abc.ABC):
|
|
99
116
|
def init_kv_buffer(self):
|
100
117
|
raise NotImplementedError()
|
101
118
|
|
119
|
+
@abc.abstractmethod
|
120
|
+
def load_to_device_per_layer(
|
121
|
+
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
122
|
+
) -> None:
|
123
|
+
"""
|
124
|
+
Load KV data from the host memory pool to the device memory pool for a specific layer.
|
125
|
+
"""
|
126
|
+
raise NotImplementedError()
|
127
|
+
|
128
|
+
@abc.abstractmethod
|
129
|
+
def backup_from_device_all_layer(
|
130
|
+
self, device_pool, host_indices, device_indices, io_backend
|
131
|
+
) -> None:
|
132
|
+
"""
|
133
|
+
Backup KV data from the device memory pool to the host memory pool for all layers.
|
134
|
+
"""
|
135
|
+
raise NotImplementedError()
|
136
|
+
|
102
137
|
@abc.abstractmethod
|
103
138
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
104
139
|
"""
|
@@ -106,6 +141,14 @@ class HostKVCache(abc.ABC):
|
|
106
141
|
"""
|
107
142
|
raise NotImplementedError()
|
108
143
|
|
144
|
+
@abc.abstractmethod
|
145
|
+
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
146
|
+
"""
|
147
|
+
Get a dummy flat data page from the host memory pool.
|
148
|
+
This is used for prefetching or initializing empty pages.
|
149
|
+
"""
|
150
|
+
raise NotImplementedError()
|
151
|
+
|
109
152
|
@abc.abstractmethod
|
110
153
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
111
154
|
"""
|
@@ -181,6 +224,15 @@ class HostKVCache(abc.ABC):
|
|
181
224
|
)
|
182
225
|
self.mem_state[indices] = MemoryStateInt.BACKUP
|
183
226
|
|
227
|
+
@synchronized(debug_only=True)
|
228
|
+
def update_prefetch(self, indices: torch.Tensor):
|
229
|
+
if not self.is_reserved(indices):
|
230
|
+
raise ValueError(
|
231
|
+
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
|
232
|
+
f"Current state: {self.get_state(indices)}"
|
233
|
+
)
|
234
|
+
self.mem_state[indices] = MemoryStateInt.BACKUP
|
235
|
+
|
184
236
|
@synchronized(debug_only=True)
|
185
237
|
def update_synced(self, indices: torch.Tensor):
|
186
238
|
self.mem_state[indices] = MemoryStateInt.SYNCED
|
@@ -222,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
222
274
|
host_to_device_ratio: float,
|
223
275
|
host_size: int,
|
224
276
|
page_size: int,
|
277
|
+
layout: str,
|
225
278
|
pin_memory: bool = True,
|
226
279
|
device: str = "cpu",
|
227
280
|
):
|
228
281
|
super().__init__(
|
229
|
-
device_pool,
|
282
|
+
device_pool,
|
283
|
+
host_to_device_ratio,
|
284
|
+
host_size,
|
285
|
+
page_size,
|
286
|
+
layout,
|
287
|
+
pin_memory,
|
288
|
+
device,
|
289
|
+
)
|
290
|
+
self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)]
|
291
|
+
self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)]
|
292
|
+
self.k_data_ptrs = torch.tensor(
|
293
|
+
[x.data_ptr() for x in self.k_data_refs],
|
294
|
+
dtype=torch.uint64,
|
295
|
+
device=self.device_pool.device,
|
296
|
+
)
|
297
|
+
self.v_data_ptrs = torch.tensor(
|
298
|
+
[x.data_ptr() for x in self.v_data_refs],
|
299
|
+
dtype=torch.uint64,
|
300
|
+
device=self.device_pool.device,
|
230
301
|
)
|
231
302
|
|
232
303
|
def get_size_per_token(self):
|
@@ -237,26 +308,21 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
237
308
|
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
238
309
|
|
239
310
|
def init_kv_buffer(self):
|
311
|
+
if self.layout == "layer_first":
|
312
|
+
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
313
|
+
elif self.layout == "page_first":
|
314
|
+
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
|
315
|
+
else:
|
316
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
317
|
+
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
|
318
|
+
self.layout_dim = self.token_stride_size * self.layer_num
|
240
319
|
return torch.empty(
|
241
|
-
|
320
|
+
dims,
|
242
321
|
dtype=self.dtype,
|
243
322
|
device=self.device,
|
244
323
|
pin_memory=self.pin_memory,
|
245
324
|
)
|
246
325
|
|
247
|
-
# todo, page first memory layout
|
248
|
-
def get_flat_data_page(self, index) -> torch.Tensor:
|
249
|
-
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
250
|
-
|
251
|
-
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
252
|
-
self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
|
253
|
-
2,
|
254
|
-
self.layer_num,
|
255
|
-
self.page_size,
|
256
|
-
self.head_num,
|
257
|
-
self.head_dim,
|
258
|
-
)
|
259
|
-
|
260
326
|
@property
|
261
327
|
def k_buffer(self):
|
262
328
|
return self.kv_buffer[0]
|
@@ -265,6 +331,171 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
265
331
|
def v_buffer(self):
|
266
332
|
return self.kv_buffer[1]
|
267
333
|
|
334
|
+
def load_to_device_per_layer(
|
335
|
+
self,
|
336
|
+
device_pool,
|
337
|
+
host_indices,
|
338
|
+
device_indices,
|
339
|
+
layer_id,
|
340
|
+
io_backend,
|
341
|
+
):
|
342
|
+
if io_backend == "kernel":
|
343
|
+
if self.layout == "layer_first":
|
344
|
+
transfer_kv_per_layer(
|
345
|
+
src_k=self.k_buffer[layer_id],
|
346
|
+
dst_k=device_pool.k_buffer[layer_id],
|
347
|
+
src_v=self.v_buffer[layer_id],
|
348
|
+
dst_v=device_pool.v_buffer[layer_id],
|
349
|
+
src_indices=host_indices,
|
350
|
+
dst_indices=device_indices,
|
351
|
+
item_size=self.token_stride_size,
|
352
|
+
)
|
353
|
+
elif self.layout == "page_first":
|
354
|
+
transfer_kv_per_layer_pf_lf(
|
355
|
+
src_k=self.k_buffer,
|
356
|
+
dst_k=device_pool.k_buffer[layer_id],
|
357
|
+
src_v=self.v_buffer,
|
358
|
+
dst_v=device_pool.v_buffer[layer_id],
|
359
|
+
src_indices=host_indices,
|
360
|
+
dst_indices=device_indices,
|
361
|
+
item_size=self.token_stride_size,
|
362
|
+
src_layout_dim=self.layout_dim,
|
363
|
+
)
|
364
|
+
else:
|
365
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
366
|
+
elif io_backend == "direct":
|
367
|
+
assert (
|
368
|
+
self.layout == "layer_first"
|
369
|
+
), f"Direct IO backend only supports layer_first layout."
|
370
|
+
transfer_kv_direct(
|
371
|
+
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
|
372
|
+
dst_layers=[
|
373
|
+
device_pool.k_buffer[layer_id],
|
374
|
+
device_pool.v_buffer[layer_id],
|
375
|
+
],
|
376
|
+
src_indices=host_indices,
|
377
|
+
dst_indices=device_indices,
|
378
|
+
page_size=self.page_size,
|
379
|
+
)
|
380
|
+
else:
|
381
|
+
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
382
|
+
|
383
|
+
def backup_from_device_all_layer(
|
384
|
+
self, device_pool, host_indices, device_indices, io_backend
|
385
|
+
):
|
386
|
+
if io_backend == "kernel":
|
387
|
+
if self.layout == "layer_first":
|
388
|
+
transfer_kv_all_layer(
|
389
|
+
src_k_layers=device_pool.k_data_ptrs,
|
390
|
+
dst_k_layers=self.k_data_ptrs,
|
391
|
+
src_v_layers=device_pool.v_data_ptrs,
|
392
|
+
dst_v_layers=self.v_data_ptrs,
|
393
|
+
src_indices=device_indices,
|
394
|
+
dst_indices=host_indices,
|
395
|
+
item_size=self.token_stride_size,
|
396
|
+
num_layers=self.layer_num,
|
397
|
+
)
|
398
|
+
elif self.layout == "page_first":
|
399
|
+
transfer_kv_all_layer_lf_pf(
|
400
|
+
src_k_layers=device_pool.k_data_ptrs,
|
401
|
+
dst_k=self.k_buffer,
|
402
|
+
src_v_layers=device_pool.v_data_ptrs,
|
403
|
+
dst_v=self.v_buffer,
|
404
|
+
src_indices=device_indices,
|
405
|
+
dst_indices=host_indices,
|
406
|
+
item_size=self.token_stride_size,
|
407
|
+
dst_layout_dim=self.layout_dim,
|
408
|
+
num_layers=self.layer_num,
|
409
|
+
)
|
410
|
+
else:
|
411
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
412
|
+
elif io_backend == "direct":
|
413
|
+
assert (
|
414
|
+
self.layout == "layer_first"
|
415
|
+
), f"Direct IO backend only supports layer_first layout."
|
416
|
+
transfer_kv_direct(
|
417
|
+
src_layers=device_pool.k_buffer + device_pool.v_buffer,
|
418
|
+
dst_layers=self.k_data_refs + self.v_data_refs,
|
419
|
+
src_indices=device_indices,
|
420
|
+
dst_indices=host_indices,
|
421
|
+
page_size=self.page_size,
|
422
|
+
)
|
423
|
+
else:
|
424
|
+
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
425
|
+
|
426
|
+
def get_flat_data_page(self, index) -> torch.Tensor:
|
427
|
+
if self.layout == "layer_first":
|
428
|
+
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
429
|
+
elif self.layout == "page_first":
|
430
|
+
return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
|
431
|
+
else:
|
432
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
433
|
+
|
434
|
+
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
435
|
+
return torch.zeros(
|
436
|
+
(2, self.layer_num, self.page_size, self.head_num, self.head_dim),
|
437
|
+
dtype=self.dtype,
|
438
|
+
device=self.device,
|
439
|
+
pin_memory=self.pin_memory,
|
440
|
+
).flatten()
|
441
|
+
|
442
|
+
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
443
|
+
if self.layout == "layer_first":
|
444
|
+
self.kv_buffer[:, :, index : index + self.page_size, :, :] = (
|
445
|
+
data_page.reshape(
|
446
|
+
2,
|
447
|
+
self.layer_num,
|
448
|
+
self.page_size,
|
449
|
+
self.head_num,
|
450
|
+
self.head_dim,
|
451
|
+
)
|
452
|
+
)
|
453
|
+
elif self.layout == "page_first":
|
454
|
+
self.kv_buffer[:, index : index + self.page_size, :, :, :] = (
|
455
|
+
data_page.reshape(
|
456
|
+
2, self.page_size, self.layer_num, self.head_num, self.head_dim
|
457
|
+
)
|
458
|
+
)
|
459
|
+
else:
|
460
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
461
|
+
|
462
|
+
def get_buffer_meta(self, keys, indices):
|
463
|
+
ptr_list = []
|
464
|
+
key_list = []
|
465
|
+
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
466
|
+
v_offset = (
|
467
|
+
self.layer_num
|
468
|
+
* self.size
|
469
|
+
* self.head_num
|
470
|
+
* self.head_dim
|
471
|
+
* self.dtype.itemsize
|
472
|
+
)
|
473
|
+
for index in range(0, len(indices), self.page_size):
|
474
|
+
for layer_id in range(self.layer_num):
|
475
|
+
k_ptr = (
|
476
|
+
kv_buffer_data_ptr
|
477
|
+
+ indices[index]
|
478
|
+
* self.head_num
|
479
|
+
* self.head_dim
|
480
|
+
* self.dtype.itemsize
|
481
|
+
+ layer_id
|
482
|
+
* self.size
|
483
|
+
* self.head_num
|
484
|
+
* self.head_dim
|
485
|
+
* self.dtype.itemsize
|
486
|
+
)
|
487
|
+
v_ptr = k_ptr + v_offset
|
488
|
+
ptr_list.append(k_ptr)
|
489
|
+
ptr_list.append(v_ptr)
|
490
|
+
key_ = keys[index // self.page_size]
|
491
|
+
key_list.append(f"{key_}_{layer_id}_k")
|
492
|
+
key_list.append(f"{key_}_{layer_id}_v")
|
493
|
+
element_size = (
|
494
|
+
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
|
495
|
+
)
|
496
|
+
element_size_list = [element_size] * len(key_list)
|
497
|
+
return key_list, ptr_list, element_size_list
|
498
|
+
|
268
499
|
|
269
500
|
class MLATokenToKVPoolHost(HostKVCache):
|
270
501
|
device_pool: MLATokenToKVPool
|
@@ -275,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
275
506
|
host_to_device_ratio: float,
|
276
507
|
host_size: int,
|
277
508
|
page_size: int,
|
509
|
+
layout: str,
|
278
510
|
pin_memory: bool = True,
|
279
511
|
device: str = "cpu",
|
280
512
|
):
|
281
513
|
super().__init__(
|
282
|
-
device_pool,
|
514
|
+
device_pool,
|
515
|
+
host_to_device_ratio,
|
516
|
+
host_size,
|
517
|
+
page_size,
|
518
|
+
layout,
|
519
|
+
pin_memory,
|
520
|
+
device,
|
521
|
+
)
|
522
|
+
self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)]
|
523
|
+
self.data_ptrs = torch.tensor(
|
524
|
+
[x.data_ptr() for x in self.data_refs],
|
525
|
+
dtype=torch.uint64,
|
526
|
+
device=self.device_pool.device,
|
283
527
|
)
|
284
528
|
|
285
529
|
def get_size_per_token(self):
|
@@ -295,25 +539,170 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
295
539
|
)
|
296
540
|
|
297
541
|
def init_kv_buffer(self):
|
298
|
-
|
299
|
-
(
|
542
|
+
if self.layout == "layer_first":
|
543
|
+
dims = (
|
300
544
|
self.layer_num,
|
301
545
|
self.size,
|
302
546
|
1,
|
303
547
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
304
|
-
)
|
548
|
+
)
|
549
|
+
elif self.layout == "page_first":
|
550
|
+
dims = (
|
551
|
+
self.size,
|
552
|
+
self.layer_num,
|
553
|
+
1,
|
554
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
555
|
+
)
|
556
|
+
else:
|
557
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
558
|
+
self.token_stride_size = (
|
559
|
+
self.kv_lora_rank + self.qk_rope_head_dim
|
560
|
+
) * self.dtype.itemsize
|
561
|
+
self.layout_dim = self.token_stride_size * self.layer_num
|
562
|
+
|
563
|
+
return torch.empty(
|
564
|
+
dims,
|
305
565
|
dtype=self.dtype,
|
306
566
|
device=self.device,
|
307
567
|
pin_memory=self.pin_memory,
|
308
568
|
)
|
309
569
|
|
570
|
+
def load_to_device_per_layer(
|
571
|
+
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
572
|
+
):
|
573
|
+
if io_backend == "kernel":
|
574
|
+
if self.layout == "layer_first":
|
575
|
+
transfer_kv_per_layer_mla(
|
576
|
+
src=self.kv_buffer[layer_id],
|
577
|
+
dst=device_pool.kv_buffer[layer_id],
|
578
|
+
src_indices=host_indices,
|
579
|
+
dst_indices=device_indices,
|
580
|
+
item_size=self.token_stride_size,
|
581
|
+
)
|
582
|
+
elif self.layout == "page_first":
|
583
|
+
transfer_kv_per_layer_mla_pf_lf(
|
584
|
+
src=self.kv_buffer,
|
585
|
+
dst=device_pool.kv_buffer[layer_id],
|
586
|
+
src_indices=host_indices,
|
587
|
+
dst_indices=device_indices,
|
588
|
+
item_size=self.token_stride_size,
|
589
|
+
src_layout_dim=self.layout_dim,
|
590
|
+
)
|
591
|
+
else:
|
592
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
593
|
+
elif io_backend == "direct":
|
594
|
+
assert (
|
595
|
+
self.layout == "layer_first"
|
596
|
+
), f"Direct IO backend only supports layer_first layout."
|
597
|
+
transfer_kv_direct(
|
598
|
+
src_layers=[self.kv_buffer[layer_id]],
|
599
|
+
dst_layers=[device_pool.kv_buffer[layer_id]],
|
600
|
+
src_indices=host_indices,
|
601
|
+
dst_indices=device_indices,
|
602
|
+
page_size=self.page_size,
|
603
|
+
)
|
604
|
+
|
605
|
+
def backup_from_device_all_layer(
|
606
|
+
self, device_pool, host_indices, device_indices, io_backend
|
607
|
+
):
|
608
|
+
if io_backend == "kernel":
|
609
|
+
if self.layout == "layer_first":
|
610
|
+
transfer_kv_all_layer_mla(
|
611
|
+
src_layers=device_pool.data_ptrs,
|
612
|
+
dst_layers=self.data_ptrs,
|
613
|
+
src_indices=device_indices,
|
614
|
+
dst_indices=host_indices,
|
615
|
+
item_size=self.token_stride_size,
|
616
|
+
num_layers=self.layer_num,
|
617
|
+
)
|
618
|
+
elif self.layout == "page_first":
|
619
|
+
transfer_kv_all_layer_mla_lf_pf(
|
620
|
+
src_layers=device_pool.data_ptrs,
|
621
|
+
dst_k=self.kv_buffer,
|
622
|
+
src_indices=device_indices,
|
623
|
+
dst_indices=host_indices,
|
624
|
+
item_size=self.token_stride_size,
|
625
|
+
dst_layout_dim=self.layout_dim,
|
626
|
+
num_layers=self.layer_num,
|
627
|
+
)
|
628
|
+
else:
|
629
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
630
|
+
elif io_backend == "direct":
|
631
|
+
assert (
|
632
|
+
self.layout == "layer_first"
|
633
|
+
), f"Direct IO backend only supports layer_first layout."
|
634
|
+
transfer_kv_direct(
|
635
|
+
src_layers=device_pool.kv_buffer,
|
636
|
+
dst_layers=self.data_refs,
|
637
|
+
src_indices=device_indices,
|
638
|
+
dst_indices=host_indices,
|
639
|
+
page_size=self.page_size,
|
640
|
+
)
|
641
|
+
else:
|
642
|
+
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
643
|
+
|
310
644
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
311
|
-
|
645
|
+
if self.layout == "layer_first":
|
646
|
+
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
647
|
+
elif self.layout == "page_first":
|
648
|
+
return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
|
649
|
+
else:
|
650
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
651
|
+
|
652
|
+
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
653
|
+
return torch.zeros(
|
654
|
+
(
|
655
|
+
self.layer_num,
|
656
|
+
self.page_size,
|
657
|
+
1,
|
658
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
659
|
+
),
|
660
|
+
dtype=self.dtype,
|
661
|
+
device=self.device,
|
662
|
+
pin_memory=self.pin_memory,
|
663
|
+
).flatten()
|
312
664
|
|
313
665
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
314
|
-
|
315
|
-
self.
|
316
|
-
|
317
|
-
|
318
|
-
|
666
|
+
if self.layout == "layer_first":
|
667
|
+
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
668
|
+
self.layer_num,
|
669
|
+
self.page_size,
|
670
|
+
1,
|
671
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
672
|
+
)
|
673
|
+
elif self.layout == "page_first":
|
674
|
+
self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape(
|
675
|
+
self.page_size,
|
676
|
+
self.layer_num,
|
677
|
+
1,
|
678
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
679
|
+
)
|
680
|
+
else:
|
681
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
682
|
+
|
683
|
+
def get_buffer_meta(self, keys, indices):
|
684
|
+
ptr_list = []
|
685
|
+
key_list = []
|
686
|
+
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
687
|
+
for index in range(0, len(indices), self.page_size):
|
688
|
+
for layer_id in range(self.layer_num):
|
689
|
+
k_ptr = (
|
690
|
+
kv_buffer_data_ptr
|
691
|
+
+ indices[index]
|
692
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
693
|
+
* self.dtype.itemsize
|
694
|
+
+ layer_id
|
695
|
+
* self.size
|
696
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
697
|
+
* self.dtype.itemsize
|
698
|
+
)
|
699
|
+
ptr_list.append(k_ptr)
|
700
|
+
key_ = keys[index // self.page_size]
|
701
|
+
key_list.append(f"{key_}_{layer_id}_k")
|
702
|
+
element_size = (
|
703
|
+
self.dtype.itemsize
|
704
|
+
* self.page_size
|
705
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
319
706
|
)
|
707
|
+
element_size_list = [element_size] * len(key_list)
|
708
|
+
return key_list, ptr_list, element_size_list
|