sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -45,17 +45,20 @@ import triton
45
45
  import triton.language as tl
46
46
 
47
47
  from sglang.global_config import global_config
48
- from sglang.srt.configs.model_config import ModelConfig
49
48
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
50
49
  from sglang.srt.disaggregation.base import BaseKVSender
51
50
  from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
52
51
  ScheduleBatchDisaggregationDecodeMixin,
53
52
  )
54
53
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
55
- from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
54
+ from sglang.srt.mem_cache.allocator import (
55
+ BaseTokenToKVPoolAllocator,
56
+ SWATokenToKVPoolAllocator,
57
+ )
56
58
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
57
59
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
58
60
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
61
+ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
59
62
  from sglang.srt.metrics.collector import TimeStats
60
63
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
61
64
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -64,6 +67,7 @@ from sglang.srt.server_args import ServerArgs
64
67
  from sglang.srt.utils import flatten_nested_list, support_triton
65
68
 
66
69
  if TYPE_CHECKING:
70
+ from sglang.srt.configs.model_config import ModelConfig
67
71
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
68
72
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
69
73
 
@@ -102,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
102
106
  "num_reserved_decode_tokens",
103
107
  "weight_loader_disable_mmap",
104
108
  "enable_triton_kernel_moe",
109
+ "enable_multimodal",
105
110
  ]
106
111
 
107
112
  # Put some global args for easy access
@@ -197,45 +202,41 @@ class MultimodalDataItem:
197
202
  For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
198
203
  One for images and one for audio.
199
204
 
200
- We put the common fields first and the model-specific fields last.
205
+ We put the common fields first and the model-specific fields in model_specific_data.
201
206
  """
202
207
 
203
208
  modality: Modality
204
209
  hash: int = None
205
210
  pad_value: int = None
206
- image_sizes: Tuple[int, int] = None
207
211
  offsets: Optional[list] = None
212
+ # the raw features returned by processor, e.g. pixel_values or audio_features
213
+ feature: Union[torch.Tensor, np.ndarray] = None
208
214
 
209
- # the real data, pixel_values or audio_features
210
- # data: Union[List[torch.Tensor], List[np.ndarray]]
211
- pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
212
- audio_features: Union[torch.Tensor, np.ndarray] = None
213
- audio_feature_lens: Optional[List[torch.Tensor]] = None
214
- audio_offsets: Optional[List[Tuple[int, int]]] = None
215
- precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
216
-
217
- # For qwen-vl
218
- image_grid_thw: Union[torch.Tensor, np.ndarray] = None
219
- second_per_grid_ts: Optional[List[torch.Tensor]] = None
215
+ # the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
216
+ precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
220
217
 
221
- # For deepseek-vl
222
- image_emb_mask: Optional[torch.Tensor] = None
223
- image_spatial_crop: Optional[torch.Tensor] = None
218
+ # Model-specific data stored in a dictionary
219
+ model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)
224
220
 
225
- # For minicpmv
226
- # [num_images, (n, w, h)]
227
- tgt_size: Tuple[int, int] = None
228
-
229
- # For mllama
230
- aspect_ratio_id: Optional[List[torch.Tensor]] = None
231
- aspect_ratio_mask: Optional[List[torch.Tensor]] = None
221
+ def __getattr__(self, name: str):
222
+ if (
223
+ "model_specific_data" in self.__dict__
224
+ and name in self.__dict__["model_specific_data"]
225
+ ):
226
+ return self.__dict__["model_specific_data"][name]
227
+ else:
228
+ raise AttributeError(
229
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
230
+ )
232
231
 
233
- # For kimi-vl
234
- image_grid_hws: Optional[List[torch.Tensor]] = None
232
+ def __setitem__(self, key: str, value: Any):
233
+ if key in self.__dict__:
234
+ self.__dict__[key] = value
235
+ else:
236
+ self.model_specific_data[key] = value
235
237
 
236
- # For gemma3n
237
- input_features: Optional[torch.Tensor] = None
238
- input_features_mask: Optional[torch.Tensor] = None
238
+ def set(self, key: str, value: Any):
239
+ self.__setitem__(key, value)
239
240
 
240
241
  @staticmethod
241
242
  def is_empty_list(l):
@@ -250,18 +251,11 @@ class MultimodalDataItem:
250
251
  from sglang.srt.managers.mm_utils import hash_feature
251
252
 
252
253
  if self.hash is None:
253
- if self.precomputed_features is not None:
254
- self.hash = hash_feature(self.precomputed_features)
255
- elif self.is_audio():
256
- if self.audio_features is not None:
257
- self.hash = hash_feature(self.audio_features)
258
- elif self.input_features is not None:
259
- self.hash = hash_feature(self.input_features)
260
- elif self.is_video():
261
- self.hash = hash_feature(self.pixel_values_videos)
254
+ if self.feature is not None:
255
+ hashed_feature = self.feature
262
256
  else:
263
- self.hash = hash_feature(self.pixel_values)
264
-
257
+ hashed_feature = self.precomputed_embeddings
258
+ self.hash = hash_feature(hashed_feature)
265
259
  assert self.hash is not None
266
260
  self.pad_value = self.hash % (1 << 30)
267
261
 
@@ -269,25 +263,13 @@ class MultimodalDataItem:
269
263
  return self.modality == modality
270
264
 
271
265
  def is_audio(self):
272
- return (self.modality == Modality.AUDIO) and (
273
- self.precomputed_features is not None
274
- or not MultimodalDataItem.is_empty_list(self.audio_features)
275
- or not MultimodalDataItem.is_empty_list(self.input_features)
276
- )
266
+ return self.modality == Modality.AUDIO
277
267
 
278
268
  def is_image(self):
279
- return (
280
- self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
281
- ) and (
282
- self.precomputed_features is not None
283
- or not MultimodalDataItem.is_empty_list(self.pixel_values)
284
- )
269
+ return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
285
270
 
286
271
  def is_video(self):
287
- return (self.modality == Modality.VIDEO) and (
288
- self.precomputed_features is not None
289
- or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
290
- )
272
+ return self.modality == Modality.VIDEO
291
273
 
292
274
  def is_valid(self) -> bool:
293
275
  return self.is_image() or self.is_video() or self.is_audio()
@@ -307,9 +289,8 @@ class MultimodalDataItem:
307
289
  return ret
308
290
 
309
291
  def merge(self, other):
310
- self.pixel_values += other.pixel_values
311
- self.image_sizes += other.image_sizes
312
- self.image_offsets += other.image_offsets
292
+ self.feature += other.feature
293
+ self.offsets += other.offsets
313
294
  self.hash = hash((self.hash, other.hash))
314
295
  self.set_pad_value()
315
296
 
@@ -350,7 +331,6 @@ class MultimodalInputs:
350
331
 
351
332
  assert isinstance(ret.mm_items, list)
352
333
  ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
353
-
354
334
  for item in ret.mm_items:
355
335
  item.set_pad_value()
356
336
 
@@ -451,6 +431,7 @@ class Req:
451
431
  bootstrap_port: Optional[int] = None,
452
432
  bootstrap_room: Optional[int] = None,
453
433
  data_parallel_rank: Optional[int] = None,
434
+ vocab_size: Optional[int] = None,
454
435
  ):
455
436
  # Input and output info
456
437
  self.rid = rid
@@ -500,6 +481,7 @@ class Req:
500
481
  self.to_abort_message: str = None
501
482
  self.stream = stream
502
483
  self.eos_token_ids = eos_token_ids
484
+ self.vocab_size = vocab_size
503
485
 
504
486
  # For incremental decoding
505
487
  # ----- | --------- read_ids -------|
@@ -527,6 +509,8 @@ class Req:
527
509
  self.last_node: Any = None
528
510
  self.last_host_node: Any = None
529
511
  self.host_hit_length = 0
512
+ # The node to lock until for swa radix tree lock ref
513
+ self.swa_uuid_for_lock: Optional[int] = None
530
514
 
531
515
  # Whether or not if it is chunked. It increments whenever
532
516
  # it is chunked, and decrement whenever chunked request is
@@ -731,6 +715,14 @@ class Req:
731
715
  self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
732
716
  return
733
717
 
718
+ if last_token_id > self.vocab_size or last_token_id < 0:
719
+ if self.sampling_params.stop_token_ids:
720
+ self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
721
+ if self.eos_token_ids:
722
+ self.output_ids[-1] = next(iter(self.eos_token_ids))
723
+ self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
724
+ return
725
+
734
726
  # Check stop strings
735
727
  if len(self.sampling_params.stop_strs) > 0:
736
728
  tail_str = self.tokenizer.decode(
@@ -745,6 +737,7 @@ class Req:
745
737
  def reset_for_retract(self):
746
738
  self.prefix_indices = []
747
739
  self.last_node = None
740
+ self.swa_uuid_for_lock = None
748
741
  self.extend_input_len = 0
749
742
  self.is_retracted = True
750
743
  self.input_token_logprobs = None
@@ -813,6 +806,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
813
806
  req_to_token_pool: ReqToTokenPool = None
814
807
  token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
815
808
  tree_cache: BasePrefixCache = None
809
+ is_hybrid: bool = False
816
810
 
817
811
  # Batch configs
818
812
  model_config: ModelConfig = None
@@ -918,11 +912,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
918
912
  ):
919
913
  return_logprob = any(req.return_logprob for req in reqs)
920
914
 
915
+ is_hybrid = False
916
+ if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
917
+ assert isinstance(tree_cache, SWARadixCache) or isinstance(
918
+ tree_cache, SWAChunkCache
919
+ ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
920
+ is_hybrid = True
921
+
921
922
  return cls(
922
923
  reqs=reqs,
923
924
  req_to_token_pool=req_to_token_pool,
924
925
  token_to_kv_pool_allocator=token_to_kv_pool_allocator,
925
926
  tree_cache=tree_cache,
927
+ is_hybrid=is_hybrid,
926
928
  model_config=model_config,
927
929
  enable_overlap=enable_overlap,
928
930
  return_logprob=return_logprob,
@@ -953,9 +955,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
953
955
  return req_pool_indices
954
956
 
955
957
  def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
956
- if self.token_to_kv_pool_allocator.available_size() < num_tokens:
957
- if self.tree_cache is not None:
958
- self.tree_cache.evict(num_tokens)
958
+ self._evict_tree_cache_if_needed(num_tokens)
959
959
 
960
960
  if backup_state:
961
961
  state = self.token_to_kv_pool_allocator.backup_state()
@@ -966,7 +966,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
966
966
  error_msg = (
967
967
  f"{phase_str} out of memory. Try to lower your batch size.\n"
968
968
  f"Try to allocate {num_tokens} tokens.\n"
969
- f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
969
+ f"{self._available_and_evictable_str()}"
970
970
  )
971
971
  logger.error(error_msg)
972
972
  if self.tree_cache is not None:
@@ -986,16 +986,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
986
986
  extend_num_tokens: int,
987
987
  backup_state: bool = False,
988
988
  ):
989
- if (
990
- self.token_to_kv_pool_allocator.available_size()
991
- < extend_num_tokens
989
+ num_tokens = (
990
+ extend_num_tokens
992
991
  + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
993
- ):
994
- if self.tree_cache is not None:
995
- self.tree_cache.evict(
996
- extend_num_tokens
997
- + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
998
- )
992
+ )
993
+ self._evict_tree_cache_if_needed(num_tokens)
999
994
 
1000
995
  if backup_state:
1001
996
  state = self.token_to_kv_pool_allocator.backup_state()
@@ -1007,9 +1002,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1007
1002
  error_msg = (
1008
1003
  f"Prefill out of memory. Try to lower your batch size.\n"
1009
1004
  f"Try to allocate {extend_num_tokens} tokens.\n"
1010
- f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
1011
- f"{self.token_to_kv_pool_allocator.available_size()=}\n"
1012
- f"{self.tree_cache.evictable_size()=}\n"
1005
+ f"{self._available_and_evictable_str()}"
1013
1006
  )
1014
1007
  logger.error(error_msg)
1015
1008
  raise RuntimeError(error_msg)
@@ -1025,14 +1018,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1025
1018
  last_loc: torch.Tensor,
1026
1019
  backup_state: bool = False,
1027
1020
  ):
1028
- if self.tree_cache is not None:
1029
- if (
1030
- self.token_to_kv_pool_allocator.available_size()
1031
- < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1032
- ):
1033
- self.tree_cache.evict(
1034
- len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
1035
- )
1021
+ num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1022
+
1023
+ self._evict_tree_cache_if_needed(num_tokens)
1036
1024
 
1037
1025
  if backup_state:
1038
1026
  state = self.token_to_kv_pool_allocator.backup_state()
@@ -1042,9 +1030,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1042
1030
  error_msg = (
1043
1031
  f"Decode out of memory. Try to lower your batch size.\n"
1044
1032
  f"Try to allocate {len(seq_lens)} tokens.\n"
1045
- f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
1046
- f"{self.token_to_kv_pool_allocator.available_size()=}\n"
1047
- f"{self.tree_cache.evictable_size()=}\n"
1033
+ f"{self._available_and_evictable_str()}"
1048
1034
  )
1049
1035
  logger.error(error_msg)
1050
1036
  raise RuntimeError(error_msg)
@@ -1181,7 +1167,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1181
1167
  (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1182
1168
  )
1183
1169
  if isinstance(self.tree_cache, SWAChunkCache):
1184
- self.tree_cache.evict(
1170
+ self.tree_cache.evict_swa(
1185
1171
  req, pre_len, self.model_config.attention_chunk_size
1186
1172
  )
1187
1173
 
@@ -1278,11 +1264,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1278
1264
  if mm_input is None:
1279
1265
  continue
1280
1266
  for mm_item in mm_input.mm_items:
1281
- pixel_values = getattr(mm_item, "pixel_values", None)
1267
+ pixel_values = getattr(mm_item, "feature", None)
1282
1268
  if isinstance(pixel_values, torch.Tensor):
1283
- mm_item.pixel_values = pixel_values.to(
1284
- self.device, non_blocking=True
1285
- )
1269
+ mm_item.feature = pixel_values.to(self.device, non_blocking=True)
1286
1270
  self.multimodal_inputs = multimodal_inputs
1287
1271
  self.token_type_ids = token_type_ids_tensor
1288
1272
  self.seq_lens_sum = sum(seq_lens)
@@ -1328,6 +1312,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1328
1312
  self.model_config.vocab_size,
1329
1313
  )
1330
1314
 
1315
+ def prepare_for_split_prefill(self):
1316
+ self.prepare_for_extend()
1317
+ # For split prefill, we need to set the forward mode to SPLIT_PREFILL
1318
+ self.forward_mode = ForwardMode.SPLIT_PREFILL
1319
+
1331
1320
  def mix_with_running(self, running_batch: "ScheduleBatch"):
1332
1321
  self.forward_mode = ForwardMode.MIXED
1333
1322
  running_bs = running_batch.batch_size()
@@ -1371,17 +1360,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1371
1360
  )
1372
1361
 
1373
1362
  def check_decode_mem(self, buf_multiplier=1):
1374
- tokens_required = (
1363
+ num_tokens = (
1375
1364
  self.new_page_count_next_decode()
1376
1365
  * buf_multiplier
1377
1366
  * self.token_to_kv_pool_allocator.page_size
1378
1367
  )
1379
- if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1380
- return True
1381
1368
 
1382
- self.tree_cache.evict(tokens_required)
1383
-
1384
- return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1369
+ self._evict_tree_cache_if_needed(num_tokens)
1370
+ return self._is_available_size_sufficient(num_tokens)
1385
1371
 
1386
1372
  def retract_decode(self, server_args: ServerArgs):
1387
1373
  """Retract the decoding requests when there is not enough memory."""
@@ -1414,19 +1400,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1414
1400
  num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
1415
1401
  )
1416
1402
 
1403
+ def _get_available_size():
1404
+ if self.is_hybrid:
1405
+ return min(
1406
+ self.token_to_kv_pool_allocator.full_available_size(),
1407
+ self.token_to_kv_pool_allocator.swa_available_size(),
1408
+ )
1409
+ else:
1410
+ return self.token_to_kv_pool_allocator.available_size()
1411
+
1417
1412
  retracted_reqs = []
1418
1413
  seq_lens_cpu = self.seq_lens.cpu().numpy()
1419
1414
  first_iter = True
1420
1415
  while (
1421
- self.token_to_kv_pool_allocator.available_size()
1422
- < get_required_tokens(len(sorted_indices))
1416
+ _get_available_size() < get_required_tokens(len(sorted_indices))
1423
1417
  or first_iter
1424
1418
  ):
1425
1419
  if len(sorted_indices) == 1:
1426
1420
  # Corner case: only one request left
1427
- assert (
1428
- self.token_to_kv_pool_allocator.available_size() > 0
1429
- ), "No space left for only one request"
1421
+ if self.is_hybrid:
1422
+ full_available_size = (
1423
+ self.token_to_kv_pool_allocator.full_available_size()
1424
+ )
1425
+ swa_available_size = (
1426
+ self.token_to_kv_pool_allocator.swa_available_size()
1427
+ )
1428
+ assert (
1429
+ full_available_size > 0 and swa_available_size > 0
1430
+ ), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
1431
+ else:
1432
+ assert (
1433
+ self.token_to_kv_pool_allocator.available_size() > 0
1434
+ ), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
1430
1435
  break
1431
1436
 
1432
1437
  first_iter = False
@@ -1458,15 +1463,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1458
1463
  self.req_to_token_pool.free(req.req_pool_idx)
1459
1464
 
1460
1465
  # release the last node
1461
- self.tree_cache.dec_lock_ref(req.last_node)
1466
+ if self.is_hybrid:
1467
+ self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
1468
+ else:
1469
+ self.tree_cache.dec_lock_ref(req.last_node)
1462
1470
 
1463
1471
  # NOTE(lsyin): we should use the newly evictable memory instantly.
1464
- residual_size = (
1465
- len(sorted_indices) * global_config.retract_decode_steps
1466
- - self.token_to_kv_pool_allocator.available_size()
1467
- )
1468
- residual_size = max(0, residual_size)
1469
- self.tree_cache.evict(residual_size)
1472
+ num_tokens = len(sorted_indices) * global_config.retract_decode_steps
1473
+ self._evict_tree_cache_if_needed(num_tokens)
1470
1474
 
1471
1475
  req.reset_for_retract()
1472
1476
 
@@ -1559,7 +1563,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1559
1563
  # free memory
1560
1564
  if isinstance(self.tree_cache, SWAChunkCache):
1561
1565
  for req in self.reqs:
1562
- self.tree_cache.evict(
1566
+ self.tree_cache.evict_swa(
1563
1567
  req, req.seqlen - 1, self.model_config.attention_chunk_size
1564
1568
  )
1565
1569
 
@@ -1778,6 +1782,53 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1778
1782
  is_extend_in_batch=self.is_extend_in_batch,
1779
1783
  )
1780
1784
 
1785
+ def _evict_tree_cache_if_needed(
1786
+ self,
1787
+ num_tokens: int,
1788
+ ) -> None:
1789
+ if isinstance(self.tree_cache, SWAChunkCache):
1790
+ return
1791
+
1792
+ if self.is_hybrid:
1793
+ full_available_size = self.token_to_kv_pool_allocator.full_available_size()
1794
+ swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
1795
+
1796
+ if full_available_size < num_tokens or swa_available_size < num_tokens:
1797
+ if self.tree_cache is not None:
1798
+ full_num_tokens = max(0, num_tokens - full_available_size)
1799
+ swa_num_tokens = max(0, num_tokens - swa_available_size)
1800
+ self.tree_cache.evict(full_num_tokens, swa_num_tokens)
1801
+ else:
1802
+ if self.token_to_kv_pool_allocator.available_size() < num_tokens:
1803
+ if self.tree_cache is not None:
1804
+ self.tree_cache.evict(num_tokens)
1805
+
1806
+ def _is_available_size_sufficient(self, num_tokens: int) -> bool:
1807
+ if self.is_hybrid:
1808
+ return (
1809
+ self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
1810
+ and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
1811
+ )
1812
+ else:
1813
+ return self.token_to_kv_pool_allocator.available_size() >= num_tokens
1814
+
1815
+ def _available_and_evictable_str(self) -> str:
1816
+ if self.is_hybrid:
1817
+ full_available_size = self.token_to_kv_pool_allocator.full_available_size()
1818
+ swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
1819
+ full_evictable_size = self.tree_cache.full_evictable_size()
1820
+ swa_evictable_size = self.tree_cache.swa_evictable_size()
1821
+ return (
1822
+ f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
1823
+ f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
1824
+ f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
1825
+ f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
1826
+ )
1827
+ else:
1828
+ available_size = self.token_to_kv_pool_allocator.available_size()
1829
+ evictable_size = self.tree_cache.evictable_size()
1830
+ return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
1831
+
1781
1832
  def __str__(self):
1782
1833
  return (
1783
1834
  f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
@@ -1839,7 +1890,7 @@ class ModelWorkerBatch:
1839
1890
  sampling_info: SamplingBatchInfo
1840
1891
 
1841
1892
  # The input Embeds
1842
- input_embeds: Optional[torch.tensor] = None
1893
+ input_embeds: Optional[torch.Tensor] = None
1843
1894
 
1844
1895
  # For corss-encoder model
1845
1896
  token_type_ids: Optional[torch.Tensor] = None
@@ -1849,7 +1900,6 @@ class ModelWorkerBatch:
1849
1900
  spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
1850
1901
  # If set, the output of the batch contains the hidden states of the run.
1851
1902
  capture_hidden_mode: CaptureHiddenMode = None
1852
- spec_num_draft_tokens: Optional[int] = None
1853
1903
  hicache_consumer_index: int = 0
1854
1904
 
1855
1905
  # Overlap event
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
25
25
  import torch
26
26
 
27
27
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
28
+ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
28
29
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
29
30
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
30
31
 
@@ -311,21 +312,43 @@ class PrefillAdder:
311
312
  ]
312
313
  )
313
314
 
315
+ self.is_hybrid = isinstance(
316
+ self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
317
+ )
318
+
314
319
  @property
315
320
  def rem_total_tokens(self):
316
- return (
317
- self.token_to_kv_pool_allocator.available_size()
318
- + self.tree_cache.evictable_size()
319
- - self.rem_total_token_offset
320
- )
321
+ if self.is_hybrid:
322
+ available_and_evictable = min(
323
+ self.token_to_kv_pool_allocator.full_available_size()
324
+ + self.tree_cache.full_evictable_size(),
325
+ self.token_to_kv_pool_allocator.swa_available_size()
326
+ + self.tree_cache.swa_evictable_size(),
327
+ )
328
+ else:
329
+ available_and_evictable = (
330
+ self.token_to_kv_pool_allocator.available_size()
331
+ + self.tree_cache.evictable_size()
332
+ )
333
+
334
+ return available_and_evictable - self.rem_total_token_offset
321
335
 
322
336
  @property
323
337
  def cur_rem_tokens(self):
324
- return (
325
- self.token_to_kv_pool_allocator.available_size()
326
- + self.tree_cache.evictable_size()
327
- - self.cur_rem_token_offset
328
- )
338
+ if self.is_hybrid:
339
+ available_and_evictable = min(
340
+ self.token_to_kv_pool_allocator.full_available_size()
341
+ + self.tree_cache.full_evictable_size(),
342
+ self.token_to_kv_pool_allocator.swa_available_size()
343
+ + self.tree_cache.swa_evictable_size(),
344
+ )
345
+ else:
346
+ available_and_evictable = (
347
+ self.token_to_kv_pool_allocator.available_size()
348
+ + self.tree_cache.evictable_size()
349
+ )
350
+
351
+ return available_and_evictable - self.cur_rem_token_offset
329
352
 
330
353
  def ceil_paged_tokens(self, tokens: int) -> int:
331
354
  return -(-tokens // self.page_size) * self.page_size
@@ -376,11 +399,18 @@ class PrefillAdder:
376
399
 
377
400
  @contextmanager
378
401
  def _lock_node(self, last_node: TreeNode):
379
- try:
380
- self.tree_cache.inc_lock_ref(last_node)
381
- yield None
382
- finally:
383
- self.tree_cache.dec_lock_ref(last_node)
402
+ if self.is_hybrid:
403
+ try:
404
+ swa_uuid_for_lock = self.tree_cache.inc_lock_ref(last_node)
405
+ yield None
406
+ finally:
407
+ self.tree_cache.dec_lock_ref(last_node, swa_uuid_for_lock)
408
+ else:
409
+ try:
410
+ self.tree_cache.inc_lock_ref(last_node)
411
+ yield None
412
+ finally:
413
+ self.tree_cache.dec_lock_ref(last_node)
384
414
 
385
415
  def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
386
416
  # Early exit if no enough tokens for the input tokens
@@ -422,16 +452,19 @@ class PrefillAdder:
422
452
  else:
423
453
  add_req_state(req, insert_sort=True)
424
454
 
425
- cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
426
- tokens_freed = 0
427
- for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
428
- # tokens_left gives a reservative calculation as the last token is not stored
429
- bs = len(self.req_states) - i
430
- min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
431
- # reserve tokens for corner cases
432
- if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
433
- return AddReqResult.NO_TOKEN
434
- tokens_freed += tokens_occupied
455
+ if not self.is_hybrid:
456
+ # Skip this logic for swa. The SWA has different memory management, and
457
+ # this mechanism is underestimating the memory usage.
458
+ cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
459
+ tokens_freed = 0
460
+ for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
461
+ # tokens_left gives a reservative calculation as the last token is not stored
462
+ bs = len(self.req_states) - i
463
+ min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
464
+ # reserve tokens for corner cases
465
+ if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
466
+ return AddReqResult.NO_TOKEN
467
+ tokens_freed += tokens_occupied
435
468
 
436
469
  if (
437
470
  self.rem_chunk_tokens is None # chunked prefill is disabled
@@ -499,7 +532,11 @@ class PrefillAdder:
499
532
  if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
500
533
  # Non-chunked prefill
501
534
  self.can_run_list.append(req)
502
- self.tree_cache.inc_lock_ref(req.last_node)
535
+ if self.is_hybrid:
536
+ swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
537
+ req.swa_uuid_for_lock = swa_uuid_for_lock
538
+ else:
539
+ self.tree_cache.inc_lock_ref(req.last_node)
503
540
  self._update_prefill_budget(
504
541
  prefix_len,
505
542
  input_tokens,
@@ -520,7 +557,11 @@ class PrefillAdder:
520
557
 
521
558
  self.can_run_list.append(req)
522
559
  self.new_chunked_req = req
523
- self.tree_cache.inc_lock_ref(req.last_node)
560
+ if self.is_hybrid:
561
+ swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
562
+ req.swa_uuid_for_lock = swa_uuid_for_lock
563
+ else:
564
+ self.tree_cache.inc_lock_ref(req.last_node)
524
565
  self._update_prefill_budget(prefix_len, trunc_len, 0)
525
566
 
526
567
  return self.budget_state()