sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -26,24 +26,15 @@ KVCache actually holds the physical kv cache.
26
26
 
27
27
  import abc
28
28
  import logging
29
- import threading
30
- from enum import IntEnum
31
- from functools import wraps
32
29
  from typing import List, Optional, Tuple, Union
33
30
 
34
31
  import numpy as np
35
- import psutil
36
32
  import torch
37
33
  import triton
38
34
  import triton.language as tl
39
35
 
40
36
  from sglang.srt.layers.radix_attention import RadixAttention
41
- from sglang.srt.utils import (
42
- debug_timing,
43
- get_compiler_backend,
44
- is_cuda,
45
- next_power_of_2,
46
- )
37
+ from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
47
38
 
48
39
  logger = logging.getLogger(__name__)
49
40
 
@@ -150,15 +141,12 @@ class KVCache(abc.ABC):
150
141
  ) -> None:
151
142
  raise NotImplementedError()
152
143
 
153
- @abc.abstractmethod
154
144
  def get_flat_data(self, indices):
155
145
  raise NotImplementedError()
156
146
 
157
- @abc.abstractmethod
158
147
  def transfer(self, indices, flat_data):
159
148
  raise NotImplementedError()
160
149
 
161
- @abc.abstractmethod
162
150
  def transfer_per_layer(self, indices, flat_data, layer_id):
163
151
  raise NotImplementedError()
164
152
 
@@ -191,6 +179,9 @@ class TokenToKVPoolAllocator:
191
179
  def available_size(self):
192
180
  return len(self.free_slots)
193
181
 
182
+ def debug_print(self) -> str:
183
+ return ""
184
+
194
185
  def get_kvcache(self):
195
186
  return self._kvcache
196
187
 
@@ -234,6 +225,12 @@ class TokenToKVPoolAllocator:
234
225
  self.is_not_in_free_group = True
235
226
  self.free_group = []
236
227
 
228
+ def get_cpu_copy(self, indices):
229
+ return self._kvcache.get_cpu_copy(indices)
230
+
231
+ def load_cpu_copy(self, kv_cache_cpu, indices):
232
+ return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
233
+
237
234
 
238
235
  class MHATokenToKVPool(KVCache):
239
236
 
@@ -265,9 +262,11 @@ class MHATokenToKVPool(KVCache):
265
262
  self.head_dim = head_dim
266
263
  self._create_buffers()
267
264
 
265
+ # used for chunked cpu-offloading
266
+ self.chunk_size = 8192
268
267
  self.layer_transfer_counter = None
269
268
  self.device_module = torch.get_device_module(self.device)
270
- self.alt_stream = self.device_module.Stream() if is_cuda else None
269
+ self.alt_stream = self.device_module.Stream() if _is_cuda else None
271
270
 
272
271
  k_size, v_size = self.get_kv_size_bytes()
273
272
  logger.info(
@@ -295,6 +294,19 @@ class MHATokenToKVPool(KVCache):
295
294
  for _ in range(self.layer_num)
296
295
  ]
297
296
 
297
+ self.data_ptrs = torch.tensor(
298
+ [x.data_ptr() for x in self.k_buffer + self.v_buffer],
299
+ dtype=torch.uint64,
300
+ device=self.device,
301
+ )
302
+ self.data_strides = torch.tensor(
303
+ [
304
+ np.prod(x.shape[1:]) * x.dtype.itemsize
305
+ for x in self.k_buffer + self.v_buffer
306
+ ],
307
+ device=self.device,
308
+ )
309
+
298
310
  def _clear_buffers(self):
299
311
  del self.k_buffer
300
312
  del self.v_buffer
@@ -315,20 +327,61 @@ class MHATokenToKVPool(KVCache):
315
327
  # layer_num x [seq_len, head_num, head_dim]
316
328
  # layer_num x [page_num, page_size, head_num, head_dim]
317
329
  kv_data_ptrs = [
318
- self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
319
- ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
330
+ self.get_key_buffer(i).data_ptr()
331
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
332
+ ] + [
333
+ self.get_value_buffer(i).data_ptr()
334
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
335
+ ]
320
336
  kv_data_lens = [
321
- self.get_key_buffer(i).nbytes for i in range(self.layer_num)
322
- ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
337
+ self.get_key_buffer(i).nbytes
338
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
339
+ ] + [
340
+ self.get_value_buffer(i).nbytes
341
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
342
+ ]
323
343
  kv_item_lens = [
324
344
  self.get_key_buffer(i)[0].nbytes * self.page_size
325
- for i in range(self.layer_num)
345
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
326
346
  ] + [
327
347
  self.get_value_buffer(i)[0].nbytes * self.page_size
328
- for i in range(self.layer_num)
348
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
329
349
  ]
330
350
  return kv_data_ptrs, kv_data_lens, kv_item_lens
331
351
 
352
+ def get_cpu_copy(self, indices):
353
+ torch.cuda.synchronize()
354
+ kv_cache_cpu = []
355
+ for layer_id in range(self.layer_num):
356
+ kv_cache_cpu.append([])
357
+ for i in range(0, len(indices), self.chunk_size):
358
+ chunk_indices = indices[i : i + self.chunk_size]
359
+ k_cpu = self.k_buffer[layer_id][chunk_indices].to(
360
+ "cpu", non_blocking=True
361
+ )
362
+ v_cpu = self.v_buffer[layer_id][chunk_indices].to(
363
+ "cpu", non_blocking=True
364
+ )
365
+ kv_cache_cpu[-1].append([k_cpu, v_cpu])
366
+ torch.cuda.synchronize()
367
+ return kv_cache_cpu
368
+
369
+ def load_cpu_copy(self, kv_cache_cpu, indices):
370
+ torch.cuda.synchronize()
371
+ for layer_id in range(self.layer_num):
372
+ for i in range(0, len(indices), self.chunk_size):
373
+ chunk_indices = indices[i : i + self.chunk_size]
374
+ k_cpu, v_cpu = (
375
+ kv_cache_cpu[layer_id][i // self.chunk_size][0],
376
+ kv_cache_cpu[layer_id][i // self.chunk_size][1],
377
+ )
378
+ assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
379
+ k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
380
+ v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
381
+ self.k_buffer[layer_id][chunk_indices] = k_chunk
382
+ self.v_buffer[layer_id][chunk_indices] = v_chunk
383
+ torch.cuda.synchronize()
384
+
332
385
  # Todo: different memory layout
333
386
  def get_flat_data(self, indices):
334
387
  # prepare a large chunk of contiguous data for efficient transfer
@@ -411,35 +464,15 @@ class MHATokenToKVPool(KVCache):
411
464
  self.k_buffer[layer_id - self.start_layer][loc] = cache_k
412
465
  self.v_buffer[layer_id - self.start_layer][loc] = cache_v
413
466
 
414
-
415
- @torch.compile
416
- def fused_downcast(
417
- cache_k: torch.Tensor,
418
- cache_v: torch.Tensor,
419
- k_scale: torch.Tensor,
420
- v_scale: torch.Tensor,
421
- dtype: torch.dtype,
422
- store_dtype: torch.dtype,
423
- max_fp8: float,
424
- min_fp8: float,
425
- ):
426
- cache_k = cache_k / k_scale
427
- cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
428
- cache_v = cache_v / v_scale
429
- cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
430
- cache_k = cache_k.to(dtype)
431
- cache_v = cache_v.to(dtype)
432
- cache_k = cache_k.view(store_dtype)
433
- cache_v = cache_v.view(store_dtype)
434
- return cache_k, cache_v
435
-
436
-
437
- # This compiled version is slower in the unit test
438
- # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
439
- @torch.compile(dynamic=True, backend=get_compiler_backend())
440
- def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
441
- dst_1[loc] = src_1.to(dtype).view(store_dtype)
442
- dst_2[loc] = src_2.to(dtype).view(store_dtype)
467
+ def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
468
+ copy_all_layer_kv_cache[(len(self.data_ptrs),)](
469
+ self.data_ptrs,
470
+ self.data_strides,
471
+ tgt_loc,
472
+ src_loc,
473
+ len(tgt_loc),
474
+ next_power_of_2(len(tgt_loc)),
475
+ )
443
476
 
444
477
 
445
478
  @triton.jit
@@ -733,368 +766,39 @@ class DoubleSparseTokenToKVPool(KVCache):
733
766
  pass
734
767
 
735
768
 
736
- class MemoryStateInt(IntEnum):
737
- IDLE = 0
738
- RESERVED = 1
739
- PROTECTED = 2
740
- SYNCED = 3
741
- BACKUP = 4
742
-
743
-
744
- def synchronized(debug_only=False):
745
- def _decorator(func):
746
- @wraps(func)
747
- def wrapper(self, *args, **kwargs):
748
- if (not debug_only) or self.debug:
749
- return func(self, *args, **kwargs)
750
- with self.lock:
751
- return func(self, *args, **kwargs)
752
- else:
753
- return True
754
-
755
- return wrapper
756
-
757
- return _decorator
758
-
759
-
760
- class HostKVCache(abc.ABC):
761
-
762
- def __init__(
763
- self,
764
- device_pool: KVCache,
765
- host_to_device_ratio: float,
766
- host_size: int,
767
- pin_memory: bool,
768
- device: str,
769
- page_size: int,
770
- ):
771
- self.device_pool = device_pool
772
- self.dtype = device_pool.store_dtype
773
- self.pin_memory = pin_memory
774
- self.device = device
775
- self.page_size = page_size
776
- self.size_per_token = self.get_size_per_token()
777
- if host_size > 0:
778
- self.size = int(host_size * 1e9 // self.size_per_token)
779
- else:
780
- self.size = int(device_pool.size * host_to_device_ratio)
781
- # Align the host memory pool size to the page size
782
- self.size = self.size - (self.size % self.page_size)
783
- self.start_layer = device_pool.start_layer
784
- self.end_layer = device_pool.end_layer
785
-
786
- assert (
787
- self.size > device_pool.size
788
- ), "The host memory should be larger than the device memory with the current protocol"
789
-
790
- # Verify there is enough available host memory.
791
- host_mem = psutil.virtual_memory()
792
- requested_bytes = self.size * self.size_per_token
793
- # preserve at least 10GB for other usage
794
- ten_gb = 10 * (1024**3)
795
- if requested_bytes > host_mem.available - ten_gb:
796
- raise ValueError(
797
- f"Not enough host memory available. Requesting "
798
- f"{requested_bytes / 1e9:.2f} GB but only have "
799
- f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
800
- f"size of the hierarchical cache."
801
- )
802
- else:
803
- logger.info(
804
- f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
805
- )
806
-
807
- self.kv_buffer = self.init_kv_buffer()
808
-
809
- # A lock for synchronized operations on memory allocation and state transitions.
810
- self.lock = threading.RLock()
811
- self.debug = logger.isEnabledFor(logging.DEBUG)
812
- self.clear()
813
-
814
- @abc.abstractmethod
815
- def get_size_per_token(self):
816
- raise NotImplementedError()
817
-
818
- @abc.abstractmethod
819
- def init_kv_buffer(self):
820
- raise NotImplementedError()
821
-
822
- @abc.abstractmethod
823
- def transfer(self, indices, flat_data):
824
- raise NotImplementedError()
825
-
826
- @abc.abstractmethod
827
- def get_flat_data(self, indices):
828
- raise NotImplementedError()
829
-
830
- @abc.abstractmethod
831
- def get_flat_data_by_layer(self, indices, layer_id):
832
- raise NotImplementedError()
833
-
834
- @abc.abstractmethod
835
- def assign_flat_data(self, indices, flat_data):
836
- raise NotImplementedError()
837
-
838
- @synchronized()
839
- def clear(self):
840
- # Initialize memory states and tracking structures.
841
- self.mem_state = torch.zeros(
842
- (self.size,), dtype=torch.uint8, device=self.device
843
- )
844
- self.free_slots = torch.arange(self.size, dtype=torch.int64)
845
-
846
- def available_size(self):
847
- return len(self.free_slots)
848
-
849
- @synchronized()
850
- def alloc(self, need_size: int) -> torch.Tensor:
851
- if need_size > self.available_size():
852
- return None
853
-
854
- select_index = self.free_slots[:need_size]
855
- self.free_slots = self.free_slots[need_size:]
856
-
857
- if self.debug:
858
- self.mem_state[select_index] = MemoryStateInt.RESERVED
859
-
860
- return select_index
861
-
862
- @synchronized()
863
- def free(self, indices: torch.Tensor) -> int:
864
- self.free_slots = torch.cat([self.free_slots, indices])
865
- if self.debug:
866
- self.mem_state[indices] = MemoryStateInt.IDLE
867
- return len(indices)
868
-
869
- @synchronized(debug_only=True)
870
- def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
871
- assert len(indices) > 0, "The indices should not be empty"
872
- states = self.mem_state[indices]
873
- assert (
874
- states == states[0]
875
- ).all(), "The memory slots should have the same state {}".format(states)
876
- return MemoryStateInt(states[0].item())
877
-
878
- @synchronized(debug_only=True)
879
- def is_reserved(self, indices: torch.Tensor) -> bool:
880
- return self.get_state(indices) == MemoryStateInt.RESERVED
881
-
882
- @synchronized(debug_only=True)
883
- def is_protected(self, indices: torch.Tensor) -> bool:
884
- return self.get_state(indices) == MemoryStateInt.PROTECTED
885
-
886
- @synchronized(debug_only=True)
887
- def is_synced(self, indices: torch.Tensor) -> bool:
888
- return self.get_state(indices) == MemoryStateInt.SYNCED
889
-
890
- @synchronized(debug_only=True)
891
- def is_backup(self, indices: torch.Tensor) -> bool:
892
- return self.get_state(indices) == MemoryStateInt.BACKUP
893
-
894
- @synchronized(debug_only=True)
895
- def update_backup(self, indices: torch.Tensor):
896
- if not self.is_synced(indices):
897
- raise ValueError(
898
- f"The host memory slots should be in SYNCED state before turning into BACKUP. "
899
- f"Current state: {self.get_state(indices)}"
900
- )
901
- self.mem_state[indices] = MemoryStateInt.BACKUP
902
-
903
- @synchronized(debug_only=True)
904
- def update_synced(self, indices: torch.Tensor):
905
- self.mem_state[indices] = MemoryStateInt.SYNCED
906
-
907
- @synchronized(debug_only=True)
908
- def protect_write(self, indices: torch.Tensor):
909
- if not self.is_reserved(indices):
910
- raise ValueError(
911
- f"The host memory slots should be RESERVED before write operations. "
912
- f"Current state: {self.get_state(indices)}"
913
- )
914
- self.mem_state[indices] = MemoryStateInt.PROTECTED
915
-
916
- @synchronized(debug_only=True)
917
- def protect_load(self, indices: torch.Tensor):
918
- if not self.is_backup(indices):
919
- raise ValueError(
920
- f"The host memory slots should be in BACKUP state before load operations. "
921
- f"Current state: {self.get_state(indices)}"
922
- )
923
- self.mem_state[indices] = MemoryStateInt.PROTECTED
924
-
925
- @synchronized(debug_only=True)
926
- def complete_io(self, indices: torch.Tensor):
927
- if not self.is_protected(indices):
928
- raise ValueError(
929
- f"The host memory slots should be PROTECTED during I/O operations. "
930
- f"Current state: {self.get_state(indices)}"
931
- )
932
- self.mem_state[indices] = MemoryStateInt.SYNCED
933
-
934
-
935
- class MHATokenToKVPoolHost(HostKVCache):
936
- device_pool: MHATokenToKVPool
937
-
938
- def __init__(
939
- self,
940
- device_pool: MHATokenToKVPool,
941
- host_to_device_ratio: float,
942
- host_size: int,
943
- page_size: int,
944
- pin_memory: bool = True,
945
- device: str = "cpu",
946
- ):
947
- super().__init__(
948
- device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
949
- )
950
-
951
- def get_size_per_token(self):
952
- self.head_num = self.device_pool.head_num
953
- self.head_dim = self.device_pool.head_dim
954
- self.layer_num = self.device_pool.layer_num
955
-
956
- return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
957
-
958
- def init_kv_buffer(self):
959
- return torch.empty(
960
- (2, self.layer_num, self.size, self.head_num, self.head_dim),
961
- dtype=self.dtype,
962
- device=self.device,
963
- pin_memory=self.pin_memory,
964
- )
965
-
966
- @debug_timing
967
- def transfer(self, indices, flat_data):
968
- # backup prepared data from device to host
969
- self.kv_buffer[:, :, indices] = flat_data.to(
970
- device=self.device, non_blocking=False
971
- )
972
-
973
- def get_flat_data(self, indices):
974
- return self.kv_buffer[:, :, indices]
975
-
976
- def get_flat_data_by_layer(self, indices, layer_id):
977
- return self.kv_buffer[:, layer_id - self.start_layer, indices]
978
-
979
- def assign_flat_data(self, indices, flat_data):
980
- self.kv_buffer[:, :, indices] = flat_data
981
-
982
- def write_page_all_layers(self, host_indices, device_indices, device_pool):
983
- device_indices_cpu = device_indices[:: self.page_size].cpu()
984
- for i in range(len(device_indices_cpu)):
985
- h_index = host_indices[i * self.page_size]
986
- d_index = device_indices_cpu[i]
987
- for j in range(self.layer_num):
988
- self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
989
- device_pool.k_buffer[j][d_index : d_index + self.page_size],
990
- non_blocking=True,
991
- )
992
- self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
993
- device_pool.v_buffer[j][d_index : d_index + self.page_size],
994
- non_blocking=True,
995
- )
996
-
997
- def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
998
- device_indices_cpu = device_indices[:: self.page_size].cpu()
999
- for i in range(len(device_indices_cpu)):
1000
- h_index = host_indices[i * self.page_size]
1001
- d_index = device_indices_cpu[i]
1002
- device_pool.k_buffer[layer_id - self.start_layer][
1003
- d_index : d_index + self.page_size
1004
- ].copy_(
1005
- self.kv_buffer[
1006
- 0, layer_id - self.start_layer, h_index : h_index + self.page_size
1007
- ],
1008
- non_blocking=True,
1009
- )
1010
- device_pool.v_buffer[layer_id - self.start_layer][
1011
- d_index : d_index + self.page_size
1012
- ].copy_(
1013
- self.kv_buffer[
1014
- 1, layer_id - self.start_layer, h_index : h_index + self.page_size
1015
- ],
1016
- non_blocking=True,
1017
- )
1018
-
769
+ @triton.jit
770
+ def copy_all_layer_kv_cache(
771
+ data_ptrs,
772
+ strides,
773
+ tgt_loc_ptr,
774
+ src_loc_ptr,
775
+ num_locs,
776
+ num_locs_upper: tl.constexpr,
777
+ ):
778
+ BLOCK_SIZE: tl.constexpr = 128
1019
779
 
1020
- class MLATokenToKVPoolHost(HostKVCache):
1021
- device_pool: MLATokenToKVPool
780
+ bid = tl.program_id(0)
781
+ stride = tl.load(strides + bid)
1022
782
 
1023
- def __init__(
1024
- self,
1025
- device_pool: MLATokenToKVPool,
1026
- host_to_device_ratio: float,
1027
- host_size: int,
1028
- page_size: int,
1029
- pin_memory: bool = True,
1030
- device: str = "cpu",
1031
- ):
1032
- super().__init__(
1033
- device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
1034
- )
783
+ data_ptr = tl.load(data_ptrs + bid)
784
+ data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
1035
785
 
1036
- def get_size_per_token(self):
1037
- self.kv_lora_rank = self.device_pool.kv_lora_rank
1038
- self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
1039
- self.layer_num = self.device_pool.layer_num
786
+ num_locs_offset = tl.arange(0, num_locs_upper)
787
+ tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
788
+ src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
1040
789
 
1041
- return (
1042
- (self.kv_lora_rank + self.qk_rope_head_dim)
1043
- * 1
1044
- * self.dtype.itemsize
1045
- * self.layer_num
1046
- )
790
+ # NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
791
+ # because this copy is an inplace operation.
1047
792
 
1048
- def init_kv_buffer(self):
1049
- return torch.empty(
1050
- (
1051
- self.layer_num,
1052
- self.size,
1053
- 1,
1054
- self.kv_lora_rank + self.qk_rope_head_dim,
1055
- ),
1056
- dtype=self.dtype,
1057
- device=self.device,
1058
- pin_memory=self.pin_memory,
793
+ num_loop = tl.cdiv(stride, BLOCK_SIZE)
794
+ for i in range(num_loop):
795
+ copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
796
+ mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
797
+ value = tl.load(
798
+ data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
1059
799
  )
1060
-
1061
- @debug_timing
1062
- def transfer(self, indices, flat_data):
1063
- # backup prepared data from device to host
1064
- self.kv_buffer[:, indices] = flat_data.to(
1065
- device=self.device, non_blocking=False
800
+ tl.store(
801
+ data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
802
+ value,
803
+ mask=mask,
1066
804
  )
1067
-
1068
- def get_flat_data(self, indices):
1069
- return self.kv_buffer[:, indices]
1070
-
1071
- def get_flat_data_by_layer(self, indices, layer_id):
1072
- return self.kv_buffer[layer_id - self.start_layer, indices]
1073
-
1074
- def assign_flat_data(self, indices, flat_data):
1075
- self.kv_buffer[:, indices] = flat_data
1076
-
1077
- def write_page_all_layers(self, host_indices, device_indices, device_pool):
1078
- device_indices_cpu = device_indices[:: self.page_size].cpu()
1079
- for i in range(len(device_indices_cpu)):
1080
- h_index = host_indices[i * self.page_size]
1081
- d_index = device_indices_cpu[i]
1082
- for j in range(self.layer_num):
1083
- self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
1084
- device_pool.kv_buffer[j][d_index : d_index + self.page_size],
1085
- non_blocking=True,
1086
- )
1087
-
1088
- def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
1089
- device_indices_cpu = device_indices[:: self.page_size].cpu()
1090
- for i in range(len(device_indices_cpu)):
1091
- h_index = host_indices[i * self.page_size]
1092
- d_index = device_indices_cpu[i]
1093
- device_pool.kv_buffer[layer_id - self.start_layer][
1094
- d_index : d_index + self.page_size
1095
- ].copy_(
1096
- self.kv_buffer[
1097
- layer_id - self.start_layer, h_index : h_index + self.page_size
1098
- ],
1099
- non_blocking=True,
1100
- )