sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/conversation.py +0 -112
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  7. sglang/srt/disaggregation/launch_lb.py +5 -20
  8. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  9. sglang/srt/disaggregation/prefill.py +1 -0
  10. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  11. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  12. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  13. sglang/srt/distributed/parallel_state.py +11 -0
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +35 -15
  16. sglang/srt/eplb/expert_distribution.py +4 -2
  17. sglang/srt/hf_transformers_utils.py +25 -10
  18. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  19. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  20. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  21. sglang/srt/layers/attention/utils.py +6 -1
  22. sglang/srt/layers/attention/vision.py +27 -10
  23. sglang/srt/layers/communicator.py +14 -4
  24. sglang/srt/layers/linear.py +7 -1
  25. sglang/srt/layers/logits_processor.py +9 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +29 -68
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
  29. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  30. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  31. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  32. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  33. sglang/srt/layers/moe/utils.py +43 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  35. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  36. sglang/srt/layers/quantization/fp8.py +57 -1
  37. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  38. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  39. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  40. sglang/srt/lora/lora_registry.py +7 -0
  41. sglang/srt/managers/cache_controller.py +43 -39
  42. sglang/srt/managers/data_parallel_controller.py +52 -2
  43. sglang/srt/managers/io_struct.py +6 -1
  44. sglang/srt/managers/schedule_batch.py +3 -2
  45. sglang/srt/managers/schedule_policy.py +3 -1
  46. sglang/srt/managers/scheduler.py +145 -6
  47. sglang/srt/managers/template_manager.py +25 -22
  48. sglang/srt/managers/tokenizer_manager.py +114 -62
  49. sglang/srt/managers/utils.py +45 -1
  50. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  51. sglang/srt/mem_cache/hicache_storage.py +13 -12
  52. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  53. sglang/srt/mem_cache/memory_pool.py +15 -118
  54. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  55. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  56. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  57. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  58. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
  59. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
  60. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
  61. sglang/srt/model_executor/cuda_graph_runner.py +42 -4
  62. sglang/srt/model_executor/forward_batch_info.py +13 -3
  63. sglang/srt/model_executor/model_runner.py +13 -1
  64. sglang/srt/model_loader/weight_utils.py +2 -0
  65. sglang/srt/models/deepseek_v2.py +28 -23
  66. sglang/srt/models/glm4_moe.py +85 -22
  67. sglang/srt/models/grok.py +3 -3
  68. sglang/srt/models/llama4.py +13 -2
  69. sglang/srt/models/mixtral.py +3 -3
  70. sglang/srt/models/mllama4.py +428 -19
  71. sglang/srt/models/qwen2_moe.py +1 -4
  72. sglang/srt/models/qwen3_moe.py +7 -8
  73. sglang/srt/models/step3_vl.py +1 -4
  74. sglang/srt/multimodal/processors/base_processor.py +4 -3
  75. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  76. sglang/srt/operations_strategy.py +1 -1
  77. sglang/srt/server_args.py +115 -21
  78. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  79. sglang/srt/two_batch_overlap.py +6 -4
  80. sglang/srt/utils.py +4 -24
  81. sglang/srt/weight_sync/utils.py +1 -1
  82. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  83. sglang/test/runners.py +2 -2
  84. sglang/test/test_utils.py +3 -3
  85. sglang/version.py +1 -1
  86. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  87. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
  88. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  89. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  90. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  91. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  92. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -31,21 +31,17 @@ from typing import Dict, List, Optional, Tuple, Union
31
31
 
32
32
  import numpy as np
33
33
  import torch
34
- import torch.distributed as dist
35
34
  import triton
36
35
  import triton.language as tl
37
36
 
38
37
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
39
38
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
39
+ from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
41
40
 
42
41
  logger = logging.getLogger(__name__)
43
42
 
44
43
  GB = 1024 * 1024 * 1024
45
44
  _is_cuda = is_cuda()
46
- _is_npu = is_npu()
47
- if not _is_npu:
48
- from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
49
45
 
50
46
 
51
47
  class ReqToTokenPool:
@@ -153,18 +149,6 @@ class KVCache(abc.ABC):
153
149
  ) -> None:
154
150
  raise NotImplementedError()
155
151
 
156
- @abc.abstractmethod
157
- def load_from_host_per_layer(
158
- self, host_pool, host_indices, device_indices, layer_id, io_backend
159
- ):
160
- raise NotImplementedError()
161
-
162
- @abc.abstractmethod
163
- def backup_to_host_all_layer(
164
- self, host_pool, host_indices, device_indices, io_backend
165
- ):
166
- raise NotImplementedError()
167
-
168
152
  def register_layer_transfer_counter(self, layer_transfer_counter):
169
153
  self.layer_transfer_counter = layer_transfer_counter
170
154
 
@@ -253,12 +237,18 @@ class MHATokenToKVPool(KVCache):
253
237
  )
254
238
  for _ in range(self.layer_num)
255
239
  ]
256
- self.token_stride = self.head_num * self.head_dim
257
- self.data_ptrs = torch.tensor(
258
- [x.data_ptr() for x in self.k_buffer + self.v_buffer],
240
+
241
+ self.k_data_ptrs = torch.tensor(
242
+ [x.data_ptr() for x in self.k_buffer],
243
+ dtype=torch.uint64,
244
+ device=self.device,
245
+ )
246
+ self.v_data_ptrs = torch.tensor(
247
+ [x.data_ptr() for x in self.v_buffer],
259
248
  dtype=torch.uint64,
260
249
  device=self.device,
261
250
  )
251
+ self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
262
252
  self.data_strides = torch.tensor(
263
253
  [
264
254
  np.prod(x.shape[1:]) * x.dtype.itemsize
@@ -347,47 +337,6 @@ class MHATokenToKVPool(KVCache):
347
337
  self.v_buffer[layer_id][chunk_indices] = v_chunk
348
338
  torch.cuda.synchronize()
349
339
 
350
- def load_from_host_per_layer(
351
- self,
352
- host_pool,
353
- host_indices,
354
- device_indices,
355
- layer_id,
356
- io_backend,
357
- ):
358
- transfer_kv_per_layer(
359
- src_k=host_pool.k_buffer[layer_id],
360
- dst_k=self.k_buffer[layer_id],
361
- src_v=host_pool.v_buffer[layer_id],
362
- dst_v=self.v_buffer[layer_id],
363
- src_indices=host_indices,
364
- dst_indices=device_indices,
365
- io_backend=io_backend,
366
- page_size=self.page_size,
367
- item_size=self.token_stride,
368
- )
369
-
370
- def backup_to_host_all_layer(
371
- self, host_pool, host_indices, device_indices, io_backend
372
- ):
373
- # todo: specialized all layer kernels for the layer-non-contiguous memory pool
374
- for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
375
- if layer_id - self.start_layer >= len(host_pool.k_buffer):
376
- raise ValueError(
377
- f"Layer ID {layer_id} exceeds the number of layers in host pool."
378
- )
379
- transfer_kv_per_layer(
380
- src_k=self.k_buffer[layer_id],
381
- dst_k=host_pool.k_buffer[layer_id],
382
- src_v=self.v_buffer[layer_id],
383
- dst_v=host_pool.v_buffer[layer_id],
384
- src_indices=device_indices,
385
- dst_indices=host_indices,
386
- io_backend=io_backend,
387
- page_size=self.page_size,
388
- item_size=self.token_stride,
389
- )
390
-
391
340
  def _get_key_buffer(self, layer_id: int):
392
341
  # for internal use of referencing
393
342
  if self.store_dtype != self.dtype:
@@ -602,16 +551,6 @@ class SWAKVPool(KVCache):
602
551
  layer_id_override=layer_id_pool,
603
552
  )
604
553
 
605
- def load_from_host_per_layer(
606
- self, host_pool, host_indices, device_indices, layer_id, io_backend
607
- ):
608
- raise NotImplementedError("HiCache not supported for SWAKVPool.")
609
-
610
- def backup_to_host_all_layer(
611
- self, host_pool, host_indices, device_indices, io_backend
612
- ):
613
- raise NotImplementedError("HiCache not supported for SWAKVPool.")
614
-
615
554
 
616
555
  class AscendTokenToKVPool(MHATokenToKVPool):
617
556
 
@@ -823,7 +762,11 @@ class MLATokenToKVPool(KVCache):
823
762
  for _ in range(layer_num)
824
763
  ]
825
764
 
826
- self.token_stride = kv_lora_rank + qk_rope_head_dim
765
+ self.data_ptrs = torch.tensor(
766
+ [x.data_ptr() for x in self.kv_buffer],
767
+ dtype=torch.uint64,
768
+ device=self.device,
769
+ )
827
770
  self.layer_transfer_counter = None
828
771
 
829
772
  kv_size = self.get_kv_size_bytes()
@@ -909,38 +852,6 @@ class MLATokenToKVPool(KVCache):
909
852
  self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
910
853
  )
911
854
 
912
- def load_from_host_per_layer(
913
- self, host_pool, host_indices, device_indices, layer_id, io_backend
914
- ):
915
- transfer_kv_per_layer_mla(
916
- src=host_pool.kv_buffer[layer_id],
917
- dst=self.kv_buffer[layer_id],
918
- src_indices=host_indices,
919
- dst_indices=device_indices,
920
- io_backend=io_backend,
921
- page_size=self.page_size,
922
- item_size=self.token_stride,
923
- )
924
-
925
- def backup_to_host_all_layer(
926
- self, host_pool, host_indices, device_indices, io_backend
927
- ):
928
- # todo: specialized all layer kernels for the layer-non-contiguous memory pool
929
- for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
930
- if layer_id - self.start_layer >= len(host_pool.kv_buffer):
931
- raise ValueError(
932
- f"Layer ID {layer_id} exceeds the number of layers in host pool."
933
- )
934
- transfer_kv_per_layer_mla(
935
- src=self.kv_buffer[layer_id],
936
- dst=host_pool.kv_buffer[layer_id],
937
- src_indices=device_indices,
938
- dst_indices=host_indices,
939
- io_backend=io_backend,
940
- page_size=self.page_size,
941
- item_size=self.token_stride,
942
- )
943
-
944
855
  def get_cpu_copy(self, indices):
945
856
  torch.cuda.synchronize()
946
857
  kv_cache_cpu = []
@@ -1131,20 +1042,6 @@ class DoubleSparseTokenToKVPool(KVCache):
1131
1042
  self.v_buffer[layer_id - self.start_layer][loc] = cache_v
1132
1043
  self.label_buffer[layer_id - self.start_layer][loc] = cache_label
1133
1044
 
1134
- def load_from_host_per_layer(
1135
- self, host_pool, host_indices, device_indices, layer_id, io_backend
1136
- ):
1137
- raise NotImplementedError(
1138
- "HiCache not supported for DoubleSparseTokenToKVPool."
1139
- )
1140
-
1141
- def backup_to_host_all_layer(
1142
- self, host_pool, host_indices, device_indices, io_backend
1143
- ):
1144
- raise NotImplementedError(
1145
- "HiCache not supported for DoubleSparseTokenToKVPool."
1146
- )
1147
-
1148
1045
 
1149
1046
  @triton.jit
1150
1047
  def copy_all_layer_kv_cache(
@@ -8,6 +8,21 @@ import psutil
8
8
  import torch
9
9
 
10
10
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
11
+ from sglang.srt.utils import is_npu
12
+
13
+ _is_npu = is_npu()
14
+ if not _is_npu:
15
+ from sgl_kernel.kvcacheio import (
16
+ transfer_kv_all_layer,
17
+ transfer_kv_all_layer_lf_pf,
18
+ transfer_kv_all_layer_mla,
19
+ transfer_kv_all_layer_mla_lf_pf,
20
+ transfer_kv_direct,
21
+ transfer_kv_per_layer,
22
+ transfer_kv_per_layer_mla,
23
+ transfer_kv_per_layer_mla_pf_lf,
24
+ transfer_kv_per_layer_pf_lf,
25
+ )
11
26
 
12
27
  logger = logging.getLogger(__name__)
13
28
 
@@ -42,15 +57,18 @@ class HostKVCache(abc.ABC):
42
57
  device_pool: KVCache,
43
58
  host_to_device_ratio: float,
44
59
  host_size: int,
60
+ page_size: int,
61
+ layout: str,
45
62
  pin_memory: bool,
46
63
  device: str,
47
- page_size: int,
48
64
  ):
49
65
  self.device_pool = device_pool
50
- self.dtype = device_pool.store_dtype
66
+ self.page_size = page_size
67
+ self.layout = layout
51
68
  self.pin_memory = pin_memory
52
69
  self.device = device
53
- self.page_size = page_size
70
+
71
+ self.dtype = device_pool.store_dtype
54
72
  self.size_per_token = self.get_size_per_token()
55
73
  if host_size > 0:
56
74
  self.size = int(host_size * 1e9 // self.size_per_token)
@@ -98,6 +116,24 @@ class HostKVCache(abc.ABC):
98
116
  def init_kv_buffer(self):
99
117
  raise NotImplementedError()
100
118
 
119
+ @abc.abstractmethod
120
+ def load_to_device_per_layer(
121
+ self, device_pool, host_indices, device_indices, layer_id, io_backend
122
+ ) -> None:
123
+ """
124
+ Load KV data from the host memory pool to the device memory pool for a specific layer.
125
+ """
126
+ raise NotImplementedError()
127
+
128
+ @abc.abstractmethod
129
+ def backup_from_device_all_layer(
130
+ self, device_pool, host_indices, device_indices, io_backend
131
+ ) -> None:
132
+ """
133
+ Backup KV data from the device memory pool to the host memory pool for all layers.
134
+ """
135
+ raise NotImplementedError()
136
+
101
137
  @abc.abstractmethod
102
138
  def get_flat_data_page(self, index) -> torch.Tensor:
103
139
  """
@@ -105,6 +141,14 @@ class HostKVCache(abc.ABC):
105
141
  """
106
142
  raise NotImplementedError()
107
143
 
144
+ @abc.abstractmethod
145
+ def get_dummy_flat_data_page(self) -> torch.Tensor:
146
+ """
147
+ Get a dummy flat data page from the host memory pool.
148
+ This is used for prefetching or initializing empty pages.
149
+ """
150
+ raise NotImplementedError()
151
+
108
152
  @abc.abstractmethod
109
153
  def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
110
154
  """
@@ -230,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache):
230
274
  host_to_device_ratio: float,
231
275
  host_size: int,
232
276
  page_size: int,
277
+ layout: str,
233
278
  pin_memory: bool = True,
234
279
  device: str = "cpu",
235
280
  ):
236
281
  super().__init__(
237
- device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
282
+ device_pool,
283
+ host_to_device_ratio,
284
+ host_size,
285
+ page_size,
286
+ layout,
287
+ pin_memory,
288
+ device,
289
+ )
290
+ self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)]
291
+ self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)]
292
+ self.k_data_ptrs = torch.tensor(
293
+ [x.data_ptr() for x in self.k_data_refs],
294
+ dtype=torch.uint64,
295
+ device=self.device_pool.device,
296
+ )
297
+ self.v_data_ptrs = torch.tensor(
298
+ [x.data_ptr() for x in self.v_data_refs],
299
+ dtype=torch.uint64,
300
+ device=self.device_pool.device,
238
301
  )
239
302
 
240
303
  def get_size_per_token(self):
@@ -245,25 +308,156 @@ class MHATokenToKVPoolHost(HostKVCache):
245
308
  return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
246
309
 
247
310
  def init_kv_buffer(self):
311
+ if self.layout == "layer_first":
312
+ dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
313
+ elif self.layout == "page_first":
314
+ dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
315
+ else:
316
+ raise ValueError(f"Unsupported layout: {self.layout}")
317
+ self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
318
+ self.layout_dim = self.token_stride_size * self.layer_num
248
319
  return torch.empty(
249
- (2, self.layer_num, self.size, self.head_num, self.head_dim),
320
+ dims,
250
321
  dtype=self.dtype,
251
322
  device=self.device,
252
323
  pin_memory=self.pin_memory,
253
324
  )
254
325
 
255
- # todo, page first memory layout
326
+ @property
327
+ def k_buffer(self):
328
+ return self.kv_buffer[0]
329
+
330
+ @property
331
+ def v_buffer(self):
332
+ return self.kv_buffer[1]
333
+
334
+ def load_to_device_per_layer(
335
+ self,
336
+ device_pool,
337
+ host_indices,
338
+ device_indices,
339
+ layer_id,
340
+ io_backend,
341
+ ):
342
+ if io_backend == "kernel":
343
+ if self.layout == "layer_first":
344
+ transfer_kv_per_layer(
345
+ src_k=self.k_buffer[layer_id],
346
+ dst_k=device_pool.k_buffer[layer_id],
347
+ src_v=self.v_buffer[layer_id],
348
+ dst_v=device_pool.v_buffer[layer_id],
349
+ src_indices=host_indices,
350
+ dst_indices=device_indices,
351
+ item_size=self.token_stride_size,
352
+ )
353
+ elif self.layout == "page_first":
354
+ transfer_kv_per_layer_pf_lf(
355
+ src_k=self.k_buffer,
356
+ dst_k=device_pool.k_buffer[layer_id],
357
+ src_v=self.v_buffer,
358
+ dst_v=device_pool.v_buffer[layer_id],
359
+ src_indices=host_indices,
360
+ dst_indices=device_indices,
361
+ item_size=self.token_stride_size,
362
+ src_layout_dim=self.layout_dim,
363
+ )
364
+ else:
365
+ raise ValueError(f"Unsupported layout: {self.layout}")
366
+ elif io_backend == "direct":
367
+ assert (
368
+ self.layout == "layer_first"
369
+ ), f"Direct IO backend only supports layer_first layout."
370
+ transfer_kv_direct(
371
+ src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
372
+ dst_layers=[
373
+ device_pool.k_buffer[layer_id],
374
+ device_pool.v_buffer[layer_id],
375
+ ],
376
+ src_indices=host_indices,
377
+ dst_indices=device_indices,
378
+ page_size=self.page_size,
379
+ )
380
+ else:
381
+ raise ValueError(f"Unsupported IO backend: {io_backend}")
382
+
383
+ def backup_from_device_all_layer(
384
+ self, device_pool, host_indices, device_indices, io_backend
385
+ ):
386
+ if io_backend == "kernel":
387
+ if self.layout == "layer_first":
388
+ transfer_kv_all_layer(
389
+ src_k_layers=device_pool.k_data_ptrs,
390
+ dst_k_layers=self.k_data_ptrs,
391
+ src_v_layers=device_pool.v_data_ptrs,
392
+ dst_v_layers=self.v_data_ptrs,
393
+ src_indices=device_indices,
394
+ dst_indices=host_indices,
395
+ item_size=self.token_stride_size,
396
+ num_layers=self.layer_num,
397
+ )
398
+ elif self.layout == "page_first":
399
+ transfer_kv_all_layer_lf_pf(
400
+ src_k_layers=device_pool.k_data_ptrs,
401
+ dst_k=self.k_buffer,
402
+ src_v_layers=device_pool.v_data_ptrs,
403
+ dst_v=self.v_buffer,
404
+ src_indices=device_indices,
405
+ dst_indices=host_indices,
406
+ item_size=self.token_stride_size,
407
+ dst_layout_dim=self.layout_dim,
408
+ num_layers=self.layer_num,
409
+ )
410
+ else:
411
+ raise ValueError(f"Unsupported layout: {self.layout}")
412
+ elif io_backend == "direct":
413
+ assert (
414
+ self.layout == "layer_first"
415
+ ), f"Direct IO backend only supports layer_first layout."
416
+ transfer_kv_direct(
417
+ src_layers=device_pool.k_buffer + device_pool.v_buffer,
418
+ dst_layers=self.k_data_refs + self.v_data_refs,
419
+ src_indices=device_indices,
420
+ dst_indices=host_indices,
421
+ page_size=self.page_size,
422
+ )
423
+ else:
424
+ raise ValueError(f"Unsupported IO backend: {io_backend}")
425
+
256
426
  def get_flat_data_page(self, index) -> torch.Tensor:
257
- return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
427
+ if self.layout == "layer_first":
428
+ return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
429
+ elif self.layout == "page_first":
430
+ return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
431
+ else:
432
+ raise ValueError(f"Unsupported layout: {self.layout}")
433
+
434
+ def get_dummy_flat_data_page(self) -> torch.Tensor:
435
+ return torch.zeros(
436
+ (2, self.layer_num, self.page_size, self.head_num, self.head_dim),
437
+ dtype=self.dtype,
438
+ device=self.device,
439
+ pin_memory=self.pin_memory,
440
+ ).flatten()
258
441
 
259
442
  def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
260
- self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
261
- 2,
262
- self.layer_num,
263
- self.page_size,
264
- self.head_num,
265
- self.head_dim,
266
- )
443
+ if self.layout == "layer_first":
444
+ self.kv_buffer[:, :, index : index + self.page_size, :, :] = (
445
+ data_page.reshape(
446
+ 2,
447
+ self.layer_num,
448
+ self.page_size,
449
+ self.head_num,
450
+ self.head_dim,
451
+ )
452
+ )
453
+ elif self.layout == "page_first":
454
+ self.kv_buffer[:, index : index + self.page_size, :, :, :] = (
455
+ data_page.reshape(
456
+ 2, self.page_size, self.layer_num, self.head_num, self.head_dim
457
+ )
458
+ )
459
+ else:
460
+ raise ValueError(f"Unsupported layout: {self.layout}")
267
461
 
268
462
  def get_buffer_meta(self, keys, indices):
269
463
  ptr_list = []
@@ -302,14 +496,6 @@ class MHATokenToKVPoolHost(HostKVCache):
302
496
  element_size_list = [element_size] * len(key_list)
303
497
  return key_list, ptr_list, element_size_list
304
498
 
305
- @property
306
- def k_buffer(self):
307
- return self.kv_buffer[0]
308
-
309
- @property
310
- def v_buffer(self):
311
- return self.kv_buffer[1]
312
-
313
499
 
314
500
  class MLATokenToKVPoolHost(HostKVCache):
315
501
  device_pool: MLATokenToKVPool
@@ -320,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache):
320
506
  host_to_device_ratio: float,
321
507
  host_size: int,
322
508
  page_size: int,
509
+ layout: str,
323
510
  pin_memory: bool = True,
324
511
  device: str = "cpu",
325
512
  ):
326
513
  super().__init__(
327
- device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
514
+ device_pool,
515
+ host_to_device_ratio,
516
+ host_size,
517
+ page_size,
518
+ layout,
519
+ pin_memory,
520
+ device,
521
+ )
522
+ self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)]
523
+ self.data_ptrs = torch.tensor(
524
+ [x.data_ptr() for x in self.data_refs],
525
+ dtype=torch.uint64,
526
+ device=self.device_pool.device,
328
527
  )
329
528
 
330
529
  def get_size_per_token(self):
@@ -340,28 +539,146 @@ class MLATokenToKVPoolHost(HostKVCache):
340
539
  )
341
540
 
342
541
  def init_kv_buffer(self):
343
- return torch.empty(
344
- (
542
+ if self.layout == "layer_first":
543
+ dims = (
345
544
  self.layer_num,
346
545
  self.size,
347
546
  1,
348
547
  self.kv_lora_rank + self.qk_rope_head_dim,
349
- ),
548
+ )
549
+ elif self.layout == "page_first":
550
+ dims = (
551
+ self.size,
552
+ self.layer_num,
553
+ 1,
554
+ self.kv_lora_rank + self.qk_rope_head_dim,
555
+ )
556
+ else:
557
+ raise ValueError(f"Unsupported layout: {self.layout}")
558
+ self.token_stride_size = (
559
+ self.kv_lora_rank + self.qk_rope_head_dim
560
+ ) * self.dtype.itemsize
561
+ self.layout_dim = self.token_stride_size * self.layer_num
562
+
563
+ return torch.empty(
564
+ dims,
350
565
  dtype=self.dtype,
351
566
  device=self.device,
352
567
  pin_memory=self.pin_memory,
353
568
  )
354
569
 
570
+ def load_to_device_per_layer(
571
+ self, device_pool, host_indices, device_indices, layer_id, io_backend
572
+ ):
573
+ if io_backend == "kernel":
574
+ if self.layout == "layer_first":
575
+ transfer_kv_per_layer_mla(
576
+ src=self.kv_buffer[layer_id],
577
+ dst=device_pool.kv_buffer[layer_id],
578
+ src_indices=host_indices,
579
+ dst_indices=device_indices,
580
+ item_size=self.token_stride_size,
581
+ )
582
+ elif self.layout == "page_first":
583
+ transfer_kv_per_layer_mla_pf_lf(
584
+ src=self.kv_buffer,
585
+ dst=device_pool.kv_buffer[layer_id],
586
+ src_indices=host_indices,
587
+ dst_indices=device_indices,
588
+ item_size=self.token_stride_size,
589
+ src_layout_dim=self.layout_dim,
590
+ )
591
+ else:
592
+ raise ValueError(f"Unsupported layout: {self.layout}")
593
+ elif io_backend == "direct":
594
+ assert (
595
+ self.layout == "layer_first"
596
+ ), f"Direct IO backend only supports layer_first layout."
597
+ transfer_kv_direct(
598
+ src_layers=[self.kv_buffer[layer_id]],
599
+ dst_layers=[device_pool.kv_buffer[layer_id]],
600
+ src_indices=host_indices,
601
+ dst_indices=device_indices,
602
+ page_size=self.page_size,
603
+ )
604
+
605
+ def backup_from_device_all_layer(
606
+ self, device_pool, host_indices, device_indices, io_backend
607
+ ):
608
+ if io_backend == "kernel":
609
+ if self.layout == "layer_first":
610
+ transfer_kv_all_layer_mla(
611
+ src_layers=device_pool.data_ptrs,
612
+ dst_layers=self.data_ptrs,
613
+ src_indices=device_indices,
614
+ dst_indices=host_indices,
615
+ item_size=self.token_stride_size,
616
+ num_layers=self.layer_num,
617
+ )
618
+ elif self.layout == "page_first":
619
+ transfer_kv_all_layer_mla_lf_pf(
620
+ src_layers=device_pool.data_ptrs,
621
+ dst_k=self.kv_buffer,
622
+ src_indices=device_indices,
623
+ dst_indices=host_indices,
624
+ item_size=self.token_stride_size,
625
+ dst_layout_dim=self.layout_dim,
626
+ num_layers=self.layer_num,
627
+ )
628
+ else:
629
+ raise ValueError(f"Unsupported layout: {self.layout}")
630
+ elif io_backend == "direct":
631
+ assert (
632
+ self.layout == "layer_first"
633
+ ), f"Direct IO backend only supports layer_first layout."
634
+ transfer_kv_direct(
635
+ src_layers=device_pool.kv_buffer,
636
+ dst_layers=self.data_refs,
637
+ src_indices=device_indices,
638
+ dst_indices=host_indices,
639
+ page_size=self.page_size,
640
+ )
641
+ else:
642
+ raise ValueError(f"Unsupported IO backend: {io_backend}")
643
+
355
644
  def get_flat_data_page(self, index) -> torch.Tensor:
356
- return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
645
+ if self.layout == "layer_first":
646
+ return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
647
+ elif self.layout == "page_first":
648
+ return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
649
+ else:
650
+ raise ValueError(f"Unsupported layout: {self.layout}")
651
+
652
+ def get_dummy_flat_data_page(self) -> torch.Tensor:
653
+ return torch.zeros(
654
+ (
655
+ self.layer_num,
656
+ self.page_size,
657
+ 1,
658
+ self.kv_lora_rank + self.qk_rope_head_dim,
659
+ ),
660
+ dtype=self.dtype,
661
+ device=self.device,
662
+ pin_memory=self.pin_memory,
663
+ ).flatten()
357
664
 
358
665
  def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
359
- self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
360
- self.layer_num,
361
- self.page_size,
362
- 1,
363
- self.kv_lora_rank + self.qk_rope_head_dim,
364
- )
666
+ if self.layout == "layer_first":
667
+ self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
668
+ self.layer_num,
669
+ self.page_size,
670
+ 1,
671
+ self.kv_lora_rank + self.qk_rope_head_dim,
672
+ )
673
+ elif self.layout == "page_first":
674
+ self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape(
675
+ self.page_size,
676
+ self.layer_num,
677
+ 1,
678
+ self.kv_lora_rank + self.qk_rope_head_dim,
679
+ )
680
+ else:
681
+ raise ValueError(f"Unsupported layout: {self.layout}")
365
682
 
366
683
  def get_buffer_meta(self, keys, indices):
367
684
  ptr_list = []