sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,21 @@ import psutil
8
8
  import torch
9
9
 
10
10
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
11
+ from sglang.srt.utils import is_npu
12
+
13
+ _is_npu = is_npu()
14
+ if not _is_npu:
15
+ from sgl_kernel.kvcacheio import (
16
+ transfer_kv_all_layer,
17
+ transfer_kv_all_layer_lf_pf,
18
+ transfer_kv_all_layer_mla,
19
+ transfer_kv_all_layer_mla_lf_pf,
20
+ transfer_kv_direct,
21
+ transfer_kv_per_layer,
22
+ transfer_kv_per_layer_mla,
23
+ transfer_kv_per_layer_mla_pf_lf,
24
+ transfer_kv_per_layer_pf_lf,
25
+ )
11
26
 
12
27
  logger = logging.getLogger(__name__)
13
28
 
@@ -25,7 +40,6 @@ def synchronized(debug_only=False):
25
40
  @wraps(func)
26
41
  def wrapper(self, *args, **kwargs):
27
42
  if (not debug_only) or self.debug:
28
- return func(self, *args, **kwargs)
29
43
  with self.lock:
30
44
  return func(self, *args, **kwargs)
31
45
  else:
@@ -43,15 +57,18 @@ class HostKVCache(abc.ABC):
43
57
  device_pool: KVCache,
44
58
  host_to_device_ratio: float,
45
59
  host_size: int,
60
+ page_size: int,
61
+ layout: str,
46
62
  pin_memory: bool,
47
63
  device: str,
48
- page_size: int,
49
64
  ):
50
65
  self.device_pool = device_pool
51
- self.dtype = device_pool.store_dtype
66
+ self.page_size = page_size
67
+ self.layout = layout
52
68
  self.pin_memory = pin_memory
53
69
  self.device = device
54
- self.page_size = page_size
70
+
71
+ self.dtype = device_pool.store_dtype
55
72
  self.size_per_token = self.get_size_per_token()
56
73
  if host_size > 0:
57
74
  self.size = int(host_size * 1e9 // self.size_per_token)
@@ -99,6 +116,24 @@ class HostKVCache(abc.ABC):
99
116
  def init_kv_buffer(self):
100
117
  raise NotImplementedError()
101
118
 
119
+ @abc.abstractmethod
120
+ def load_to_device_per_layer(
121
+ self, device_pool, host_indices, device_indices, layer_id, io_backend
122
+ ) -> None:
123
+ """
124
+ Load KV data from the host memory pool to the device memory pool for a specific layer.
125
+ """
126
+ raise NotImplementedError()
127
+
128
+ @abc.abstractmethod
129
+ def backup_from_device_all_layer(
130
+ self, device_pool, host_indices, device_indices, io_backend
131
+ ) -> None:
132
+ """
133
+ Backup KV data from the device memory pool to the host memory pool for all layers.
134
+ """
135
+ raise NotImplementedError()
136
+
102
137
  @abc.abstractmethod
103
138
  def get_flat_data_page(self, index) -> torch.Tensor:
104
139
  """
@@ -106,6 +141,14 @@ class HostKVCache(abc.ABC):
106
141
  """
107
142
  raise NotImplementedError()
108
143
 
144
+ @abc.abstractmethod
145
+ def get_dummy_flat_data_page(self) -> torch.Tensor:
146
+ """
147
+ Get a dummy flat data page from the host memory pool.
148
+ This is used for prefetching or initializing empty pages.
149
+ """
150
+ raise NotImplementedError()
151
+
109
152
  @abc.abstractmethod
110
153
  def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
111
154
  """
@@ -181,6 +224,15 @@ class HostKVCache(abc.ABC):
181
224
  )
182
225
  self.mem_state[indices] = MemoryStateInt.BACKUP
183
226
 
227
+ @synchronized(debug_only=True)
228
+ def update_prefetch(self, indices: torch.Tensor):
229
+ if not self.is_reserved(indices):
230
+ raise ValueError(
231
+ f"The host memory slots should be in RESERVED state before turning into BACKUP. "
232
+ f"Current state: {self.get_state(indices)}"
233
+ )
234
+ self.mem_state[indices] = MemoryStateInt.BACKUP
235
+
184
236
  @synchronized(debug_only=True)
185
237
  def update_synced(self, indices: torch.Tensor):
186
238
  self.mem_state[indices] = MemoryStateInt.SYNCED
@@ -222,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache):
222
274
  host_to_device_ratio: float,
223
275
  host_size: int,
224
276
  page_size: int,
277
+ layout: str,
225
278
  pin_memory: bool = True,
226
279
  device: str = "cpu",
227
280
  ):
228
281
  super().__init__(
229
- device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
282
+ device_pool,
283
+ host_to_device_ratio,
284
+ host_size,
285
+ page_size,
286
+ layout,
287
+ pin_memory,
288
+ device,
289
+ )
290
+ self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)]
291
+ self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)]
292
+ self.k_data_ptrs = torch.tensor(
293
+ [x.data_ptr() for x in self.k_data_refs],
294
+ dtype=torch.uint64,
295
+ device=self.device_pool.device,
296
+ )
297
+ self.v_data_ptrs = torch.tensor(
298
+ [x.data_ptr() for x in self.v_data_refs],
299
+ dtype=torch.uint64,
300
+ device=self.device_pool.device,
230
301
  )
231
302
 
232
303
  def get_size_per_token(self):
@@ -237,26 +308,21 @@ class MHATokenToKVPoolHost(HostKVCache):
237
308
  return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
238
309
 
239
310
  def init_kv_buffer(self):
311
+ if self.layout == "layer_first":
312
+ dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
313
+ elif self.layout == "page_first":
314
+ dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
315
+ else:
316
+ raise ValueError(f"Unsupported layout: {self.layout}")
317
+ self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
318
+ self.layout_dim = self.token_stride_size * self.layer_num
240
319
  return torch.empty(
241
- (2, self.layer_num, self.size, self.head_num, self.head_dim),
320
+ dims,
242
321
  dtype=self.dtype,
243
322
  device=self.device,
244
323
  pin_memory=self.pin_memory,
245
324
  )
246
325
 
247
- # todo, page first memory layout
248
- def get_flat_data_page(self, index) -> torch.Tensor:
249
- return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
250
-
251
- def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
252
- self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
253
- 2,
254
- self.layer_num,
255
- self.page_size,
256
- self.head_num,
257
- self.head_dim,
258
- )
259
-
260
326
  @property
261
327
  def k_buffer(self):
262
328
  return self.kv_buffer[0]
@@ -265,6 +331,171 @@ class MHATokenToKVPoolHost(HostKVCache):
265
331
  def v_buffer(self):
266
332
  return self.kv_buffer[1]
267
333
 
334
+ def load_to_device_per_layer(
335
+ self,
336
+ device_pool,
337
+ host_indices,
338
+ device_indices,
339
+ layer_id,
340
+ io_backend,
341
+ ):
342
+ if io_backend == "kernel":
343
+ if self.layout == "layer_first":
344
+ transfer_kv_per_layer(
345
+ src_k=self.k_buffer[layer_id],
346
+ dst_k=device_pool.k_buffer[layer_id],
347
+ src_v=self.v_buffer[layer_id],
348
+ dst_v=device_pool.v_buffer[layer_id],
349
+ src_indices=host_indices,
350
+ dst_indices=device_indices,
351
+ item_size=self.token_stride_size,
352
+ )
353
+ elif self.layout == "page_first":
354
+ transfer_kv_per_layer_pf_lf(
355
+ src_k=self.k_buffer,
356
+ dst_k=device_pool.k_buffer[layer_id],
357
+ src_v=self.v_buffer,
358
+ dst_v=device_pool.v_buffer[layer_id],
359
+ src_indices=host_indices,
360
+ dst_indices=device_indices,
361
+ item_size=self.token_stride_size,
362
+ src_layout_dim=self.layout_dim,
363
+ )
364
+ else:
365
+ raise ValueError(f"Unsupported layout: {self.layout}")
366
+ elif io_backend == "direct":
367
+ assert (
368
+ self.layout == "layer_first"
369
+ ), f"Direct IO backend only supports layer_first layout."
370
+ transfer_kv_direct(
371
+ src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
372
+ dst_layers=[
373
+ device_pool.k_buffer[layer_id],
374
+ device_pool.v_buffer[layer_id],
375
+ ],
376
+ src_indices=host_indices,
377
+ dst_indices=device_indices,
378
+ page_size=self.page_size,
379
+ )
380
+ else:
381
+ raise ValueError(f"Unsupported IO backend: {io_backend}")
382
+
383
+ def backup_from_device_all_layer(
384
+ self, device_pool, host_indices, device_indices, io_backend
385
+ ):
386
+ if io_backend == "kernel":
387
+ if self.layout == "layer_first":
388
+ transfer_kv_all_layer(
389
+ src_k_layers=device_pool.k_data_ptrs,
390
+ dst_k_layers=self.k_data_ptrs,
391
+ src_v_layers=device_pool.v_data_ptrs,
392
+ dst_v_layers=self.v_data_ptrs,
393
+ src_indices=device_indices,
394
+ dst_indices=host_indices,
395
+ item_size=self.token_stride_size,
396
+ num_layers=self.layer_num,
397
+ )
398
+ elif self.layout == "page_first":
399
+ transfer_kv_all_layer_lf_pf(
400
+ src_k_layers=device_pool.k_data_ptrs,
401
+ dst_k=self.k_buffer,
402
+ src_v_layers=device_pool.v_data_ptrs,
403
+ dst_v=self.v_buffer,
404
+ src_indices=device_indices,
405
+ dst_indices=host_indices,
406
+ item_size=self.token_stride_size,
407
+ dst_layout_dim=self.layout_dim,
408
+ num_layers=self.layer_num,
409
+ )
410
+ else:
411
+ raise ValueError(f"Unsupported layout: {self.layout}")
412
+ elif io_backend == "direct":
413
+ assert (
414
+ self.layout == "layer_first"
415
+ ), f"Direct IO backend only supports layer_first layout."
416
+ transfer_kv_direct(
417
+ src_layers=device_pool.k_buffer + device_pool.v_buffer,
418
+ dst_layers=self.k_data_refs + self.v_data_refs,
419
+ src_indices=device_indices,
420
+ dst_indices=host_indices,
421
+ page_size=self.page_size,
422
+ )
423
+ else:
424
+ raise ValueError(f"Unsupported IO backend: {io_backend}")
425
+
426
+ def get_flat_data_page(self, index) -> torch.Tensor:
427
+ if self.layout == "layer_first":
428
+ return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
429
+ elif self.layout == "page_first":
430
+ return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
431
+ else:
432
+ raise ValueError(f"Unsupported layout: {self.layout}")
433
+
434
+ def get_dummy_flat_data_page(self) -> torch.Tensor:
435
+ return torch.zeros(
436
+ (2, self.layer_num, self.page_size, self.head_num, self.head_dim),
437
+ dtype=self.dtype,
438
+ device=self.device,
439
+ pin_memory=self.pin_memory,
440
+ ).flatten()
441
+
442
+ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
443
+ if self.layout == "layer_first":
444
+ self.kv_buffer[:, :, index : index + self.page_size, :, :] = (
445
+ data_page.reshape(
446
+ 2,
447
+ self.layer_num,
448
+ self.page_size,
449
+ self.head_num,
450
+ self.head_dim,
451
+ )
452
+ )
453
+ elif self.layout == "page_first":
454
+ self.kv_buffer[:, index : index + self.page_size, :, :, :] = (
455
+ data_page.reshape(
456
+ 2, self.page_size, self.layer_num, self.head_num, self.head_dim
457
+ )
458
+ )
459
+ else:
460
+ raise ValueError(f"Unsupported layout: {self.layout}")
461
+
462
+ def get_buffer_meta(self, keys, indices):
463
+ ptr_list = []
464
+ key_list = []
465
+ kv_buffer_data_ptr = self.kv_buffer.data_ptr()
466
+ v_offset = (
467
+ self.layer_num
468
+ * self.size
469
+ * self.head_num
470
+ * self.head_dim
471
+ * self.dtype.itemsize
472
+ )
473
+ for index in range(0, len(indices), self.page_size):
474
+ for layer_id in range(self.layer_num):
475
+ k_ptr = (
476
+ kv_buffer_data_ptr
477
+ + indices[index]
478
+ * self.head_num
479
+ * self.head_dim
480
+ * self.dtype.itemsize
481
+ + layer_id
482
+ * self.size
483
+ * self.head_num
484
+ * self.head_dim
485
+ * self.dtype.itemsize
486
+ )
487
+ v_ptr = k_ptr + v_offset
488
+ ptr_list.append(k_ptr)
489
+ ptr_list.append(v_ptr)
490
+ key_ = keys[index // self.page_size]
491
+ key_list.append(f"{key_}_{layer_id}_k")
492
+ key_list.append(f"{key_}_{layer_id}_v")
493
+ element_size = (
494
+ self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
495
+ )
496
+ element_size_list = [element_size] * len(key_list)
497
+ return key_list, ptr_list, element_size_list
498
+
268
499
 
269
500
  class MLATokenToKVPoolHost(HostKVCache):
270
501
  device_pool: MLATokenToKVPool
@@ -275,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache):
275
506
  host_to_device_ratio: float,
276
507
  host_size: int,
277
508
  page_size: int,
509
+ layout: str,
278
510
  pin_memory: bool = True,
279
511
  device: str = "cpu",
280
512
  ):
281
513
  super().__init__(
282
- device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
514
+ device_pool,
515
+ host_to_device_ratio,
516
+ host_size,
517
+ page_size,
518
+ layout,
519
+ pin_memory,
520
+ device,
521
+ )
522
+ self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)]
523
+ self.data_ptrs = torch.tensor(
524
+ [x.data_ptr() for x in self.data_refs],
525
+ dtype=torch.uint64,
526
+ device=self.device_pool.device,
283
527
  )
284
528
 
285
529
  def get_size_per_token(self):
@@ -295,25 +539,170 @@ class MLATokenToKVPoolHost(HostKVCache):
295
539
  )
296
540
 
297
541
  def init_kv_buffer(self):
298
- return torch.empty(
299
- (
542
+ if self.layout == "layer_first":
543
+ dims = (
300
544
  self.layer_num,
301
545
  self.size,
302
546
  1,
303
547
  self.kv_lora_rank + self.qk_rope_head_dim,
304
- ),
548
+ )
549
+ elif self.layout == "page_first":
550
+ dims = (
551
+ self.size,
552
+ self.layer_num,
553
+ 1,
554
+ self.kv_lora_rank + self.qk_rope_head_dim,
555
+ )
556
+ else:
557
+ raise ValueError(f"Unsupported layout: {self.layout}")
558
+ self.token_stride_size = (
559
+ self.kv_lora_rank + self.qk_rope_head_dim
560
+ ) * self.dtype.itemsize
561
+ self.layout_dim = self.token_stride_size * self.layer_num
562
+
563
+ return torch.empty(
564
+ dims,
305
565
  dtype=self.dtype,
306
566
  device=self.device,
307
567
  pin_memory=self.pin_memory,
308
568
  )
309
569
 
570
+ def load_to_device_per_layer(
571
+ self, device_pool, host_indices, device_indices, layer_id, io_backend
572
+ ):
573
+ if io_backend == "kernel":
574
+ if self.layout == "layer_first":
575
+ transfer_kv_per_layer_mla(
576
+ src=self.kv_buffer[layer_id],
577
+ dst=device_pool.kv_buffer[layer_id],
578
+ src_indices=host_indices,
579
+ dst_indices=device_indices,
580
+ item_size=self.token_stride_size,
581
+ )
582
+ elif self.layout == "page_first":
583
+ transfer_kv_per_layer_mla_pf_lf(
584
+ src=self.kv_buffer,
585
+ dst=device_pool.kv_buffer[layer_id],
586
+ src_indices=host_indices,
587
+ dst_indices=device_indices,
588
+ item_size=self.token_stride_size,
589
+ src_layout_dim=self.layout_dim,
590
+ )
591
+ else:
592
+ raise ValueError(f"Unsupported layout: {self.layout}")
593
+ elif io_backend == "direct":
594
+ assert (
595
+ self.layout == "layer_first"
596
+ ), f"Direct IO backend only supports layer_first layout."
597
+ transfer_kv_direct(
598
+ src_layers=[self.kv_buffer[layer_id]],
599
+ dst_layers=[device_pool.kv_buffer[layer_id]],
600
+ src_indices=host_indices,
601
+ dst_indices=device_indices,
602
+ page_size=self.page_size,
603
+ )
604
+
605
+ def backup_from_device_all_layer(
606
+ self, device_pool, host_indices, device_indices, io_backend
607
+ ):
608
+ if io_backend == "kernel":
609
+ if self.layout == "layer_first":
610
+ transfer_kv_all_layer_mla(
611
+ src_layers=device_pool.data_ptrs,
612
+ dst_layers=self.data_ptrs,
613
+ src_indices=device_indices,
614
+ dst_indices=host_indices,
615
+ item_size=self.token_stride_size,
616
+ num_layers=self.layer_num,
617
+ )
618
+ elif self.layout == "page_first":
619
+ transfer_kv_all_layer_mla_lf_pf(
620
+ src_layers=device_pool.data_ptrs,
621
+ dst_k=self.kv_buffer,
622
+ src_indices=device_indices,
623
+ dst_indices=host_indices,
624
+ item_size=self.token_stride_size,
625
+ dst_layout_dim=self.layout_dim,
626
+ num_layers=self.layer_num,
627
+ )
628
+ else:
629
+ raise ValueError(f"Unsupported layout: {self.layout}")
630
+ elif io_backend == "direct":
631
+ assert (
632
+ self.layout == "layer_first"
633
+ ), f"Direct IO backend only supports layer_first layout."
634
+ transfer_kv_direct(
635
+ src_layers=device_pool.kv_buffer,
636
+ dst_layers=self.data_refs,
637
+ src_indices=device_indices,
638
+ dst_indices=host_indices,
639
+ page_size=self.page_size,
640
+ )
641
+ else:
642
+ raise ValueError(f"Unsupported IO backend: {io_backend}")
643
+
310
644
  def get_flat_data_page(self, index) -> torch.Tensor:
311
- return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
645
+ if self.layout == "layer_first":
646
+ return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
647
+ elif self.layout == "page_first":
648
+ return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
649
+ else:
650
+ raise ValueError(f"Unsupported layout: {self.layout}")
651
+
652
+ def get_dummy_flat_data_page(self) -> torch.Tensor:
653
+ return torch.zeros(
654
+ (
655
+ self.layer_num,
656
+ self.page_size,
657
+ 1,
658
+ self.kv_lora_rank + self.qk_rope_head_dim,
659
+ ),
660
+ dtype=self.dtype,
661
+ device=self.device,
662
+ pin_memory=self.pin_memory,
663
+ ).flatten()
312
664
 
313
665
  def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
314
- self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
315
- self.layer_num,
316
- self.page_size,
317
- 1,
318
- self.kv_lora_rank + self.qk_rope_head_dim,
666
+ if self.layout == "layer_first":
667
+ self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
668
+ self.layer_num,
669
+ self.page_size,
670
+ 1,
671
+ self.kv_lora_rank + self.qk_rope_head_dim,
672
+ )
673
+ elif self.layout == "page_first":
674
+ self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape(
675
+ self.page_size,
676
+ self.layer_num,
677
+ 1,
678
+ self.kv_lora_rank + self.qk_rope_head_dim,
679
+ )
680
+ else:
681
+ raise ValueError(f"Unsupported layout: {self.layout}")
682
+
683
+ def get_buffer_meta(self, keys, indices):
684
+ ptr_list = []
685
+ key_list = []
686
+ kv_buffer_data_ptr = self.kv_buffer.data_ptr()
687
+ for index in range(0, len(indices), self.page_size):
688
+ for layer_id in range(self.layer_num):
689
+ k_ptr = (
690
+ kv_buffer_data_ptr
691
+ + indices[index]
692
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
693
+ * self.dtype.itemsize
694
+ + layer_id
695
+ * self.size
696
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
697
+ * self.dtype.itemsize
698
+ )
699
+ ptr_list.append(k_ptr)
700
+ key_ = keys[index // self.page_size]
701
+ key_list.append(f"{key_}_{layer_id}_k")
702
+ element_size = (
703
+ self.dtype.itemsize
704
+ * self.page_size
705
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
319
706
  )
707
+ element_size_list = [element_size] * len(key_list)
708
+ return key_list, ptr_list, element_size_list