sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -76,7 +76,6 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
76
76
  This function will replace the data-tokens in between with pad_values accordingly
77
77
  """
78
78
  pad_values = [item.pad_value for item in mm_inputs.mm_items]
79
- print(f"{mm_inputs.mm_items=}")
80
79
  data_token_pairs = self.data_token_id_pairs
81
80
  mm_inputs.data_offsets = []
82
81
  if data_token_pairs is None:
@@ -222,17 +221,17 @@ def _get_precomputed_embedding(
222
221
  items: List[MultimodalDataItem],
223
222
  ) -> Optional[torch.Tensor]:
224
223
  """
225
- If all items have precomputed_features, return their concatenation.
226
- If some but not all have precomputed_features, raise NotImplementedError.
227
- If none have precomputed_features, return None.
224
+ If all items have precomputed_embeddings, return their concatenation.
225
+ If some but not all have precomputed_embeddings, raise NotImplementedError.
226
+ If none have precomputed_embeddings, return None.
228
227
  """
229
- precomputed_features = [item.precomputed_features for item in items]
230
- if any(feature is not None for feature in precomputed_features):
231
- if not all(feature is not None for feature in precomputed_features):
228
+ precomputed_embeddings = [item.precomputed_embeddings for item in items]
229
+ if any(feature is not None for feature in precomputed_embeddings):
230
+ if not all(feature is not None for feature in precomputed_embeddings):
232
231
  raise NotImplementedError(
233
232
  "MM inputs where only some items are precomputed."
234
233
  )
235
- result = torch.concat(precomputed_features)
234
+ result = torch.concat(precomputed_embeddings)
236
235
  # some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
237
236
  result = result.reshape(-1, result.shape[-1])
238
237
  return result
@@ -52,10 +52,14 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
52
52
  ScheduleBatchDisaggregationDecodeMixin,
53
53
  )
54
54
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
55
- from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
55
+ from sglang.srt.mem_cache.allocator import (
56
+ BaseTokenToKVPoolAllocator,
57
+ SWATokenToKVPoolAllocator,
58
+ )
56
59
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
57
60
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
58
61
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
62
+ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
59
63
  from sglang.srt.metrics.collector import TimeStats
60
64
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
61
65
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -197,45 +201,41 @@ class MultimodalDataItem:
197
201
  For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
198
202
  One for images and one for audio.
199
203
 
200
- We put the common fields first and the model-specific fields last.
204
+ We put the common fields first and the model-specific fields in model_specific_data.
201
205
  """
202
206
 
203
207
  modality: Modality
204
208
  hash: int = None
205
209
  pad_value: int = None
206
- image_sizes: Tuple[int, int] = None
207
210
  offsets: Optional[list] = None
211
+ # the raw features returned by processor, e.g. pixel_values or audio_features
212
+ feature: Union[torch.Tensor, np.ndarray] = None
208
213
 
209
- # the real data, pixel_values or audio_features
210
- # data: Union[List[torch.Tensor], List[np.ndarray]]
211
- pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
212
- audio_features: Union[torch.Tensor, np.ndarray] = None
213
- audio_feature_lens: Optional[List[torch.Tensor]] = None
214
- audio_offsets: Optional[List[Tuple[int, int]]] = None
215
- precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
216
-
217
- # For qwen-vl
218
- image_grid_thw: Union[torch.Tensor, np.ndarray] = None
219
- second_per_grid_ts: Optional[List[torch.Tensor]] = None
220
-
221
- # For deepseek-vl
222
- image_emb_mask: Optional[torch.Tensor] = None
223
- image_spatial_crop: Optional[torch.Tensor] = None
214
+ # the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
215
+ precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
224
216
 
225
- # For minicpmv
226
- # [num_images, (n, w, h)]
227
- tgt_size: Tuple[int, int] = None
217
+ # Model-specific data stored in a dictionary
218
+ model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)
228
219
 
229
- # For mllama
230
- aspect_ratio_id: Optional[List[torch.Tensor]] = None
231
- aspect_ratio_mask: Optional[List[torch.Tensor]] = None
220
+ def __getattr__(self, name: str):
221
+ if (
222
+ "model_specific_data" in self.__dict__
223
+ and name in self.__dict__["model_specific_data"]
224
+ ):
225
+ return self.__dict__["model_specific_data"][name]
226
+ else:
227
+ raise AttributeError(
228
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
229
+ )
232
230
 
233
- # For kimi-vl
234
- image_grid_hws: Optional[List[torch.Tensor]] = None
231
+ def __setitem__(self, key: str, value: Any):
232
+ if key in self.__dict__:
233
+ self.__dict__[key] = value
234
+ else:
235
+ self.model_specific_data[key] = value
235
236
 
236
- # For gemma3n
237
- input_features: Optional[torch.Tensor] = None
238
- input_features_mask: Optional[torch.Tensor] = None
237
+ def set(self, key: str, value: Any):
238
+ self.__setitem__(key, value)
239
239
 
240
240
  @staticmethod
241
241
  def is_empty_list(l):
@@ -250,18 +250,11 @@ class MultimodalDataItem:
250
250
  from sglang.srt.managers.mm_utils import hash_feature
251
251
 
252
252
  if self.hash is None:
253
- if self.precomputed_features is not None:
254
- self.hash = hash_feature(self.precomputed_features)
255
- elif self.is_audio():
256
- if self.audio_features is not None:
257
- self.hash = hash_feature(self.audio_features)
258
- elif self.input_features is not None:
259
- self.hash = hash_feature(self.input_features)
260
- elif self.is_video():
261
- self.hash = hash_feature(self.pixel_values_videos)
253
+ if self.feature is not None:
254
+ hashed_feature = self.feature
262
255
  else:
263
- self.hash = hash_feature(self.pixel_values)
264
-
256
+ hashed_feature = self.precomputed_embeddings
257
+ self.hash = hash_feature(hashed_feature)
265
258
  assert self.hash is not None
266
259
  self.pad_value = self.hash % (1 << 30)
267
260
 
@@ -269,25 +262,13 @@ class MultimodalDataItem:
269
262
  return self.modality == modality
270
263
 
271
264
  def is_audio(self):
272
- return (self.modality == Modality.AUDIO) and (
273
- self.precomputed_features is not None
274
- or not MultimodalDataItem.is_empty_list(self.audio_features)
275
- or not MultimodalDataItem.is_empty_list(self.input_features)
276
- )
265
+ return self.modality == Modality.AUDIO
277
266
 
278
267
  def is_image(self):
279
- return (
280
- self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
281
- ) and (
282
- self.precomputed_features is not None
283
- or not MultimodalDataItem.is_empty_list(self.pixel_values)
284
- )
268
+ return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
285
269
 
286
270
  def is_video(self):
287
- return (self.modality == Modality.VIDEO) and (
288
- self.precomputed_features is not None
289
- or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
290
- )
271
+ return self.modality == Modality.VIDEO
291
272
 
292
273
  def is_valid(self) -> bool:
293
274
  return self.is_image() or self.is_video() or self.is_audio()
@@ -307,9 +288,8 @@ class MultimodalDataItem:
307
288
  return ret
308
289
 
309
290
  def merge(self, other):
310
- self.pixel_values += other.pixel_values
311
- self.image_sizes += other.image_sizes
312
- self.image_offsets += other.image_offsets
291
+ self.feature += other.feature
292
+ self.offsets += other.offsets
313
293
  self.hash = hash((self.hash, other.hash))
314
294
  self.set_pad_value()
315
295
 
@@ -350,7 +330,6 @@ class MultimodalInputs:
350
330
 
351
331
  assert isinstance(ret.mm_items, list)
352
332
  ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
353
-
354
333
  for item in ret.mm_items:
355
334
  item.set_pad_value()
356
335
 
@@ -527,6 +506,8 @@ class Req:
527
506
  self.last_node: Any = None
528
507
  self.last_host_node: Any = None
529
508
  self.host_hit_length = 0
509
+ # The node to lock until for swa radix tree lock ref
510
+ self.swa_uuid_for_lock: Optional[int] = None
530
511
 
531
512
  # Whether or not if it is chunked. It increments whenever
532
513
  # it is chunked, and decrement whenever chunked request is
@@ -745,6 +726,7 @@ class Req:
745
726
  def reset_for_retract(self):
746
727
  self.prefix_indices = []
747
728
  self.last_node = None
729
+ self.swa_uuid_for_lock = None
748
730
  self.extend_input_len = 0
749
731
  self.is_retracted = True
750
732
  self.input_token_logprobs = None
@@ -813,6 +795,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
813
795
  req_to_token_pool: ReqToTokenPool = None
814
796
  token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
815
797
  tree_cache: BasePrefixCache = None
798
+ is_hybrid: bool = False
816
799
 
817
800
  # Batch configs
818
801
  model_config: ModelConfig = None
@@ -918,11 +901,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
918
901
  ):
919
902
  return_logprob = any(req.return_logprob for req in reqs)
920
903
 
904
+ is_hybrid = False
905
+ if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
906
+ assert isinstance(tree_cache, SWARadixCache) or isinstance(
907
+ tree_cache, SWAChunkCache
908
+ ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
909
+ is_hybrid = True
910
+
921
911
  return cls(
922
912
  reqs=reqs,
923
913
  req_to_token_pool=req_to_token_pool,
924
914
  token_to_kv_pool_allocator=token_to_kv_pool_allocator,
925
915
  tree_cache=tree_cache,
916
+ is_hybrid=is_hybrid,
926
917
  model_config=model_config,
927
918
  enable_overlap=enable_overlap,
928
919
  return_logprob=return_logprob,
@@ -953,9 +944,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
953
944
  return req_pool_indices
954
945
 
955
946
  def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
956
- if self.token_to_kv_pool_allocator.available_size() < num_tokens:
957
- if self.tree_cache is not None:
958
- self.tree_cache.evict(num_tokens)
947
+ self._evict_tree_cache_if_needed(num_tokens)
959
948
 
960
949
  if backup_state:
961
950
  state = self.token_to_kv_pool_allocator.backup_state()
@@ -966,7 +955,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
966
955
  error_msg = (
967
956
  f"{phase_str} out of memory. Try to lower your batch size.\n"
968
957
  f"Try to allocate {num_tokens} tokens.\n"
969
- f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
958
+ f"{self._available_and_evictable_str()}"
970
959
  )
971
960
  logger.error(error_msg)
972
961
  if self.tree_cache is not None:
@@ -986,16 +975,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
986
975
  extend_num_tokens: int,
987
976
  backup_state: bool = False,
988
977
  ):
989
- if (
990
- self.token_to_kv_pool_allocator.available_size()
991
- < extend_num_tokens
978
+ num_tokens = (
979
+ extend_num_tokens
992
980
  + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
993
- ):
994
- if self.tree_cache is not None:
995
- self.tree_cache.evict(
996
- extend_num_tokens
997
- + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
998
- )
981
+ )
982
+ self._evict_tree_cache_if_needed(num_tokens)
999
983
 
1000
984
  if backup_state:
1001
985
  state = self.token_to_kv_pool_allocator.backup_state()
@@ -1007,9 +991,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1007
991
  error_msg = (
1008
992
  f"Prefill out of memory. Try to lower your batch size.\n"
1009
993
  f"Try to allocate {extend_num_tokens} tokens.\n"
1010
- f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
1011
- f"{self.token_to_kv_pool_allocator.available_size()=}\n"
1012
- f"{self.tree_cache.evictable_size()=}\n"
994
+ f"{self._available_and_evictable_str()}"
1013
995
  )
1014
996
  logger.error(error_msg)
1015
997
  raise RuntimeError(error_msg)
@@ -1025,14 +1007,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1025
1007
  last_loc: torch.Tensor,
1026
1008
  backup_state: bool = False,
1027
1009
  ):
1028
- if self.tree_cache is not None:
1029
- if (
1030
- self.token_to_kv_pool_allocator.available_size()
1031
- < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1032
- ):
1033
- self.tree_cache.evict(
1034
- len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
1035
- )
1010
+ num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1011
+
1012
+ self._evict_tree_cache_if_needed(num_tokens)
1036
1013
 
1037
1014
  if backup_state:
1038
1015
  state = self.token_to_kv_pool_allocator.backup_state()
@@ -1042,9 +1019,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1042
1019
  error_msg = (
1043
1020
  f"Decode out of memory. Try to lower your batch size.\n"
1044
1021
  f"Try to allocate {len(seq_lens)} tokens.\n"
1045
- f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
1046
- f"{self.token_to_kv_pool_allocator.available_size()=}\n"
1047
- f"{self.tree_cache.evictable_size()=}\n"
1022
+ f"{self._available_and_evictable_str()}"
1048
1023
  )
1049
1024
  logger.error(error_msg)
1050
1025
  raise RuntimeError(error_msg)
@@ -1181,7 +1156,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1181
1156
  (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1182
1157
  )
1183
1158
  if isinstance(self.tree_cache, SWAChunkCache):
1184
- self.tree_cache.evict(
1159
+ self.tree_cache.evict_swa(
1185
1160
  req, pre_len, self.model_config.attention_chunk_size
1186
1161
  )
1187
1162
 
@@ -1278,11 +1253,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1278
1253
  if mm_input is None:
1279
1254
  continue
1280
1255
  for mm_item in mm_input.mm_items:
1281
- pixel_values = getattr(mm_item, "pixel_values", None)
1256
+ pixel_values = getattr(mm_item, "feature", None)
1282
1257
  if isinstance(pixel_values, torch.Tensor):
1283
- mm_item.pixel_values = pixel_values.to(
1284
- self.device, non_blocking=True
1285
- )
1258
+ mm_item.feature = pixel_values.to(self.device, non_blocking=True)
1286
1259
  self.multimodal_inputs = multimodal_inputs
1287
1260
  self.token_type_ids = token_type_ids_tensor
1288
1261
  self.seq_lens_sum = sum(seq_lens)
@@ -1328,6 +1301,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1328
1301
  self.model_config.vocab_size,
1329
1302
  )
1330
1303
 
1304
+ def prepare_for_split_prefill(self):
1305
+ self.prepare_for_extend()
1306
+ # For split prefill, we need to set the forward mode to SPLIT_PREFILL
1307
+ self.forward_mode = ForwardMode.SPLIT_PREFILL
1308
+
1331
1309
  def mix_with_running(self, running_batch: "ScheduleBatch"):
1332
1310
  self.forward_mode = ForwardMode.MIXED
1333
1311
  running_bs = running_batch.batch_size()
@@ -1371,17 +1349,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1371
1349
  )
1372
1350
 
1373
1351
  def check_decode_mem(self, buf_multiplier=1):
1374
- tokens_required = (
1352
+ num_tokens = (
1375
1353
  self.new_page_count_next_decode()
1376
1354
  * buf_multiplier
1377
1355
  * self.token_to_kv_pool_allocator.page_size
1378
1356
  )
1379
- if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1380
- return True
1381
-
1382
- self.tree_cache.evict(tokens_required)
1383
1357
 
1384
- return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1358
+ self._evict_tree_cache_if_needed(num_tokens)
1359
+ return self._is_available_size_sufficient(num_tokens)
1385
1360
 
1386
1361
  def retract_decode(self, server_args: ServerArgs):
1387
1362
  """Retract the decoding requests when there is not enough memory."""
@@ -1414,19 +1389,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1414
1389
  num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
1415
1390
  )
1416
1391
 
1392
+ def _get_available_size():
1393
+ if self.is_hybrid:
1394
+ return min(
1395
+ self.token_to_kv_pool_allocator.full_available_size(),
1396
+ self.token_to_kv_pool_allocator.swa_available_size(),
1397
+ )
1398
+ else:
1399
+ return self.token_to_kv_pool_allocator.available_size()
1400
+
1417
1401
  retracted_reqs = []
1418
1402
  seq_lens_cpu = self.seq_lens.cpu().numpy()
1419
1403
  first_iter = True
1420
1404
  while (
1421
- self.token_to_kv_pool_allocator.available_size()
1422
- < get_required_tokens(len(sorted_indices))
1405
+ _get_available_size() < get_required_tokens(len(sorted_indices))
1423
1406
  or first_iter
1424
1407
  ):
1425
1408
  if len(sorted_indices) == 1:
1426
1409
  # Corner case: only one request left
1427
- assert (
1428
- self.token_to_kv_pool_allocator.available_size() > 0
1429
- ), "No space left for only one request"
1410
+ if self.is_hybrid:
1411
+ full_available_size = (
1412
+ self.token_to_kv_pool_allocator.full_available_size()
1413
+ )
1414
+ swa_available_size = (
1415
+ self.token_to_kv_pool_allocator.swa_available_size()
1416
+ )
1417
+ assert (
1418
+ full_available_size > 0 and swa_available_size > 0
1419
+ ), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
1420
+ else:
1421
+ assert (
1422
+ self.token_to_kv_pool_allocator.available_size() > 0
1423
+ ), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
1430
1424
  break
1431
1425
 
1432
1426
  first_iter = False
@@ -1458,15 +1452,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1458
1452
  self.req_to_token_pool.free(req.req_pool_idx)
1459
1453
 
1460
1454
  # release the last node
1461
- self.tree_cache.dec_lock_ref(req.last_node)
1455
+ if self.is_hybrid:
1456
+ self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
1457
+ else:
1458
+ self.tree_cache.dec_lock_ref(req.last_node)
1462
1459
 
1463
1460
  # NOTE(lsyin): we should use the newly evictable memory instantly.
1464
- residual_size = (
1465
- len(sorted_indices) * global_config.retract_decode_steps
1466
- - self.token_to_kv_pool_allocator.available_size()
1467
- )
1468
- residual_size = max(0, residual_size)
1469
- self.tree_cache.evict(residual_size)
1461
+ num_tokens = len(sorted_indices) * global_config.retract_decode_steps
1462
+ self._evict_tree_cache_if_needed(num_tokens)
1470
1463
 
1471
1464
  req.reset_for_retract()
1472
1465
 
@@ -1559,7 +1552,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1559
1552
  # free memory
1560
1553
  if isinstance(self.tree_cache, SWAChunkCache):
1561
1554
  for req in self.reqs:
1562
- self.tree_cache.evict(
1555
+ self.tree_cache.evict_swa(
1563
1556
  req, req.seqlen - 1, self.model_config.attention_chunk_size
1564
1557
  )
1565
1558
 
@@ -1778,6 +1771,53 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1778
1771
  is_extend_in_batch=self.is_extend_in_batch,
1779
1772
  )
1780
1773
 
1774
+ def _evict_tree_cache_if_needed(
1775
+ self,
1776
+ num_tokens: int,
1777
+ ) -> None:
1778
+ if isinstance(self.tree_cache, SWAChunkCache):
1779
+ return
1780
+
1781
+ if self.is_hybrid:
1782
+ full_available_size = self.token_to_kv_pool_allocator.full_available_size()
1783
+ swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
1784
+
1785
+ if full_available_size < num_tokens or swa_available_size < num_tokens:
1786
+ if self.tree_cache is not None:
1787
+ full_num_tokens = max(0, num_tokens - full_available_size)
1788
+ swa_num_tokens = max(0, num_tokens - swa_available_size)
1789
+ self.tree_cache.evict(full_num_tokens, swa_num_tokens)
1790
+ else:
1791
+ if self.token_to_kv_pool_allocator.available_size() < num_tokens:
1792
+ if self.tree_cache is not None:
1793
+ self.tree_cache.evict(num_tokens)
1794
+
1795
+ def _is_available_size_sufficient(self, num_tokens: int) -> bool:
1796
+ if self.is_hybrid:
1797
+ return (
1798
+ self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
1799
+ and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
1800
+ )
1801
+ else:
1802
+ return self.token_to_kv_pool_allocator.available_size() >= num_tokens
1803
+
1804
+ def _available_and_evictable_str(self) -> str:
1805
+ if self.is_hybrid:
1806
+ full_available_size = self.token_to_kv_pool_allocator.full_available_size()
1807
+ swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
1808
+ full_evictable_size = self.tree_cache.full_evictable_size()
1809
+ swa_evictable_size = self.tree_cache.swa_evictable_size()
1810
+ return (
1811
+ f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
1812
+ f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
1813
+ f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
1814
+ f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
1815
+ )
1816
+ else:
1817
+ available_size = self.token_to_kv_pool_allocator.available_size()
1818
+ evictable_size = self.tree_cache.evictable_size()
1819
+ return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
1820
+
1781
1821
  def __str__(self):
1782
1822
  return (
1783
1823
  f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
25
25
  import torch
26
26
 
27
27
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
28
+ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
28
29
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
29
30
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
30
31
 
@@ -311,21 +312,43 @@ class PrefillAdder:
311
312
  ]
312
313
  )
313
314
 
315
+ self.is_hybrid = isinstance(
316
+ self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
317
+ )
318
+
314
319
  @property
315
320
  def rem_total_tokens(self):
316
- return (
317
- self.token_to_kv_pool_allocator.available_size()
318
- + self.tree_cache.evictable_size()
319
- - self.rem_total_token_offset
320
- )
321
+ if self.is_hybrid:
322
+ available_and_evictable = min(
323
+ self.token_to_kv_pool_allocator.full_available_size()
324
+ + self.tree_cache.full_evictable_size(),
325
+ self.token_to_kv_pool_allocator.swa_available_size()
326
+ + self.tree_cache.swa_evictable_size(),
327
+ )
328
+ else:
329
+ available_and_evictable = (
330
+ self.token_to_kv_pool_allocator.available_size()
331
+ + self.tree_cache.evictable_size()
332
+ )
333
+
334
+ return available_and_evictable - self.rem_total_token_offset
321
335
 
322
336
  @property
323
337
  def cur_rem_tokens(self):
324
- return (
325
- self.token_to_kv_pool_allocator.available_size()
326
- + self.tree_cache.evictable_size()
327
- - self.cur_rem_token_offset
328
- )
338
+ if self.is_hybrid:
339
+ available_and_evictable = min(
340
+ self.token_to_kv_pool_allocator.full_available_size()
341
+ + self.tree_cache.full_evictable_size(),
342
+ self.token_to_kv_pool_allocator.swa_available_size()
343
+ + self.tree_cache.swa_evictable_size(),
344
+ )
345
+ else:
346
+ available_and_evictable = (
347
+ self.token_to_kv_pool_allocator.available_size()
348
+ + self.tree_cache.evictable_size()
349
+ )
350
+
351
+ return available_and_evictable - self.cur_rem_token_offset
329
352
 
330
353
  def ceil_paged_tokens(self, tokens: int) -> int:
331
354
  return -(-tokens // self.page_size) * self.page_size
@@ -376,11 +399,18 @@ class PrefillAdder:
376
399
 
377
400
  @contextmanager
378
401
  def _lock_node(self, last_node: TreeNode):
379
- try:
380
- self.tree_cache.inc_lock_ref(last_node)
381
- yield None
382
- finally:
383
- self.tree_cache.dec_lock_ref(last_node)
402
+ if self.is_hybrid:
403
+ try:
404
+ swa_uuid_for_lock = self.tree_cache.inc_lock_ref(last_node)
405
+ yield None
406
+ finally:
407
+ self.tree_cache.dec_lock_ref(last_node, swa_uuid_for_lock)
408
+ else:
409
+ try:
410
+ self.tree_cache.inc_lock_ref(last_node)
411
+ yield None
412
+ finally:
413
+ self.tree_cache.dec_lock_ref(last_node)
384
414
 
385
415
  def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
386
416
  # Early exit if no enough tokens for the input tokens
@@ -422,16 +452,19 @@ class PrefillAdder:
422
452
  else:
423
453
  add_req_state(req, insert_sort=True)
424
454
 
425
- cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
426
- tokens_freed = 0
427
- for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
428
- # tokens_left gives a reservative calculation as the last token is not stored
429
- bs = len(self.req_states) - i
430
- min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
431
- # reserve tokens for corner cases
432
- if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
433
- return AddReqResult.NO_TOKEN
434
- tokens_freed += tokens_occupied
455
+ if not self.is_hybrid:
456
+ # Skip this logic for swa. The SWA has different memory management, and
457
+ # this mechanism is underestimating the memory usage.
458
+ cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
459
+ tokens_freed = 0
460
+ for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
461
+ # tokens_left gives a reservative calculation as the last token is not stored
462
+ bs = len(self.req_states) - i
463
+ min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
464
+ # reserve tokens for corner cases
465
+ if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
466
+ return AddReqResult.NO_TOKEN
467
+ tokens_freed += tokens_occupied
435
468
 
436
469
  if (
437
470
  self.rem_chunk_tokens is None # chunked prefill is disabled
@@ -499,7 +532,11 @@ class PrefillAdder:
499
532
  if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
500
533
  # Non-chunked prefill
501
534
  self.can_run_list.append(req)
502
- self.tree_cache.inc_lock_ref(req.last_node)
535
+ if self.is_hybrid:
536
+ swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
537
+ req.swa_uuid_for_lock = swa_uuid_for_lock
538
+ else:
539
+ self.tree_cache.inc_lock_ref(req.last_node)
503
540
  self._update_prefill_budget(
504
541
  prefix_len,
505
542
  input_tokens,
@@ -520,7 +557,11 @@ class PrefillAdder:
520
557
 
521
558
  self.can_run_list.append(req)
522
559
  self.new_chunked_req = req
523
- self.tree_cache.inc_lock_ref(req.last_node)
560
+ if self.is_hybrid:
561
+ swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
562
+ req.swa_uuid_for_lock = swa_uuid_for_lock
563
+ else:
564
+ self.tree_cache.inc_lock_ref(req.last_node)
524
565
  self._update_prefill_budget(prefix_len, trunc_len, 0)
525
566
 
526
567
  return self.budget_state()