sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +29 -68
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +57 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +43 -39
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +145 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -12
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/model_executor/cuda_graph_runner.py +42 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +28 -23
- sglang/srt/models/glm4_moe.py +85 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -4
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +115 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +4 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -31,21 +31,17 @@ from typing import Dict, List, Optional, Tuple, Union
|
|
31
31
|
|
32
32
|
import numpy as np
|
33
33
|
import torch
|
34
|
-
import torch.distributed as dist
|
35
34
|
import triton
|
36
35
|
import triton.language as tl
|
37
36
|
|
38
37
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
39
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.utils import get_bool_env_var, is_cuda,
|
39
|
+
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
41
40
|
|
42
41
|
logger = logging.getLogger(__name__)
|
43
42
|
|
44
43
|
GB = 1024 * 1024 * 1024
|
45
44
|
_is_cuda = is_cuda()
|
46
|
-
_is_npu = is_npu()
|
47
|
-
if not _is_npu:
|
48
|
-
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
49
45
|
|
50
46
|
|
51
47
|
class ReqToTokenPool:
|
@@ -153,18 +149,6 @@ class KVCache(abc.ABC):
|
|
153
149
|
) -> None:
|
154
150
|
raise NotImplementedError()
|
155
151
|
|
156
|
-
@abc.abstractmethod
|
157
|
-
def load_from_host_per_layer(
|
158
|
-
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
159
|
-
):
|
160
|
-
raise NotImplementedError()
|
161
|
-
|
162
|
-
@abc.abstractmethod
|
163
|
-
def backup_to_host_all_layer(
|
164
|
-
self, host_pool, host_indices, device_indices, io_backend
|
165
|
-
):
|
166
|
-
raise NotImplementedError()
|
167
|
-
|
168
152
|
def register_layer_transfer_counter(self, layer_transfer_counter):
|
169
153
|
self.layer_transfer_counter = layer_transfer_counter
|
170
154
|
|
@@ -253,12 +237,18 @@ class MHATokenToKVPool(KVCache):
|
|
253
237
|
)
|
254
238
|
for _ in range(self.layer_num)
|
255
239
|
]
|
256
|
-
|
257
|
-
self.
|
258
|
-
[x.data_ptr() for x in self.k_buffer
|
240
|
+
|
241
|
+
self.k_data_ptrs = torch.tensor(
|
242
|
+
[x.data_ptr() for x in self.k_buffer],
|
243
|
+
dtype=torch.uint64,
|
244
|
+
device=self.device,
|
245
|
+
)
|
246
|
+
self.v_data_ptrs = torch.tensor(
|
247
|
+
[x.data_ptr() for x in self.v_buffer],
|
259
248
|
dtype=torch.uint64,
|
260
249
|
device=self.device,
|
261
250
|
)
|
251
|
+
self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
|
262
252
|
self.data_strides = torch.tensor(
|
263
253
|
[
|
264
254
|
np.prod(x.shape[1:]) * x.dtype.itemsize
|
@@ -347,47 +337,6 @@ class MHATokenToKVPool(KVCache):
|
|
347
337
|
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
348
338
|
torch.cuda.synchronize()
|
349
339
|
|
350
|
-
def load_from_host_per_layer(
|
351
|
-
self,
|
352
|
-
host_pool,
|
353
|
-
host_indices,
|
354
|
-
device_indices,
|
355
|
-
layer_id,
|
356
|
-
io_backend,
|
357
|
-
):
|
358
|
-
transfer_kv_per_layer(
|
359
|
-
src_k=host_pool.k_buffer[layer_id],
|
360
|
-
dst_k=self.k_buffer[layer_id],
|
361
|
-
src_v=host_pool.v_buffer[layer_id],
|
362
|
-
dst_v=self.v_buffer[layer_id],
|
363
|
-
src_indices=host_indices,
|
364
|
-
dst_indices=device_indices,
|
365
|
-
io_backend=io_backend,
|
366
|
-
page_size=self.page_size,
|
367
|
-
item_size=self.token_stride,
|
368
|
-
)
|
369
|
-
|
370
|
-
def backup_to_host_all_layer(
|
371
|
-
self, host_pool, host_indices, device_indices, io_backend
|
372
|
-
):
|
373
|
-
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
374
|
-
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
375
|
-
if layer_id - self.start_layer >= len(host_pool.k_buffer):
|
376
|
-
raise ValueError(
|
377
|
-
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
378
|
-
)
|
379
|
-
transfer_kv_per_layer(
|
380
|
-
src_k=self.k_buffer[layer_id],
|
381
|
-
dst_k=host_pool.k_buffer[layer_id],
|
382
|
-
src_v=self.v_buffer[layer_id],
|
383
|
-
dst_v=host_pool.v_buffer[layer_id],
|
384
|
-
src_indices=device_indices,
|
385
|
-
dst_indices=host_indices,
|
386
|
-
io_backend=io_backend,
|
387
|
-
page_size=self.page_size,
|
388
|
-
item_size=self.token_stride,
|
389
|
-
)
|
390
|
-
|
391
340
|
def _get_key_buffer(self, layer_id: int):
|
392
341
|
# for internal use of referencing
|
393
342
|
if self.store_dtype != self.dtype:
|
@@ -602,16 +551,6 @@ class SWAKVPool(KVCache):
|
|
602
551
|
layer_id_override=layer_id_pool,
|
603
552
|
)
|
604
553
|
|
605
|
-
def load_from_host_per_layer(
|
606
|
-
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
607
|
-
):
|
608
|
-
raise NotImplementedError("HiCache not supported for SWAKVPool.")
|
609
|
-
|
610
|
-
def backup_to_host_all_layer(
|
611
|
-
self, host_pool, host_indices, device_indices, io_backend
|
612
|
-
):
|
613
|
-
raise NotImplementedError("HiCache not supported for SWAKVPool.")
|
614
|
-
|
615
554
|
|
616
555
|
class AscendTokenToKVPool(MHATokenToKVPool):
|
617
556
|
|
@@ -823,7 +762,11 @@ class MLATokenToKVPool(KVCache):
|
|
823
762
|
for _ in range(layer_num)
|
824
763
|
]
|
825
764
|
|
826
|
-
self.
|
765
|
+
self.data_ptrs = torch.tensor(
|
766
|
+
[x.data_ptr() for x in self.kv_buffer],
|
767
|
+
dtype=torch.uint64,
|
768
|
+
device=self.device,
|
769
|
+
)
|
827
770
|
self.layer_transfer_counter = None
|
828
771
|
|
829
772
|
kv_size = self.get_kv_size_bytes()
|
@@ -909,38 +852,6 @@ class MLATokenToKVPool(KVCache):
|
|
909
852
|
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
910
853
|
)
|
911
854
|
|
912
|
-
def load_from_host_per_layer(
|
913
|
-
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
914
|
-
):
|
915
|
-
transfer_kv_per_layer_mla(
|
916
|
-
src=host_pool.kv_buffer[layer_id],
|
917
|
-
dst=self.kv_buffer[layer_id],
|
918
|
-
src_indices=host_indices,
|
919
|
-
dst_indices=device_indices,
|
920
|
-
io_backend=io_backend,
|
921
|
-
page_size=self.page_size,
|
922
|
-
item_size=self.token_stride,
|
923
|
-
)
|
924
|
-
|
925
|
-
def backup_to_host_all_layer(
|
926
|
-
self, host_pool, host_indices, device_indices, io_backend
|
927
|
-
):
|
928
|
-
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
929
|
-
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
930
|
-
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
|
931
|
-
raise ValueError(
|
932
|
-
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
933
|
-
)
|
934
|
-
transfer_kv_per_layer_mla(
|
935
|
-
src=self.kv_buffer[layer_id],
|
936
|
-
dst=host_pool.kv_buffer[layer_id],
|
937
|
-
src_indices=device_indices,
|
938
|
-
dst_indices=host_indices,
|
939
|
-
io_backend=io_backend,
|
940
|
-
page_size=self.page_size,
|
941
|
-
item_size=self.token_stride,
|
942
|
-
)
|
943
|
-
|
944
855
|
def get_cpu_copy(self, indices):
|
945
856
|
torch.cuda.synchronize()
|
946
857
|
kv_cache_cpu = []
|
@@ -1131,20 +1042,6 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
1131
1042
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
1132
1043
|
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
1133
1044
|
|
1134
|
-
def load_from_host_per_layer(
|
1135
|
-
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
1136
|
-
):
|
1137
|
-
raise NotImplementedError(
|
1138
|
-
"HiCache not supported for DoubleSparseTokenToKVPool."
|
1139
|
-
)
|
1140
|
-
|
1141
|
-
def backup_to_host_all_layer(
|
1142
|
-
self, host_pool, host_indices, device_indices, io_backend
|
1143
|
-
):
|
1144
|
-
raise NotImplementedError(
|
1145
|
-
"HiCache not supported for DoubleSparseTokenToKVPool."
|
1146
|
-
)
|
1147
|
-
|
1148
1045
|
|
1149
1046
|
@triton.jit
|
1150
1047
|
def copy_all_layer_kv_cache(
|
@@ -8,6 +8,21 @@ import psutil
|
|
8
8
|
import torch
|
9
9
|
|
10
10
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
11
|
+
from sglang.srt.utils import is_npu
|
12
|
+
|
13
|
+
_is_npu = is_npu()
|
14
|
+
if not _is_npu:
|
15
|
+
from sgl_kernel.kvcacheio import (
|
16
|
+
transfer_kv_all_layer,
|
17
|
+
transfer_kv_all_layer_lf_pf,
|
18
|
+
transfer_kv_all_layer_mla,
|
19
|
+
transfer_kv_all_layer_mla_lf_pf,
|
20
|
+
transfer_kv_direct,
|
21
|
+
transfer_kv_per_layer,
|
22
|
+
transfer_kv_per_layer_mla,
|
23
|
+
transfer_kv_per_layer_mla_pf_lf,
|
24
|
+
transfer_kv_per_layer_pf_lf,
|
25
|
+
)
|
11
26
|
|
12
27
|
logger = logging.getLogger(__name__)
|
13
28
|
|
@@ -42,15 +57,18 @@ class HostKVCache(abc.ABC):
|
|
42
57
|
device_pool: KVCache,
|
43
58
|
host_to_device_ratio: float,
|
44
59
|
host_size: int,
|
60
|
+
page_size: int,
|
61
|
+
layout: str,
|
45
62
|
pin_memory: bool,
|
46
63
|
device: str,
|
47
|
-
page_size: int,
|
48
64
|
):
|
49
65
|
self.device_pool = device_pool
|
50
|
-
self.
|
66
|
+
self.page_size = page_size
|
67
|
+
self.layout = layout
|
51
68
|
self.pin_memory = pin_memory
|
52
69
|
self.device = device
|
53
|
-
|
70
|
+
|
71
|
+
self.dtype = device_pool.store_dtype
|
54
72
|
self.size_per_token = self.get_size_per_token()
|
55
73
|
if host_size > 0:
|
56
74
|
self.size = int(host_size * 1e9 // self.size_per_token)
|
@@ -98,6 +116,24 @@ class HostKVCache(abc.ABC):
|
|
98
116
|
def init_kv_buffer(self):
|
99
117
|
raise NotImplementedError()
|
100
118
|
|
119
|
+
@abc.abstractmethod
|
120
|
+
def load_to_device_per_layer(
|
121
|
+
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
122
|
+
) -> None:
|
123
|
+
"""
|
124
|
+
Load KV data from the host memory pool to the device memory pool for a specific layer.
|
125
|
+
"""
|
126
|
+
raise NotImplementedError()
|
127
|
+
|
128
|
+
@abc.abstractmethod
|
129
|
+
def backup_from_device_all_layer(
|
130
|
+
self, device_pool, host_indices, device_indices, io_backend
|
131
|
+
) -> None:
|
132
|
+
"""
|
133
|
+
Backup KV data from the device memory pool to the host memory pool for all layers.
|
134
|
+
"""
|
135
|
+
raise NotImplementedError()
|
136
|
+
|
101
137
|
@abc.abstractmethod
|
102
138
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
103
139
|
"""
|
@@ -105,6 +141,14 @@ class HostKVCache(abc.ABC):
|
|
105
141
|
"""
|
106
142
|
raise NotImplementedError()
|
107
143
|
|
144
|
+
@abc.abstractmethod
|
145
|
+
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
146
|
+
"""
|
147
|
+
Get a dummy flat data page from the host memory pool.
|
148
|
+
This is used for prefetching or initializing empty pages.
|
149
|
+
"""
|
150
|
+
raise NotImplementedError()
|
151
|
+
|
108
152
|
@abc.abstractmethod
|
109
153
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
110
154
|
"""
|
@@ -230,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
230
274
|
host_to_device_ratio: float,
|
231
275
|
host_size: int,
|
232
276
|
page_size: int,
|
277
|
+
layout: str,
|
233
278
|
pin_memory: bool = True,
|
234
279
|
device: str = "cpu",
|
235
280
|
):
|
236
281
|
super().__init__(
|
237
|
-
device_pool,
|
282
|
+
device_pool,
|
283
|
+
host_to_device_ratio,
|
284
|
+
host_size,
|
285
|
+
page_size,
|
286
|
+
layout,
|
287
|
+
pin_memory,
|
288
|
+
device,
|
289
|
+
)
|
290
|
+
self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)]
|
291
|
+
self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)]
|
292
|
+
self.k_data_ptrs = torch.tensor(
|
293
|
+
[x.data_ptr() for x in self.k_data_refs],
|
294
|
+
dtype=torch.uint64,
|
295
|
+
device=self.device_pool.device,
|
296
|
+
)
|
297
|
+
self.v_data_ptrs = torch.tensor(
|
298
|
+
[x.data_ptr() for x in self.v_data_refs],
|
299
|
+
dtype=torch.uint64,
|
300
|
+
device=self.device_pool.device,
|
238
301
|
)
|
239
302
|
|
240
303
|
def get_size_per_token(self):
|
@@ -245,25 +308,156 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
245
308
|
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
246
309
|
|
247
310
|
def init_kv_buffer(self):
|
311
|
+
if self.layout == "layer_first":
|
312
|
+
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
313
|
+
elif self.layout == "page_first":
|
314
|
+
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
|
315
|
+
else:
|
316
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
317
|
+
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
|
318
|
+
self.layout_dim = self.token_stride_size * self.layer_num
|
248
319
|
return torch.empty(
|
249
|
-
|
320
|
+
dims,
|
250
321
|
dtype=self.dtype,
|
251
322
|
device=self.device,
|
252
323
|
pin_memory=self.pin_memory,
|
253
324
|
)
|
254
325
|
|
255
|
-
|
326
|
+
@property
|
327
|
+
def k_buffer(self):
|
328
|
+
return self.kv_buffer[0]
|
329
|
+
|
330
|
+
@property
|
331
|
+
def v_buffer(self):
|
332
|
+
return self.kv_buffer[1]
|
333
|
+
|
334
|
+
def load_to_device_per_layer(
|
335
|
+
self,
|
336
|
+
device_pool,
|
337
|
+
host_indices,
|
338
|
+
device_indices,
|
339
|
+
layer_id,
|
340
|
+
io_backend,
|
341
|
+
):
|
342
|
+
if io_backend == "kernel":
|
343
|
+
if self.layout == "layer_first":
|
344
|
+
transfer_kv_per_layer(
|
345
|
+
src_k=self.k_buffer[layer_id],
|
346
|
+
dst_k=device_pool.k_buffer[layer_id],
|
347
|
+
src_v=self.v_buffer[layer_id],
|
348
|
+
dst_v=device_pool.v_buffer[layer_id],
|
349
|
+
src_indices=host_indices,
|
350
|
+
dst_indices=device_indices,
|
351
|
+
item_size=self.token_stride_size,
|
352
|
+
)
|
353
|
+
elif self.layout == "page_first":
|
354
|
+
transfer_kv_per_layer_pf_lf(
|
355
|
+
src_k=self.k_buffer,
|
356
|
+
dst_k=device_pool.k_buffer[layer_id],
|
357
|
+
src_v=self.v_buffer,
|
358
|
+
dst_v=device_pool.v_buffer[layer_id],
|
359
|
+
src_indices=host_indices,
|
360
|
+
dst_indices=device_indices,
|
361
|
+
item_size=self.token_stride_size,
|
362
|
+
src_layout_dim=self.layout_dim,
|
363
|
+
)
|
364
|
+
else:
|
365
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
366
|
+
elif io_backend == "direct":
|
367
|
+
assert (
|
368
|
+
self.layout == "layer_first"
|
369
|
+
), f"Direct IO backend only supports layer_first layout."
|
370
|
+
transfer_kv_direct(
|
371
|
+
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
|
372
|
+
dst_layers=[
|
373
|
+
device_pool.k_buffer[layer_id],
|
374
|
+
device_pool.v_buffer[layer_id],
|
375
|
+
],
|
376
|
+
src_indices=host_indices,
|
377
|
+
dst_indices=device_indices,
|
378
|
+
page_size=self.page_size,
|
379
|
+
)
|
380
|
+
else:
|
381
|
+
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
382
|
+
|
383
|
+
def backup_from_device_all_layer(
|
384
|
+
self, device_pool, host_indices, device_indices, io_backend
|
385
|
+
):
|
386
|
+
if io_backend == "kernel":
|
387
|
+
if self.layout == "layer_first":
|
388
|
+
transfer_kv_all_layer(
|
389
|
+
src_k_layers=device_pool.k_data_ptrs,
|
390
|
+
dst_k_layers=self.k_data_ptrs,
|
391
|
+
src_v_layers=device_pool.v_data_ptrs,
|
392
|
+
dst_v_layers=self.v_data_ptrs,
|
393
|
+
src_indices=device_indices,
|
394
|
+
dst_indices=host_indices,
|
395
|
+
item_size=self.token_stride_size,
|
396
|
+
num_layers=self.layer_num,
|
397
|
+
)
|
398
|
+
elif self.layout == "page_first":
|
399
|
+
transfer_kv_all_layer_lf_pf(
|
400
|
+
src_k_layers=device_pool.k_data_ptrs,
|
401
|
+
dst_k=self.k_buffer,
|
402
|
+
src_v_layers=device_pool.v_data_ptrs,
|
403
|
+
dst_v=self.v_buffer,
|
404
|
+
src_indices=device_indices,
|
405
|
+
dst_indices=host_indices,
|
406
|
+
item_size=self.token_stride_size,
|
407
|
+
dst_layout_dim=self.layout_dim,
|
408
|
+
num_layers=self.layer_num,
|
409
|
+
)
|
410
|
+
else:
|
411
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
412
|
+
elif io_backend == "direct":
|
413
|
+
assert (
|
414
|
+
self.layout == "layer_first"
|
415
|
+
), f"Direct IO backend only supports layer_first layout."
|
416
|
+
transfer_kv_direct(
|
417
|
+
src_layers=device_pool.k_buffer + device_pool.v_buffer,
|
418
|
+
dst_layers=self.k_data_refs + self.v_data_refs,
|
419
|
+
src_indices=device_indices,
|
420
|
+
dst_indices=host_indices,
|
421
|
+
page_size=self.page_size,
|
422
|
+
)
|
423
|
+
else:
|
424
|
+
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
425
|
+
|
256
426
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
257
|
-
|
427
|
+
if self.layout == "layer_first":
|
428
|
+
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
429
|
+
elif self.layout == "page_first":
|
430
|
+
return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
|
431
|
+
else:
|
432
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
433
|
+
|
434
|
+
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
435
|
+
return torch.zeros(
|
436
|
+
(2, self.layer_num, self.page_size, self.head_num, self.head_dim),
|
437
|
+
dtype=self.dtype,
|
438
|
+
device=self.device,
|
439
|
+
pin_memory=self.pin_memory,
|
440
|
+
).flatten()
|
258
441
|
|
259
442
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
443
|
+
if self.layout == "layer_first":
|
444
|
+
self.kv_buffer[:, :, index : index + self.page_size, :, :] = (
|
445
|
+
data_page.reshape(
|
446
|
+
2,
|
447
|
+
self.layer_num,
|
448
|
+
self.page_size,
|
449
|
+
self.head_num,
|
450
|
+
self.head_dim,
|
451
|
+
)
|
452
|
+
)
|
453
|
+
elif self.layout == "page_first":
|
454
|
+
self.kv_buffer[:, index : index + self.page_size, :, :, :] = (
|
455
|
+
data_page.reshape(
|
456
|
+
2, self.page_size, self.layer_num, self.head_num, self.head_dim
|
457
|
+
)
|
458
|
+
)
|
459
|
+
else:
|
460
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
267
461
|
|
268
462
|
def get_buffer_meta(self, keys, indices):
|
269
463
|
ptr_list = []
|
@@ -302,14 +496,6 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
302
496
|
element_size_list = [element_size] * len(key_list)
|
303
497
|
return key_list, ptr_list, element_size_list
|
304
498
|
|
305
|
-
@property
|
306
|
-
def k_buffer(self):
|
307
|
-
return self.kv_buffer[0]
|
308
|
-
|
309
|
-
@property
|
310
|
-
def v_buffer(self):
|
311
|
-
return self.kv_buffer[1]
|
312
|
-
|
313
499
|
|
314
500
|
class MLATokenToKVPoolHost(HostKVCache):
|
315
501
|
device_pool: MLATokenToKVPool
|
@@ -320,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
320
506
|
host_to_device_ratio: float,
|
321
507
|
host_size: int,
|
322
508
|
page_size: int,
|
509
|
+
layout: str,
|
323
510
|
pin_memory: bool = True,
|
324
511
|
device: str = "cpu",
|
325
512
|
):
|
326
513
|
super().__init__(
|
327
|
-
device_pool,
|
514
|
+
device_pool,
|
515
|
+
host_to_device_ratio,
|
516
|
+
host_size,
|
517
|
+
page_size,
|
518
|
+
layout,
|
519
|
+
pin_memory,
|
520
|
+
device,
|
521
|
+
)
|
522
|
+
self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)]
|
523
|
+
self.data_ptrs = torch.tensor(
|
524
|
+
[x.data_ptr() for x in self.data_refs],
|
525
|
+
dtype=torch.uint64,
|
526
|
+
device=self.device_pool.device,
|
328
527
|
)
|
329
528
|
|
330
529
|
def get_size_per_token(self):
|
@@ -340,28 +539,146 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
340
539
|
)
|
341
540
|
|
342
541
|
def init_kv_buffer(self):
|
343
|
-
|
344
|
-
(
|
542
|
+
if self.layout == "layer_first":
|
543
|
+
dims = (
|
345
544
|
self.layer_num,
|
346
545
|
self.size,
|
347
546
|
1,
|
348
547
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
349
|
-
)
|
548
|
+
)
|
549
|
+
elif self.layout == "page_first":
|
550
|
+
dims = (
|
551
|
+
self.size,
|
552
|
+
self.layer_num,
|
553
|
+
1,
|
554
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
555
|
+
)
|
556
|
+
else:
|
557
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
558
|
+
self.token_stride_size = (
|
559
|
+
self.kv_lora_rank + self.qk_rope_head_dim
|
560
|
+
) * self.dtype.itemsize
|
561
|
+
self.layout_dim = self.token_stride_size * self.layer_num
|
562
|
+
|
563
|
+
return torch.empty(
|
564
|
+
dims,
|
350
565
|
dtype=self.dtype,
|
351
566
|
device=self.device,
|
352
567
|
pin_memory=self.pin_memory,
|
353
568
|
)
|
354
569
|
|
570
|
+
def load_to_device_per_layer(
|
571
|
+
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
572
|
+
):
|
573
|
+
if io_backend == "kernel":
|
574
|
+
if self.layout == "layer_first":
|
575
|
+
transfer_kv_per_layer_mla(
|
576
|
+
src=self.kv_buffer[layer_id],
|
577
|
+
dst=device_pool.kv_buffer[layer_id],
|
578
|
+
src_indices=host_indices,
|
579
|
+
dst_indices=device_indices,
|
580
|
+
item_size=self.token_stride_size,
|
581
|
+
)
|
582
|
+
elif self.layout == "page_first":
|
583
|
+
transfer_kv_per_layer_mla_pf_lf(
|
584
|
+
src=self.kv_buffer,
|
585
|
+
dst=device_pool.kv_buffer[layer_id],
|
586
|
+
src_indices=host_indices,
|
587
|
+
dst_indices=device_indices,
|
588
|
+
item_size=self.token_stride_size,
|
589
|
+
src_layout_dim=self.layout_dim,
|
590
|
+
)
|
591
|
+
else:
|
592
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
593
|
+
elif io_backend == "direct":
|
594
|
+
assert (
|
595
|
+
self.layout == "layer_first"
|
596
|
+
), f"Direct IO backend only supports layer_first layout."
|
597
|
+
transfer_kv_direct(
|
598
|
+
src_layers=[self.kv_buffer[layer_id]],
|
599
|
+
dst_layers=[device_pool.kv_buffer[layer_id]],
|
600
|
+
src_indices=host_indices,
|
601
|
+
dst_indices=device_indices,
|
602
|
+
page_size=self.page_size,
|
603
|
+
)
|
604
|
+
|
605
|
+
def backup_from_device_all_layer(
|
606
|
+
self, device_pool, host_indices, device_indices, io_backend
|
607
|
+
):
|
608
|
+
if io_backend == "kernel":
|
609
|
+
if self.layout == "layer_first":
|
610
|
+
transfer_kv_all_layer_mla(
|
611
|
+
src_layers=device_pool.data_ptrs,
|
612
|
+
dst_layers=self.data_ptrs,
|
613
|
+
src_indices=device_indices,
|
614
|
+
dst_indices=host_indices,
|
615
|
+
item_size=self.token_stride_size,
|
616
|
+
num_layers=self.layer_num,
|
617
|
+
)
|
618
|
+
elif self.layout == "page_first":
|
619
|
+
transfer_kv_all_layer_mla_lf_pf(
|
620
|
+
src_layers=device_pool.data_ptrs,
|
621
|
+
dst_k=self.kv_buffer,
|
622
|
+
src_indices=device_indices,
|
623
|
+
dst_indices=host_indices,
|
624
|
+
item_size=self.token_stride_size,
|
625
|
+
dst_layout_dim=self.layout_dim,
|
626
|
+
num_layers=self.layer_num,
|
627
|
+
)
|
628
|
+
else:
|
629
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
630
|
+
elif io_backend == "direct":
|
631
|
+
assert (
|
632
|
+
self.layout == "layer_first"
|
633
|
+
), f"Direct IO backend only supports layer_first layout."
|
634
|
+
transfer_kv_direct(
|
635
|
+
src_layers=device_pool.kv_buffer,
|
636
|
+
dst_layers=self.data_refs,
|
637
|
+
src_indices=device_indices,
|
638
|
+
dst_indices=host_indices,
|
639
|
+
page_size=self.page_size,
|
640
|
+
)
|
641
|
+
else:
|
642
|
+
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
643
|
+
|
355
644
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
356
|
-
|
645
|
+
if self.layout == "layer_first":
|
646
|
+
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
647
|
+
elif self.layout == "page_first":
|
648
|
+
return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
|
649
|
+
else:
|
650
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
651
|
+
|
652
|
+
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
653
|
+
return torch.zeros(
|
654
|
+
(
|
655
|
+
self.layer_num,
|
656
|
+
self.page_size,
|
657
|
+
1,
|
658
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
659
|
+
),
|
660
|
+
dtype=self.dtype,
|
661
|
+
device=self.device,
|
662
|
+
pin_memory=self.pin_memory,
|
663
|
+
).flatten()
|
357
664
|
|
358
665
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
359
|
-
|
360
|
-
self.
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
666
|
+
if self.layout == "layer_first":
|
667
|
+
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
668
|
+
self.layer_num,
|
669
|
+
self.page_size,
|
670
|
+
1,
|
671
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
672
|
+
)
|
673
|
+
elif self.layout == "page_first":
|
674
|
+
self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape(
|
675
|
+
self.page_size,
|
676
|
+
self.layer_num,
|
677
|
+
1,
|
678
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
679
|
+
)
|
680
|
+
else:
|
681
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
365
682
|
|
366
683
|
def get_buffer_meta(self, keys, indices):
|
367
684
|
ptr_list = []
|