sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
30
30
  from sglang.srt.managers.io_struct import (
31
31
  GetWeightsByNameReqInput,
32
32
  InitWeightsUpdateGroupReqInput,
33
+ LoadLoRAAdapterReqInput,
34
+ UnloadLoRAAdapterReqInput,
33
35
  UpdateWeightFromDiskReqInput,
34
36
  UpdateWeightsFromDistributedReqInput,
35
37
  UpdateWeightsFromTensorReqInput,
@@ -257,7 +259,7 @@ class TpModelWorker:
257
259
  self, recv_req: UpdateWeightsFromDistributedReqInput
258
260
  ):
259
261
  success, message = self.model_runner.update_weights_from_distributed(
260
- recv_req.name, recv_req.dtype, recv_req.shape
262
+ recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
261
263
  )
262
264
  return success, message
263
265
 
@@ -275,3 +277,13 @@ class TpModelWorker:
275
277
  recv_req.name, recv_req.truncate_size
276
278
  )
277
279
  return parameter
280
+
281
+ def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
282
+ result = self.model_runner.load_lora_adapter(
283
+ recv_req.lora_name, recv_req.lora_path
284
+ )
285
+ return result
286
+
287
+ def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
288
+ result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
289
+ return result
@@ -26,6 +26,8 @@ import torch
26
26
  from sglang.srt.managers.io_struct import (
27
27
  GetWeightsByNameReqInput,
28
28
  InitWeightsUpdateGroupReqInput,
29
+ LoadLoRAAdapterReqInput,
30
+ UnloadLoRAAdapterReqInput,
29
31
  UpdateWeightFromDiskReqInput,
30
32
  UpdateWeightsFromDistributedReqInput,
31
33
  UpdateWeightsFromTensorReqInput,
@@ -268,6 +270,12 @@ class TpModelWorkerClient:
268
270
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
269
271
  return self.worker.get_weights_by_name(recv_req)
270
272
 
273
+ def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
274
+ return self.worker.load_lora_adapter(recv_req)
275
+
276
+ def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
277
+ return self.worker.unload_lora_adapter(recv_req)
278
+
271
279
  def __delete__(self):
272
280
  self.input_queue.put((None, None))
273
281
  self.copy_queue.put((None, None, None))
@@ -20,12 +20,14 @@ Page-aligned memory pool.
20
20
  """
21
21
 
22
22
  import abc
23
+ import weakref
23
24
  from typing import TYPE_CHECKING
24
25
 
25
26
  import torch
26
27
  import triton
27
28
  import triton.language as tl
28
29
 
30
+ from sglang.srt.mem_cache.memory_pool import SWAKVPool
29
31
  from sglang.srt.utils import get_bool_env_var, next_power_of_2
30
32
 
31
33
  if TYPE_CHECKING:
@@ -55,6 +57,11 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
55
57
  def debug_print(self) -> str:
56
58
  return ""
57
59
 
60
+ def log_usage(self, evictable_size: int = 0):
61
+ num_used = self.size - (self.available_size() + evictable_size)
62
+ msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
63
+ return msg, num_used
64
+
58
65
  def available_size(self):
59
66
  return len(self.free_pages) * self.page_size
60
67
 
@@ -146,6 +153,128 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
146
153
  return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
147
154
 
148
155
 
156
+ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
157
+ """Allocator for SWA hybrid KV cache."""
158
+
159
+ def __init__(
160
+ self,
161
+ size: int,
162
+ size_swa: int,
163
+ dtype: torch.dtype,
164
+ device: str,
165
+ kvcache: SWAKVPool,
166
+ ):
167
+ super().__init__(size, 1, dtype, device, kvcache)
168
+ assert isinstance(kvcache, SWAKVPool)
169
+ self._size_full = size
170
+ self._size_swa = size_swa
171
+ self.full_attn_allocator = TokenToKVPoolAllocator(
172
+ size,
173
+ dtype,
174
+ device,
175
+ kvcache.full_kv_pool,
176
+ )
177
+ self.swa_attn_allocator = TokenToKVPoolAllocator(
178
+ size_swa,
179
+ dtype,
180
+ device,
181
+ kvcache.swa_kv_pool,
182
+ )
183
+ self.full_to_swa_index_mapping = torch.empty(
184
+ size + size_swa + 1,
185
+ dtype=torch.int64,
186
+ device=device,
187
+ )
188
+ self.clear()
189
+
190
+ self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
191
+
192
+ def available_size(self):
193
+ return min(self.full_available_size(), self.swa_available_size())
194
+
195
+ def full_available_size(self):
196
+ return self.full_attn_allocator.available_size()
197
+
198
+ def swa_available_size(self):
199
+ return self.swa_attn_allocator.available_size()
200
+
201
+ @property
202
+ def size_full(self):
203
+ return self._size_full
204
+
205
+ @property
206
+ def size_swa(self):
207
+ return self._size_swa
208
+
209
+ def debug_print(self) -> str:
210
+ msg = ""
211
+ msg += f"#swa-available-size: {self.swa_attn_allocator.available_size()}, "
212
+ msg += (
213
+ f"#full-attn-available-size: {self.full_attn_allocator.available_size()}, "
214
+ )
215
+ return msg
216
+
217
+ def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0):
218
+ used_full = self.size_full - (self.full_available_size() + full_evictable_size)
219
+ used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size)
220
+ msg = (
221
+ f"#token: full={used_full}, swa={used_swa}, "
222
+ f"token usage: full={used_full / self.size_full:.2f}, "
223
+ f"swa={used_swa / self.size_swa:.2f}, "
224
+ )
225
+ return msg, used_full
226
+
227
+ def get_kvcache(self):
228
+ return self._kvcache
229
+
230
+ def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
231
+ assert self.full_to_swa_index_mapping is not None
232
+ return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
233
+
234
+ def alloc(self, need_size: int):
235
+ if need_size > self.full_attn_allocator.available_size():
236
+ return None
237
+ if need_size > self.swa_attn_allocator.available_size():
238
+ return None
239
+
240
+ alloc_full_indices = self.full_attn_allocator.alloc(need_size)
241
+ alloc_swa_indices = self.swa_attn_allocator.alloc(need_size)
242
+ self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices
243
+ return alloc_full_indices
244
+
245
+ def free(self, free_index: torch.Tensor):
246
+ if free_index.numel() == 0:
247
+ return
248
+ if self.is_not_in_free_group:
249
+ self.full_attn_allocator.free(free_index)
250
+ self.free_swa(free_index)
251
+ else:
252
+ self.free_group.append(free_index)
253
+ assert (
254
+ self.full_attn_allocator.available_size() <= self.full_attn_allocator.size
255
+ )
256
+ assert self.swa_attn_allocator.available_size() <= self.swa_attn_allocator.size
257
+
258
+ def free_swa(self, free_index: torch.Tensor):
259
+ swa_indices = self.full_to_swa_index_mapping[free_index]
260
+ swa_indices = swa_indices[swa_indices > 0]
261
+ self.swa_attn_allocator.free(swa_indices)
262
+ self.full_to_swa_index_mapping[free_index] = 0
263
+
264
+ def backup_state(self):
265
+ raise NotImplementedError
266
+
267
+ def restore_state(self, state):
268
+ raise NotImplementedError
269
+
270
+ def clear(self):
271
+ self.swa_attn_allocator.clear()
272
+ self.full_attn_allocator.clear()
273
+ self.full_to_swa_index_mapping.fill_(0)
274
+ self.is_in_free_group = False
275
+ self.free_group = []
276
+
277
+
149
278
  @triton.jit
150
279
  def alloc_extend_kernel(
151
280
  pre_lens_ptr,
@@ -411,3 +540,164 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
411
540
  )
412
541
  self.is_not_in_free_group = True
413
542
  self.free_group = []
543
+
544
+
545
+ def alloc_extend_kernel_ascend(
546
+ prefix_lens,
547
+ seq_lens,
548
+ last_loc,
549
+ free_pages,
550
+ out_indices,
551
+ page_size,
552
+ device,
553
+ ):
554
+ extend_lens = seq_lens - prefix_lens
555
+ end_pos = torch.cumsum(extend_lens, 0)
556
+ start_pos = end_pos - extend_lens
557
+ num_new_pages = (seq_lens + page_size - 1) // page_size - (
558
+ prefix_lens + page_size - 1
559
+ ) // page_size
560
+ num_full_new_pages = (seq_lens) // page_size - (
561
+ prefix_lens + page_size - 1
562
+ ) // page_size
563
+ need_page = num_new_pages - num_full_new_pages
564
+ end_new_pages = torch.cumsum(num_new_pages, 0)
565
+ start_new_pages = end_new_pages - num_new_pages
566
+ pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
567
+ for i in range(len(prefix_lens)):
568
+ num1 = (
569
+ min(
570
+ seq_lens[i],
571
+ (prefix_lens[i] + page_size - 1) // page_size * page_size,
572
+ )
573
+ - prefix_lens[i]
574
+ )
575
+ if num1:
576
+ out_indices[start_pos[i] : start_pos[i] + num1] = (
577
+ last_loc[i] + 1 + pos_in_page[:num1].view(-1)
578
+ )
579
+
580
+ num2 = (
581
+ seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
582
+ ) * page_size
583
+ if num2:
584
+ pages = (
585
+ free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
586
+ * page_size
587
+ )
588
+ out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
589
+ pages.view(-1, 1) + pos_in_page.view(1, -1)
590
+ ).view(-1)
591
+
592
+ num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
593
+ if num3:
594
+ out_indices[end_pos[i] - num3 : end_pos[i]] = (
595
+ free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
596
+ ).view(-1)
597
+ return num_new_pages
598
+
599
+
600
+ def alloc_decode_kernel_ascend(
601
+ seq_lens,
602
+ last_loc,
603
+ free_pages,
604
+ out_indices,
605
+ page_size,
606
+ ):
607
+ num_new_pages = (seq_lens + page_size - 1) // page_size - (
608
+ seq_lens - 1 + page_size - 1
609
+ ) // page_size
610
+ end_new_pages = torch.cumsum(num_new_pages, 0)
611
+ start_new_pages = end_new_pages - num_new_pages
612
+ for i in range(len(seq_lens)):
613
+ if num_new_pages[i]:
614
+ out_indices[i] = free_pages[start_new_pages[i]] * page_size
615
+ else:
616
+ out_indices[i] = last_loc[i] + 1
617
+ return num_new_pages
618
+
619
+
620
+ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
621
+
622
+ def __init__(
623
+ self,
624
+ size: int,
625
+ page_size: int,
626
+ dtype: torch.dtype,
627
+ device: str,
628
+ kvcache: KVCache,
629
+ ):
630
+ super().__init__(size, page_size, dtype, device, kvcache)
631
+ self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
632
+
633
+ def alloc_extend(
634
+ self,
635
+ prefix_lens: torch.Tensor,
636
+ seq_lens: torch.Tensor,
637
+ last_loc: torch.Tensor,
638
+ extend_num_tokens: int,
639
+ ):
640
+ if self.debug_mode:
641
+ assert torch.all(
642
+ (last_loc + 1) % self.page_size == prefix_lens % self.page_size
643
+ )
644
+
645
+ bs = len(prefix_lens)
646
+ out_indices = torch.empty(
647
+ (extend_num_tokens,), dtype=torch.int32, device=self.device
648
+ )
649
+
650
+ self.ret_values = alloc_extend_kernel_ascend(
651
+ prefix_lens,
652
+ seq_lens,
653
+ last_loc,
654
+ self.free_pages,
655
+ out_indices,
656
+ self.page_size,
657
+ self.device,
658
+ )
659
+
660
+ if self.debug_mode:
661
+ assert len(torch.unique(out_indices)) == len(out_indices)
662
+
663
+ num_new_pages = self.ret_values.sum()
664
+ if num_new_pages > len(self.free_pages):
665
+ return None
666
+
667
+ self.free_pages = self.free_pages[num_new_pages:]
668
+ return out_indices
669
+
670
+ def alloc_decode(
671
+ self,
672
+ seq_lens: torch.Tensor,
673
+ last_loc: torch.Tensor,
674
+ ):
675
+ if self.debug_mode:
676
+ assert torch.all(
677
+ (last_loc + 2) % self.page_size == seq_lens % self.page_size
678
+ )
679
+
680
+ bs = len(seq_lens)
681
+ out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
682
+
683
+ self.ret_values = alloc_decode_kernel_ascend(
684
+ seq_lens,
685
+ last_loc,
686
+ self.free_pages,
687
+ out_indices,
688
+ self.page_size,
689
+ )
690
+
691
+ if self.debug_mode:
692
+ assert len(torch.unique(out_indices)) == len(out_indices)
693
+
694
+ num_new_pages = self.ret_values.sum()
695
+ if num_new_pages > len(self.free_pages):
696
+ return None
697
+
698
+ self.free_pages = self.free_pages[num_new_pages:]
699
+ return out_indices
700
+
701
+ def clear(self):
702
+ super().clear()
703
+ self.free_pages = self.free_pages.to(torch.int32)
@@ -2,11 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
4
 
5
- from typing import TYPE_CHECKING, Any
5
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
6
6
 
7
7
  import torch
8
8
 
9
- from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
9
+ from sglang.srt.mem_cache.allocator import (
10
+ BaseTokenToKVPoolAllocator,
11
+ SWATokenToKVPoolAllocator,
12
+ )
10
13
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
11
14
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
12
15
 
@@ -63,3 +66,32 @@ class ChunkCache(BasePrefixCache):
63
66
 
64
67
  def pretty_print(self):
65
68
  return ""
69
+
70
+
71
+ class SWAChunkCache(ChunkCache):
72
+ """ChunkCache with support for hybrid KV cache operations."""
73
+
74
+ def __init__(
75
+ self,
76
+ req_to_token_pool: ReqToTokenPool,
77
+ token_to_kv_pool_allocator: SWATokenToKVPoolAllocator,
78
+ page_size: int,
79
+ ):
80
+ super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
81
+ assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
82
+
83
+ def evict(
84
+ self,
85
+ req: Req,
86
+ prelen: int,
87
+ attention_chunk_size: int,
88
+ ):
89
+ if prelen >= req.evicted_seqlen_local + attention_chunk_size:
90
+ new_evicted_seqlen_local = attention_chunk_size * (
91
+ prelen // attention_chunk_size
92
+ )
93
+ free_slots = self.req_to_token_pool.req_to_token[
94
+ req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local
95
+ ]
96
+ self.token_to_kv_pool_allocator.free_swa(free_slots)
97
+ req.evicted_seqlen_local = new_evicted_seqlen_local