sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -27,16 +27,18 @@ KVCache actually holds the physical kv cache.
27
27
  import abc
28
28
  import logging
29
29
  from contextlib import nullcontext
30
- from typing import List, Optional, Tuple, Union
30
+ from typing import Dict, List, Optional, Tuple, Union
31
31
 
32
32
  import numpy as np
33
33
  import torch
34
+ import torch.distributed as dist
34
35
  import triton
35
36
  import triton.language as tl
37
+ from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
36
38
 
37
39
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
38
40
  from sglang.srt.layers.radix_attention import RadixAttention
39
- from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
41
+ from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
40
42
 
41
43
  logger = logging.getLogger(__name__)
42
44
 
@@ -66,6 +68,7 @@ class ReqToTokenPool:
66
68
  self.req_to_token = torch.zeros(
67
69
  (size, max_context_len), dtype=torch.int32, device=device
68
70
  )
71
+
69
72
  self.free_slots = list(range(size))
70
73
 
71
74
  def write(self, indices, values):
@@ -121,6 +124,7 @@ class KVCache(abc.ABC):
121
124
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
122
125
  enable=enable_memory_saver
123
126
  )
127
+ self.mem_usage = 0
124
128
 
125
129
  # used for chunked cpu-offloading
126
130
  self.cpu_offloading_chunk_size = 8192
@@ -147,13 +151,16 @@ class KVCache(abc.ABC):
147
151
  ) -> None:
148
152
  raise NotImplementedError()
149
153
 
150
- def get_flat_data(self, indices):
151
- raise NotImplementedError()
152
-
153
- def transfer(self, indices, flat_data):
154
+ @abc.abstractmethod
155
+ def load_from_host_per_layer(
156
+ self, host_pool, host_indices, device_indices, layer_id, io_backend
157
+ ):
154
158
  raise NotImplementedError()
155
159
 
156
- def transfer_per_layer(self, indices, flat_data, layer_id):
160
+ @abc.abstractmethod
161
+ def backup_to_host_all_layer(
162
+ self, host_pool, host_indices, device_indices, io_backend
163
+ ):
157
164
  raise NotImplementedError()
158
165
 
159
166
  def register_layer_transfer_counter(self, layer_transfer_counter):
@@ -191,7 +198,6 @@ class MHATokenToKVPool(KVCache):
191
198
  start_layer,
192
199
  end_layer,
193
200
  )
194
-
195
201
  self.head_num = head_num
196
202
  self.head_dim = head_dim
197
203
 
@@ -218,6 +224,7 @@ class MHATokenToKVPool(KVCache):
218
224
  logger.info(
219
225
  f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
220
226
  )
227
+ self.mem_usage = (k_size + v_size) / GB
221
228
 
222
229
  def _create_buffers(self):
223
230
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
@@ -244,7 +251,7 @@ class MHATokenToKVPool(KVCache):
244
251
  )
245
252
  for _ in range(self.layer_num)
246
253
  ]
247
-
254
+ self.token_stride = self.head_num * self.head_dim
248
255
  self.data_ptrs = torch.tensor(
249
256
  [x.data_ptr() for x in self.k_buffer + self.v_buffer],
250
257
  dtype=torch.uint64,
@@ -278,24 +285,24 @@ class MHATokenToKVPool(KVCache):
278
285
  # layer_num x [seq_len, head_num, head_dim]
279
286
  # layer_num x [page_num, page_size, head_num, head_dim]
280
287
  kv_data_ptrs = [
281
- self.get_key_buffer(i).data_ptr()
288
+ self._get_key_buffer(i).data_ptr()
282
289
  for i in range(self.start_layer, self.start_layer + self.layer_num)
283
290
  ] + [
284
- self.get_value_buffer(i).data_ptr()
291
+ self._get_value_buffer(i).data_ptr()
285
292
  for i in range(self.start_layer, self.start_layer + self.layer_num)
286
293
  ]
287
294
  kv_data_lens = [
288
- self.get_key_buffer(i).nbytes
295
+ self._get_key_buffer(i).nbytes
289
296
  for i in range(self.start_layer, self.start_layer + self.layer_num)
290
297
  ] + [
291
- self.get_value_buffer(i).nbytes
298
+ self._get_value_buffer(i).nbytes
292
299
  for i in range(self.start_layer, self.start_layer + self.layer_num)
293
300
  ]
294
301
  kv_item_lens = [
295
- self.get_key_buffer(i)[0].nbytes * self.page_size
302
+ self._get_key_buffer(i)[0].nbytes * self.page_size
296
303
  for i in range(self.start_layer, self.start_layer + self.layer_num)
297
304
  ] + [
298
- self.get_value_buffer(i)[0].nbytes * self.page_size
305
+ self._get_value_buffer(i)[0].nbytes * self.page_size
299
306
  for i in range(self.start_layer, self.start_layer + self.layer_num)
300
307
  ]
301
308
  return kv_data_ptrs, kv_data_lens, kv_item_lens
@@ -338,49 +345,73 @@ class MHATokenToKVPool(KVCache):
338
345
  self.v_buffer[layer_id][chunk_indices] = v_chunk
339
346
  torch.cuda.synchronize()
340
347
 
341
- # Todo: different memory layout
342
- def get_flat_data(self, indices):
343
- # prepare a large chunk of contiguous data for efficient transfer
344
- flatten = torch.stack(
345
- [
346
- torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
347
- torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
348
- ]
348
+ def load_from_host_per_layer(
349
+ self,
350
+ host_pool,
351
+ host_indices,
352
+ device_indices,
353
+ layer_id,
354
+ io_backend,
355
+ ):
356
+ transfer_kv_per_layer(
357
+ src_k=host_pool.k_buffer[layer_id],
358
+ dst_k=self.k_buffer[layer_id],
359
+ src_v=host_pool.v_buffer[layer_id],
360
+ dst_v=self.v_buffer[layer_id],
361
+ src_indices=host_indices,
362
+ dst_indices=device_indices,
363
+ io_backend=io_backend,
364
+ page_size=self.page_size,
365
+ item_size=self.token_stride,
349
366
  )
350
- return flatten
351
-
352
- @debug_timing
353
- def transfer(self, indices, flat_data):
354
- # transfer prepared data from host to device
355
- flat_data = flat_data.to(device=self.device, non_blocking=False)
356
- k_data, v_data = flat_data[0], flat_data[1]
357
- for i in range(self.layer_num):
358
- self.k_buffer[i][indices] = k_data[i]
359
- self.v_buffer[i][indices] = v_data[i]
360
-
361
- def transfer_per_layer(self, indices, flat_data, layer_id):
362
- # transfer prepared data from host to device
363
- flat_data = flat_data.to(device=self.device, non_blocking=False)
364
- k_data, v_data = flat_data[0], flat_data[1]
365
- self.k_buffer[layer_id - self.start_layer][indices] = k_data
366
- self.v_buffer[layer_id - self.start_layer][indices] = v_data
367
367
 
368
- def get_key_buffer(self, layer_id: int):
369
- if self.layer_transfer_counter is not None:
370
- self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
368
+ def backup_to_host_all_layer(
369
+ self, host_pool, host_indices, device_indices, io_backend
370
+ ):
371
+ # todo: specialized all layer kernels for the layer-non-contiguous memory pool
372
+ for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
373
+ if layer_id - self.start_layer >= len(host_pool.k_buffer):
374
+ raise ValueError(
375
+ f"Layer ID {layer_id} exceeds the number of layers in host pool."
376
+ )
377
+ transfer_kv_per_layer(
378
+ src_k=self.k_buffer[layer_id],
379
+ dst_k=host_pool.k_buffer[layer_id],
380
+ src_v=self.v_buffer[layer_id],
381
+ dst_v=host_pool.v_buffer[layer_id],
382
+ src_indices=device_indices,
383
+ dst_indices=host_indices,
384
+ io_backend=io_backend,
385
+ page_size=self.page_size,
386
+ item_size=self.token_stride,
387
+ )
371
388
 
389
+ def _get_key_buffer(self, layer_id: int):
390
+ # for internal use of referencing
372
391
  if self.store_dtype != self.dtype:
373
392
  return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
374
393
  return self.k_buffer[layer_id - self.start_layer]
375
394
 
376
- def get_value_buffer(self, layer_id: int):
395
+ def get_key_buffer(self, layer_id: int):
396
+ # note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
397
+ # it is supposed to be used only by attention backend not for information purpose
398
+ # same applies to get_value_buffer and get_kv_buffer
377
399
  if self.layer_transfer_counter is not None:
378
400
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
379
401
 
402
+ return self._get_key_buffer(layer_id)
403
+
404
+ def _get_value_buffer(self, layer_id: int):
405
+ # for internal use of referencing
380
406
  if self.store_dtype != self.dtype:
381
407
  return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
382
408
  return self.v_buffer[layer_id - self.start_layer]
383
409
 
410
+ def get_value_buffer(self, layer_id: int):
411
+ if self.layer_transfer_counter is not None:
412
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
413
+ return self._get_value_buffer(layer_id)
414
+
384
415
  def get_kv_buffer(self, layer_id: int):
385
416
  return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
386
417
 
@@ -392,10 +423,14 @@ class MHATokenToKVPool(KVCache):
392
423
  cache_v: torch.Tensor,
393
424
  k_scale: Optional[float] = None,
394
425
  v_scale: Optional[float] = None,
426
+ layer_id_override: Optional[int] = None,
395
427
  ):
396
428
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
397
429
 
398
- layer_id = layer.layer_id
430
+ if layer_id_override is not None:
431
+ layer_id = layer_id_override
432
+ else:
433
+ layer_id = layer.layer_id
399
434
  if cache_k.dtype != self.dtype:
400
435
  if k_scale is not None:
401
436
  cache_k.div_(k_scale)
@@ -431,6 +466,206 @@ class MHATokenToKVPool(KVCache):
431
466
  )
432
467
 
433
468
 
469
+ class SWAKVPool(KVCache):
470
+ """KV cache with separate pools for full and SWA attention layers."""
471
+
472
+ def __init__(
473
+ self,
474
+ size: int,
475
+ size_swa: int,
476
+ dtype: torch.dtype,
477
+ head_num: int,
478
+ head_dim: int,
479
+ swa_attention_layer_ids: List[int],
480
+ full_attention_layer_ids: List[int],
481
+ enable_kvcache_transpose: bool,
482
+ device: str,
483
+ ):
484
+ self.size = size
485
+ self.size_swa = size_swa
486
+ self.dtype = dtype
487
+ self.device = device
488
+ self.swa_layer_nums = len(swa_attention_layer_ids)
489
+ self.full_layer_nums = len(full_attention_layer_ids)
490
+ self.page_size = 1
491
+ # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
492
+ assert not enable_kvcache_transpose
493
+ TokenToKVPoolClass = MHATokenToKVPool
494
+ self.swa_kv_pool = TokenToKVPoolClass(
495
+ size=size_swa,
496
+ page_size=self.page_size,
497
+ dtype=dtype,
498
+ head_num=head_num,
499
+ head_dim=head_dim,
500
+ layer_num=self.swa_layer_nums,
501
+ device=device,
502
+ enable_memory_saver=False,
503
+ )
504
+ self.full_kv_pool = TokenToKVPoolClass(
505
+ size=size,
506
+ page_size=self.page_size,
507
+ dtype=dtype,
508
+ head_num=head_num,
509
+ head_dim=head_dim,
510
+ layer_num=self.full_layer_nums,
511
+ device=device,
512
+ enable_memory_saver=False,
513
+ )
514
+ self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
515
+ for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
516
+ self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
517
+ for swa_layer_id, global_layer_id in enumerate(swa_attention_layer_ids):
518
+ self.layers_mapping[global_layer_id] = (swa_layer_id, True)
519
+ self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
520
+
521
+ def get_kv_size_bytes(self):
522
+ raise NotImplementedError
523
+
524
+ def get_contiguous_buf_infos(self):
525
+ full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
526
+ self.full_kv_pool.get_contiguous_buf_infos()
527
+ )
528
+ swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
529
+ self.swa_kv_pool.get_contiguous_buf_infos()
530
+ )
531
+
532
+ kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
533
+ kv_data_lens = full_kv_data_lens + swa_kv_data_lens
534
+ kv_item_lens = full_kv_item_lens + swa_kv_item_lens
535
+
536
+ return kv_data_ptrs, kv_data_lens, kv_item_lens
537
+
538
+ def get_key_buffer(self, layer_id: int):
539
+ layer_id_pool, is_swa = self.layers_mapping[layer_id]
540
+ if is_swa:
541
+ return self.swa_kv_pool.get_key_buffer(layer_id_pool)
542
+ else:
543
+ return self.full_kv_pool.get_key_buffer(layer_id_pool)
544
+
545
+ def get_value_buffer(self, layer_id: int):
546
+ layer_id_pool, is_swa = self.layers_mapping[layer_id]
547
+ if is_swa:
548
+ return self.swa_kv_pool.get_value_buffer(layer_id_pool)
549
+ else:
550
+ return self.full_kv_pool.get_value_buffer(layer_id_pool)
551
+
552
+ def get_kv_buffer(self, layer_id: int):
553
+ layer_id_pool, is_swa = self.layers_mapping[layer_id]
554
+ if is_swa:
555
+ return self.swa_kv_pool.get_kv_buffer(layer_id_pool)
556
+ else:
557
+ return self.full_kv_pool.get_kv_buffer(layer_id_pool)
558
+
559
+ def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
560
+ assert self.full_to_swa_index_mapping is not None
561
+ return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
562
+
563
+ def set_kv_buffer(
564
+ self,
565
+ layer: RadixAttention,
566
+ loc: torch.Tensor,
567
+ cache_k: torch.Tensor,
568
+ cache_v: torch.Tensor,
569
+ k_scale: float = 1.0,
570
+ v_scale: float = 1.0,
571
+ ):
572
+
573
+ layer_id = layer.layer_id
574
+ layer_id_pool, is_swa = self.layers_mapping[layer_id]
575
+ if is_swa:
576
+ if self.full_to_swa_index_mapping is not None:
577
+ loc = self.translate_loc_from_full_to_swa(loc)
578
+ self.swa_kv_pool.set_kv_buffer(
579
+ None,
580
+ loc,
581
+ cache_k,
582
+ cache_v,
583
+ k_scale,
584
+ v_scale,
585
+ layer_id_override=layer_id_pool,
586
+ )
587
+ else:
588
+ self.full_kv_pool.set_kv_buffer(
589
+ None,
590
+ loc,
591
+ cache_k,
592
+ cache_v,
593
+ k_scale,
594
+ v_scale,
595
+ layer_id_override=layer_id_pool,
596
+ )
597
+
598
+
599
+ class AscendTokenToKVPool(MHATokenToKVPool):
600
+
601
+ def _create_buffers(self):
602
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
603
+ # [size, head_num, head_dim] for each layer
604
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
605
+ self.k_buffer = [
606
+ torch.zeros(
607
+ (
608
+ self.size // self.page_size + 1,
609
+ self.page_size,
610
+ self.head_num,
611
+ self.head_dim,
612
+ ),
613
+ dtype=self.store_dtype,
614
+ device=self.device,
615
+ )
616
+ for _ in range(self.layer_num)
617
+ ]
618
+ self.v_buffer = [
619
+ torch.zeros(
620
+ (
621
+ self.size // self.page_size + 1,
622
+ self.page_size,
623
+ self.head_num,
624
+ self.head_dim,
625
+ ),
626
+ dtype=self.store_dtype,
627
+ device=self.device,
628
+ )
629
+ for _ in range(self.layer_num)
630
+ ]
631
+
632
+ def set_kv_buffer(
633
+ self,
634
+ layer: RadixAttention,
635
+ loc: torch.Tensor,
636
+ cache_k: torch.Tensor,
637
+ cache_v: torch.Tensor,
638
+ k_scale: Optional[float] = None,
639
+ v_scale: Optional[float] = None,
640
+ ):
641
+ layer_id = layer.layer_id
642
+ if cache_k.dtype != self.dtype:
643
+ if k_scale is not None:
644
+ cache_k.div_(k_scale)
645
+ if v_scale is not None:
646
+ cache_v.div_(v_scale)
647
+ cache_k = cache_k.to(self.dtype)
648
+ cache_v = cache_v.to(self.dtype)
649
+
650
+ if self.store_dtype != self.dtype:
651
+ cache_k = cache_k.view(self.store_dtype)
652
+ cache_v = cache_v.view(self.store_dtype)
653
+
654
+ import torch_npu
655
+
656
+ torch_npu._npu_reshape_and_cache(
657
+ key=cache_k,
658
+ value=cache_v,
659
+ key_cache=self.k_buffer[layer_id].view(
660
+ -1, self.page_size, self.head_num, self.head_dim
661
+ ),
662
+ value_cache=self.v_buffer[layer_id].view(
663
+ -1, self.page_size, self.head_num, self.head_dim
664
+ ),
665
+ slot_indices=loc,
666
+ )
667
+
668
+
434
669
  @triton.jit
435
670
  def set_mla_kv_buffer_kernel(
436
671
  kv_buffer_ptr,
@@ -554,12 +789,14 @@ class MLATokenToKVPool(KVCache):
554
789
  for _ in range(layer_num)
555
790
  ]
556
791
 
792
+ self.token_stride = kv_lora_rank + qk_rope_head_dim
557
793
  self.layer_transfer_counter = None
558
794
 
559
795
  kv_size = self.get_kv_size_bytes()
560
796
  logger.info(
561
797
  f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
562
798
  )
799
+ self.mem_usage = kv_size / GB
563
800
 
564
801
  def get_kv_size_bytes(self):
565
802
  assert hasattr(self, "kv_buffer")
@@ -638,21 +875,37 @@ class MLATokenToKVPool(KVCache):
638
875
  self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
639
876
  )
640
877
 
641
- def get_flat_data(self, indices):
642
- # prepare a large chunk of contiguous data for efficient transfer
643
- return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
644
-
645
- @debug_timing
646
- def transfer(self, indices, flat_data):
647
- # transfer prepared data from host to device
648
- flat_data = flat_data.to(device=self.device, non_blocking=False)
649
- for i in range(self.layer_num):
650
- self.kv_buffer[i][indices] = flat_data[i]
878
+ def load_from_host_per_layer(
879
+ self, host_pool, host_indices, device_indices, layer_id, io_backend
880
+ ):
881
+ transfer_kv_per_layer_mla(
882
+ src=host_pool.kv_buffer[layer_id],
883
+ dst=self.kv_buffer[layer_id],
884
+ src_indices=host_indices,
885
+ dst_indices=device_indices,
886
+ io_backend=io_backend,
887
+ page_size=self.page_size,
888
+ item_size=self.token_stride,
889
+ )
651
890
 
652
- def transfer_per_layer(self, indices, flat_data, layer_id):
653
- # transfer prepared data from host to device
654
- flat_data = flat_data.to(device=self.device, non_blocking=False)
655
- self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
891
+ def backup_to_host_all_layer(
892
+ self, host_pool, host_indices, device_indices, io_backend
893
+ ):
894
+ # todo: specialized all layer kernels for the layer-non-contiguous memory pool
895
+ for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
896
+ if layer_id - self.start_layer >= len(host_pool.kv_buffer):
897
+ raise ValueError(
898
+ f"Layer ID {layer_id} exceeds the number of layers in host pool."
899
+ )
900
+ transfer_kv_per_layer_mla(
901
+ src=self.kv_buffer[layer_id],
902
+ dst=host_pool.kv_buffer[layer_id],
903
+ src_indices=device_indices,
904
+ dst_indices=host_indices,
905
+ io_backend=io_backend,
906
+ page_size=self.page_size,
907
+ item_size=self.token_stride,
908
+ )
656
909
 
657
910
  def get_cpu_copy(self, indices):
658
911
  torch.cuda.synchronize()
@@ -682,6 +935,84 @@ class MLATokenToKVPool(KVCache):
682
935
  torch.cuda.synchronize()
683
936
 
684
937
 
938
+ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
939
+ def __init__(
940
+ self,
941
+ size: int,
942
+ page_size: int,
943
+ dtype: torch.dtype,
944
+ kv_lora_rank: int,
945
+ qk_rope_head_dim: int,
946
+ layer_num: int,
947
+ device: str,
948
+ enable_memory_saver: bool,
949
+ start_layer: Optional[int] = None,
950
+ end_layer: Optional[int] = None,
951
+ ):
952
+ super(MLATokenToKVPool, self).__init__(
953
+ size,
954
+ page_size,
955
+ dtype,
956
+ layer_num,
957
+ device,
958
+ enable_memory_saver,
959
+ start_layer,
960
+ end_layer,
961
+ )
962
+
963
+ self.kv_lora_rank = kv_lora_rank
964
+ self.qk_rope_head_dim = qk_rope_head_dim
965
+
966
+ self.custom_mem_pool = None
967
+
968
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
969
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
970
+ self.kv_buffer = [
971
+ torch.zeros(
972
+ (
973
+ self.size // self.page_size + 1,
974
+ self.page_size,
975
+ self.kv_lora_rank + self.qk_rope_head_dim,
976
+ ),
977
+ dtype=self.store_dtype,
978
+ device=self.device,
979
+ )
980
+ for _ in range(layer_num)
981
+ ]
982
+
983
+ self.layer_transfer_counter = None
984
+
985
+ kv_size = self.get_kv_size_bytes()
986
+ logger.info(
987
+ f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
988
+ )
989
+ self.mem_usage = kv_size / GB
990
+
991
+ def set_kv_buffer(
992
+ self,
993
+ layer: RadixAttention,
994
+ loc: torch.Tensor,
995
+ cache_k: torch.Tensor,
996
+ cache_v: torch.Tensor,
997
+ ):
998
+ layer_id = layer.layer_id
999
+ if cache_k.dtype != self.dtype:
1000
+ cache_k = cache_k.to(self.dtype)
1001
+
1002
+ if self.store_dtype != self.dtype:
1003
+ cache_k = cache_k.view(store_dtype)
1004
+
1005
+ import torch_npu
1006
+
1007
+ torch_npu._npu_reshape_and_cache_siso(
1008
+ key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
1009
+ key_cache=self.kv_buffer[layer_id - self.start_layer].view(
1010
+ -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
1011
+ ),
1012
+ slot_indices=loc,
1013
+ )
1014
+
1015
+
685
1016
  class DoubleSparseTokenToKVPool(KVCache):
686
1017
  def __init__(
687
1018
  self,
@@ -760,14 +1091,19 @@ class DoubleSparseTokenToKVPool(KVCache):
760
1091
  self.v_buffer[layer_id - self.start_layer][loc] = cache_v
761
1092
  self.label_buffer[layer_id - self.start_layer][loc] = cache_label
762
1093
 
763
- def get_flat_data(self, indices):
764
- pass
765
-
766
- def transfer(self, indices, flat_data):
767
- pass
1094
+ def load_from_host_per_layer(
1095
+ self, host_pool, host_indices, device_indices, layer_id, io_backend
1096
+ ):
1097
+ raise NotImplementedError(
1098
+ "HiCache not supported for DoubleSparseTokenToKVPool."
1099
+ )
768
1100
 
769
- def transfer_per_layer(self, indices, flat_data, layer_id):
770
- pass
1101
+ def backup_to_host_all_layer(
1102
+ self, host_pool, host_indices, device_indices, io_backend
1103
+ ):
1104
+ raise NotImplementedError(
1105
+ "HiCache not supported for DoubleSparseTokenToKVPool."
1106
+ )
771
1107
 
772
1108
 
773
1109
  @triton.jit