sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/model_config.py +1 -0
  4. sglang/srt/constrained/base_grammar_backend.py +5 -1
  5. sglang/srt/custom_op.py +5 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  7. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  8. sglang/srt/entrypoints/engine.py +0 -5
  9. sglang/srt/layers/attention/flashattention_backend.py +394 -76
  10. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  11. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  12. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  13. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  14. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  15. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  20. sglang/srt/layers/moe/topk.py +49 -3
  21. sglang/srt/layers/quantization/__init__.py +4 -1
  22. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  23. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  24. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  25. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  26. sglang/srt/layers/quantization/utils.py +1 -1
  27. sglang/srt/layers/rotary_embedding.py +0 -12
  28. sglang/srt/managers/cache_controller.py +34 -11
  29. sglang/srt/managers/mm_utils.py +202 -156
  30. sglang/srt/managers/multimodal_processor.py +0 -2
  31. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  32. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  33. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  34. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  35. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  36. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  37. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  38. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  40. sglang/srt/managers/schedule_batch.py +185 -128
  41. sglang/srt/managers/scheduler.py +4 -4
  42. sglang/srt/managers/tokenizer_manager.py +1 -1
  43. sglang/srt/managers/utils.py +1 -6
  44. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  45. sglang/srt/mem_cache/memory_pool.py +72 -6
  46. sglang/srt/mem_cache/paged_allocator.py +39 -0
  47. sglang/srt/metrics/collector.py +23 -53
  48. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  49. sglang/srt/model_executor/forward_batch_info.py +10 -10
  50. sglang/srt/model_executor/model_runner.py +59 -57
  51. sglang/srt/model_loader/loader.py +8 -0
  52. sglang/srt/models/clip.py +12 -7
  53. sglang/srt/models/deepseek_janus_pro.py +10 -15
  54. sglang/srt/models/deepseek_v2.py +212 -121
  55. sglang/srt/models/deepseek_vl2.py +105 -104
  56. sglang/srt/models/gemma3_mm.py +14 -80
  57. sglang/srt/models/llama.py +4 -1
  58. sglang/srt/models/llava.py +31 -19
  59. sglang/srt/models/llavavid.py +16 -7
  60. sglang/srt/models/minicpmo.py +63 -147
  61. sglang/srt/models/minicpmv.py +17 -27
  62. sglang/srt/models/mllama.py +29 -14
  63. sglang/srt/models/qwen2.py +9 -6
  64. sglang/srt/models/qwen2_5_vl.py +21 -31
  65. sglang/srt/models/qwen2_vl.py +20 -21
  66. sglang/srt/openai_api/adapter.py +18 -6
  67. sglang/srt/platforms/interface.py +371 -0
  68. sglang/srt/server_args.py +99 -14
  69. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  70. sglang/srt/speculative/eagle_utils.py +140 -28
  71. sglang/srt/speculative/eagle_worker.py +93 -24
  72. sglang/srt/utils.py +104 -51
  73. sglang/test/test_custom_ops.py +55 -0
  74. sglang/test/test_utils.py +13 -26
  75. sglang/utils.py +2 -2
  76. sglang/version.py +1 -1
  77. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
  78. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
  79. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  80. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  81. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from enum import Enum, auto
4
+
3
5
  # Copyright 2023-2024 SGLang Team
4
6
  # Licensed under the Apache License, Version 2.0 (the "License");
5
7
  # you may not use this file except in compliance with the License.
@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
51
53
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
52
54
  from sglang.srt.sampling.sampling_params import SamplingParams
53
55
  from sglang.srt.server_args import ServerArgs
54
- from sglang.srt.utils import get_compiler_backend
56
+ from sglang.srt.utils import flatten_nested_list, get_compiler_backend
55
57
 
56
58
  if TYPE_CHECKING:
57
59
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -70,14 +72,16 @@ global_server_args_dict = {
70
72
  "enable_dp_attention": ServerArgs.enable_dp_attention,
71
73
  "enable_ep_moe": ServerArgs.enable_ep_moe,
72
74
  "enable_deepep_moe": ServerArgs.enable_deepep_moe,
75
+ "deepep_mode": ServerArgs.deepep_mode,
73
76
  "device": ServerArgs.device,
74
77
  "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
75
78
  "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
76
- "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
77
79
  "enable_flashmla": ServerArgs.enable_flashmla,
78
80
  "disable_radix_cache": ServerArgs.disable_radix_cache,
79
81
  "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
80
82
  "chunked_prefill_size": ServerArgs.chunked_prefill_size,
83
+ "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
84
+ "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
81
85
  }
82
86
 
83
87
  logger = logging.getLogger(__name__)
@@ -143,165 +147,185 @@ class FINISH_ABORT(BaseFinishReason):
143
147
  }
144
148
 
145
149
 
150
+ class Modality(Enum):
151
+ IMAGE = auto()
152
+ MULTI_IMAGES = auto()
153
+ VIDEO = auto()
154
+ AUDIO = auto()
155
+
156
+
146
157
  @dataclasses.dataclass
147
- class MultimodalInputs:
148
- """The image related inputs."""
158
+ class MultimodalDataItem:
159
+ """
160
+ A single multimodal data, from a single image/video/audio or other
161
+ """
162
+
163
+ modality: Modality
164
+
165
+ hash: int = None
166
+ pad_value: int = None
149
167
 
150
- pixel_values: Union[torch.Tensor, np.array]
151
- data_hashes: Optional[list] = None
152
- image_sizes: Optional[list] = None
168
+ aspect_ratio_id: Optional[List[torch.Tensor]] = None
169
+ aspect_ratio_mask: Optional[List[torch.Tensor]] = None
170
+
171
+ image_sizes: Tuple[int, int] = None
153
172
  image_offsets: Optional[list] = None
173
+
174
+ # the real data, pixel_values or audio_features
175
+ # data: Union[List[torch.Tensor], List[np.array]]
176
+ pixel_values: Union[torch.Tensor, np.array] = None
177
+ image_grid_thws: Union[torch.Tensor, np.array] = None
178
+ video_grid_thws: Union[torch.Tensor, np.array] = None
179
+
180
+ image_emb_mask: Optional[torch.Tensor] = None
181
+ image_spatial_crop: Optional[torch.Tensor] = None
182
+ second_per_grid_ts: Optional[List[torch.Tensor]] = None
183
+
184
+ # [num_images, (n, w, h)]
185
+ tgt_size: Tuple[int, int] = None
186
+
187
+ audio_features: Union[torch.Tensor, np.array] = None
188
+ audio_feature_lens: Optional[List[torch.Tensor]] = None
189
+
190
+ @staticmethod
191
+ def is_empty_list(l):
192
+ if l is None:
193
+ return True
194
+ return len([item for item in flatten_nested_list(l) if item is not None]) == 0
195
+
196
+ def set_pad_value(self):
197
+ """
198
+ Set the pad value after first hashign the data
199
+ """
200
+
201
+ def hash_feature(f):
202
+ if isinstance(f, list):
203
+ return hash(tuple(flatten_nested_list(f)))
204
+ elif isinstance(f, np.ndarray):
205
+ arr = np.ascontiguousarray(f)
206
+ arr_bytes = arr.tobytes()
207
+ return hash(arr_bytes)
208
+ return hash(f)
209
+
210
+ if self.is_audio():
211
+ self.hash = hash_feature(self.audio_features)
212
+ else:
213
+ self.hash = hash_feature(self.pixel_values)
214
+
215
+ assert self.hash is not None
216
+ self.pad_value = self.hash % (1 << 30)
217
+
218
+ def is_audio(self):
219
+ return (
220
+ self.modality == Modality.AUDIO
221
+ ) and not MultimodalDataItem.is_empty_list(self.audio_features)
222
+
223
+ def is_image(self):
224
+ return (
225
+ self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
226
+ ) and not MultimodalDataItem.is_empty_list(self.pixel_values)
227
+
228
+ def is_video(self):
229
+ return (
230
+ self.modality == Modality.VIDEO
231
+ ) and not MultimodalDataItem.is_empty_list(self.pixel_values)
232
+
233
+ def validate(self):
234
+ ...
235
+ # TODO
236
+
237
+
238
+ @dataclasses.dataclass
239
+ class MultimodalInputs:
240
+ """The multimodal data related inputs."""
241
+
242
+ # items of data
243
+ mm_items: List[MultimodalDataItem]
154
244
  image_pad_len: Optional[list] = None
155
- pad_values: Optional[list] = None
156
- modalities: Optional[list] = None
157
245
  num_image_tokens: Optional[int] = None
158
246
 
159
- # Llava related
160
- aspect_ratio_ids: Optional[List[torch.Tensor]] = None
161
- aspect_ratio_mask: Optional[List[torch.Tensor]] = None
162
-
163
247
  # QWen2-VL related
164
- # [num_of_images, t, h, w]
165
- image_grid_thws: torch.Tensor = None
166
248
  mrope_position_delta: Optional[torch.Tensor] = None
167
- # Qwen2-VL video related
168
- video_token_id: Optional[int] = None
169
- video_grid_thws: List[Tuple[int, int, int]] = None
170
- second_per_grid_ts: Optional[List[torch.Tensor]] = None
171
249
 
172
- # deepseek vl2 related
173
- images_emb_mask: Optional[List[torch.Tensor]] = None
174
- image_spatial_crop: Optional[List[torch.Tensor]] = None
175
-
176
- # The id of the single-image placeholder token
250
+ # image
177
251
  im_token_id: Optional[torch.Tensor] = None
178
-
179
- # All the images in the batch should share the same special image
180
- # bound token ids.
181
252
  im_start_id: Optional[int] = None
182
253
  im_end_id: Optional[int] = None
183
254
  slice_start_id: Optional[int] = None
184
255
  slice_end_id: Optional[int] = None
185
- # [num_images, 2 (w, h)]
186
- tgt_sizes: Optional[list] = None
256
+
257
+ # video
258
+ video_token_id: Optional[int] = None
187
259
 
188
260
  # audio
189
261
  audio_start_id: Optional[torch.Tensor] = None
190
262
  audio_end_id: Optional[torch.Tensor] = None
191
- audio_features: Optional[List[torch.Tensor]] = None
192
- audio_feature_lens: Optional[List[torch.Tensor]] = None
193
263
 
194
264
  @staticmethod
195
265
  def from_dict(obj: dict):
196
266
  ret = MultimodalInputs(
197
- pixel_values=obj["pixel_values"],
198
- data_hashes=obj["data_hashes"],
267
+ mm_items=obj["mm_items"],
199
268
  )
200
269
 
270
+ assert isinstance(ret.mm_items, list)
271
+ ret.mm_items = [
272
+ item
273
+ for item in ret.mm_items
274
+ if item.is_audio() or item.is_image() or item.is_video()
275
+ ]
276
+
277
+ assert len(ret.mm_items) != 0
278
+
201
279
  # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
202
280
  # Please note that if the `input_ids` is later used in the model forward,
203
281
  # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
204
282
  # errors in cuda kernels. See also llava.py for example.
205
- ret.pad_values = [x % (1 << 30) for x in ret.data_hashes]
283
+ for item in ret.mm_items:
284
+ item.set_pad_value()
206
285
 
207
286
  optional_args = [
208
- "image_sizes",
209
287
  "modalities",
210
- "aspect_ratio_ids",
211
- "aspect_ratio_mask",
212
- "image_grid_thws",
213
- "images_emb_mask",
214
- "image_spatial_crop",
215
288
  "im_token_id",
216
289
  "im_start_id",
217
290
  "im_end_id",
218
291
  "slice_start_id",
219
292
  "slice_end_id",
220
- "tgt_sizes",
221
293
  "audio_start_id",
222
294
  "audio_end_id",
223
- "audio_features",
224
- "audio_feature_lens",
225
295
  ]
226
296
  for arg in optional_args:
227
297
  if arg in obj:
228
298
  setattr(ret, arg, obj[arg])
229
299
 
230
- # validate
231
- assert (
232
- isinstance(ret.pixel_values, torch.Tensor)
233
- or isinstance(ret.pixel_values, np.ndarray)
234
- or isinstance(ret.pixel_values, list)
235
- )
236
-
237
- assert ret.audio_features is None or isinstance(ret.audio_features, list)
238
-
239
300
  return ret
240
301
 
241
302
  def contains_image_inputs(self) -> bool:
242
303
  """ """
243
- return self.pixel_values is not None and self.pixel_values != []
304
+ return any(item.is_image() for item in self.mm_items)
244
305
 
245
306
  def contains_audio_inputs(self) -> bool:
246
307
  """ """
247
- return self.audio_features is not None and self.audio_features != []
308
+ return any(item.is_audio() for item in self.mm_items)
309
+
310
+ def collect_image_inputs(self) -> List[torch.Tensor]:
311
+ return [item.pixel_values for item in self.mm_items if item.is_image()]
248
312
 
249
313
  def merge(self, other: MultimodalInputs):
250
314
  """
251
315
  merge image inputs when requests are being merged
252
316
  """
253
- if isinstance(self.pixel_values, list):
254
- # in some rare cases, pixel values are list of patches with different shapes
255
- # e.g. minicpm
256
- self.pixel_values += other.pixel_values
257
- else:
258
- assert (
259
- self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
260
- ), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
261
- self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
262
-
263
- # args would be stacked along first dim
264
- # usually these are already tensors
265
- stack_args = [
266
- # TODO: merge with image_grid_thws, basically the same thing
267
- "tgt_sizes",
268
- "image_spatial_crop",
269
- ]
270
- for arg in stack_args:
271
- if getattr(self, arg, None) is None:
272
- setattr(self, arg, getattr(other, arg, None))
273
- elif getattr(other, arg, None) is not None:
274
- # self and other both not None
275
- setattr(
276
- self,
277
- arg,
278
- torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
279
- )
280
-
281
- if self.image_grid_thws is None:
282
- self.image_grid_thws = other.image_grid_thws
283
- elif other.image_grid_thws is not None:
284
- self.image_grid_thws = torch.concat(
285
- [self.image_grid_thws, other.image_grid_thws]
286
- )
287
317
 
288
318
  # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
289
319
  # Please note that if the `input_ids` is later used in the model forward,
290
320
  # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
291
321
  # errors in cuda kernels. See also llava.py for example.
292
- self.data_hashes += other.data_hashes
293
- self.pad_values = [x % (1 << 30) for x in self.data_hashes]
294
322
 
295
323
  # args needed to be merged
296
324
  optional_args = [
297
- "audio_features",
298
- "image_sizes",
325
+ "items",
299
326
  "image_offsets",
300
327
  "image_pad_len",
301
328
  # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
302
- "aspect_ratio_ids",
303
- "aspect_ratio_mask",
304
- "images_emb_mask",
305
329
  ]
306
330
  for arg in optional_args:
307
331
  self_arg = getattr(self, arg, None)
@@ -599,6 +623,7 @@ class Req:
599
623
  self.extend_logprob_start_len = 0
600
624
  self.is_chunked = 0
601
625
  self.req_pool_idx = None
626
+ self.already_computed = 0
602
627
 
603
628
  def __repr__(self):
604
629
  return (
@@ -740,11 +765,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
740
765
  )
741
766
  return req_pool_indices
742
767
 
743
- def alloc_token_slots(self, num_tokens: int):
768
+ def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
744
769
  if self.token_to_kv_pool_allocator.available_size() < num_tokens:
745
770
  if self.tree_cache is not None:
746
771
  self.tree_cache.evict(num_tokens)
747
772
 
773
+ if backup_state:
774
+ state = self.token_to_kv_pool_allocator.backup_state()
775
+
748
776
  out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
749
777
  if out_cache_loc is None:
750
778
  phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
@@ -758,7 +786,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
758
786
  self.tree_cache.pretty_print()
759
787
  raise RuntimeError(error_msg)
760
788
 
761
- return out_cache_loc
789
+ if backup_state:
790
+ return out_cache_loc, state
791
+ else:
792
+ return out_cache_loc
762
793
 
763
794
  def alloc_paged_token_slots_extend(
764
795
  self,
@@ -766,6 +797,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
766
797
  seq_lens: torch.Tensor,
767
798
  last_loc: torch.Tensor,
768
799
  extend_num_tokens: int,
800
+ backup_state: bool = False,
769
801
  ):
770
802
  if (
771
803
  self.token_to_kv_pool_allocator.available_size()
@@ -778,6 +810,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
778
810
  + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
779
811
  )
780
812
 
813
+ if backup_state:
814
+ state = self.token_to_kv_pool_allocator.backup_state()
815
+
781
816
  out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
782
817
  prefix_lens, seq_lens, last_loc, extend_num_tokens
783
818
  )
@@ -791,23 +826,31 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
791
826
  )
792
827
  logger.error(error_msg)
793
828
  raise RuntimeError(error_msg)
794
- return out_cache_loc
829
+
830
+ if backup_state:
831
+ return out_cache_loc, state
832
+ else:
833
+ return out_cache_loc
795
834
 
796
835
  def alloc_paged_token_slots_decode(
797
836
  self,
798
837
  seq_lens: torch.Tensor,
799
838
  last_loc: torch.Tensor,
839
+ backup_state: bool = False,
800
840
  ):
801
- if (
802
- self.token_to_kv_pool_allocator.available_size()
803
- < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
804
- ):
805
- if self.tree_cache is not None:
841
+ if self.tree_cache is not None:
842
+ if (
843
+ self.token_to_kv_pool_allocator.available_size()
844
+ < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
845
+ ):
806
846
  self.tree_cache.evict(
807
847
  len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
808
848
  )
809
- out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
810
849
 
850
+ if backup_state:
851
+ state = self.token_to_kv_pool_allocator.backup_state()
852
+
853
+ out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
811
854
  if out_cache_loc is None:
812
855
  error_msg = (
813
856
  f"Decode out of memory. Try to lower your batch size.\n"
@@ -818,7 +861,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
818
861
  )
819
862
  logger.error(error_msg)
820
863
  raise RuntimeError(error_msg)
821
- return out_cache_loc
864
+
865
+ if backup_state:
866
+ return out_cache_loc, state
867
+ else:
868
+ return out_cache_loc
822
869
 
823
870
  def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
824
871
  self.encoder_lens_cpu = []
@@ -938,8 +985,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
938
985
  # If req.input_embeds is already a list, append its content directly
939
986
  input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
940
987
 
941
- if req.is_retracted:
942
- req.already_computed = 0
943
988
  req.cached_tokens += pre_len - req.already_computed
944
989
  req.already_computed = seq_len
945
990
  req.is_retracted = False
@@ -1095,17 +1140,25 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1095
1140
  # TODO (lianmin): Revisit this. It should be seq_len - 1
1096
1141
  self.extend_logprob_start_lens.extend([0] * running_bs)
1097
1142
 
1098
- def check_decode_mem(self, buf_multiplier=1):
1099
- bs = len(self.reqs) * buf_multiplier
1100
- if self.token_to_kv_pool_allocator.available_size() >= bs:
1101
- return True
1143
+ def new_page_count_next_decode(self):
1144
+ page_size = self.token_to_kv_pool_allocator.page_size
1145
+ if page_size == 1:
1146
+ return len(self.reqs)
1147
+ return sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1102
1148
 
1103
- self.tree_cache.evict(bs)
1149
+ def check_decode_mem(self, buf_multiplier=1):
1150
+ tokens_required = (
1151
+ self.new_page_count_next_decode()
1152
+ * buf_multiplier
1153
+ * self.token_to_kv_pool_allocator.page_size
1154
+ )
1104
1155
 
1105
- if self.token_to_kv_pool_allocator.available_size() >= bs:
1156
+ if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1106
1157
  return True
1107
1158
 
1108
- return False
1159
+ self.tree_cache.evict(tokens_required)
1160
+
1161
+ return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1109
1162
 
1110
1163
  def retract_decode(self, server_args: ServerArgs):
1111
1164
  """Retract the decoding requests when there is not enough memory."""
@@ -1167,7 +1220,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1167
1220
  self.req_to_token_pool.free(req.req_pool_idx)
1168
1221
  else:
1169
1222
  # TODO: apply more fine-grained retraction
1170
- last_uncached_pos = len(req.prefix_indices)
1223
+ last_uncached_pos = (
1224
+ len(req.prefix_indices) // server_args.page_size
1225
+ ) * server_args.page_size
1171
1226
  token_indices = self.req_to_token_pool.req_to_token[
1172
1227
  req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1173
1228
  ]
@@ -1373,21 +1428,25 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1373
1428
 
1374
1429
  def get_model_worker_batch(self) -> ModelWorkerBatch:
1375
1430
  if self.forward_mode.is_decode_or_idle():
1376
- if (
1377
- global_server_args_dict["enable_flashinfer_mla"]
1378
- or global_server_args_dict["enable_flashmla"]
1379
- or global_server_args_dict["attention_backend"] == "fa3"
1380
- ):
1381
- decode_seq_lens = self.seq_lens.cpu()
1382
- else:
1383
- decode_seq_lens = None
1384
1431
  extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1385
1432
  else:
1386
- decode_seq_lens = None
1387
1433
  extend_seq_lens = self.extend_lens
1388
1434
  extend_prefix_lens = self.prefix_lens
1389
1435
  extend_logprob_start_lens = self.extend_logprob_start_lens
1390
1436
 
1437
+ # Create seq_lens_cpu when needed
1438
+ if (
1439
+ (
1440
+ global_server_args_dict["use_mla_backend"]
1441
+ and global_server_args_dict["attention_backend"] == "flashinfer"
1442
+ )
1443
+ or global_server_args_dict["enable_flashmla"]
1444
+ or global_server_args_dict["attention_backend"] == "fa3"
1445
+ ):
1446
+ seq_lens_cpu = self.seq_lens.cpu()
1447
+ else:
1448
+ seq_lens_cpu = None
1449
+
1391
1450
  if self.sampling_info:
1392
1451
  if self.has_grammar:
1393
1452
  self.sampling_info.grammars = [req.grammar for req in self.reqs]
@@ -1410,7 +1469,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1410
1469
  global_num_tokens=self.global_num_tokens,
1411
1470
  global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1412
1471
  can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1413
- decode_seq_lens=decode_seq_lens,
1472
+ seq_lens_cpu=seq_lens_cpu,
1414
1473
  extend_num_tokens=self.extend_num_tokens,
1415
1474
  extend_seq_lens=extend_seq_lens,
1416
1475
  extend_prefix_lens=extend_prefix_lens,
@@ -1471,6 +1530,7 @@ class ModelWorkerBatch:
1471
1530
  req_pool_indices: torch.Tensor
1472
1531
  # The sequence length
1473
1532
  seq_lens: torch.Tensor
1533
+ seq_lens_cpu: Optional[torch.Tensor]
1474
1534
  # The indices of output tokens in the token_to_kv_pool_allocator
1475
1535
  out_cache_loc: torch.Tensor
1476
1536
 
@@ -1487,9 +1547,6 @@ class ModelWorkerBatch:
1487
1547
  global_num_tokens_for_logprob: Optional[List[int]]
1488
1548
  can_run_dp_cuda_graph: bool
1489
1549
 
1490
- # For decode
1491
- decode_seq_lens: Optional[torch.Tensor]
1492
-
1493
1550
  # For extend
1494
1551
  extend_num_tokens: Optional[int]
1495
1552
  extend_seq_lens: Optional[List[int]]
@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
112
112
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
113
113
  from sglang.srt.mem_cache.radix_cache import RadixCache
114
114
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
115
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
115
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
116
116
  from sglang.srt.server_args import PortArgs, ServerArgs
117
117
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
118
118
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -1110,7 +1110,7 @@ class Scheduler(
1110
1110
  )
1111
1111
  if memory_leak:
1112
1112
  msg = (
1113
- "KV cache pool leak detected! "
1113
+ "token_to_kv_pool_allocator memory leak detected! "
1114
1114
  f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
1115
1115
  f"{self.token_to_kv_pool_allocator.available_size()=}\n"
1116
1116
  f"{self.tree_cache.evictable_size()=}\n"
@@ -1121,7 +1121,7 @@ class Scheduler(
1121
1121
 
1122
1122
  if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1123
1123
  msg = (
1124
- "Memory pool leak detected!"
1124
+ "req_to_token_pool memory leak detected!"
1125
1125
  f"available_size={len(self.req_to_token_pool.free_slots)}, "
1126
1126
  f"total_size={self.req_to_token_pool.size}\n"
1127
1127
  )
@@ -1282,7 +1282,7 @@ class Scheduler(
1282
1282
  ]
1283
1283
 
1284
1284
  if self.enable_hierarchical_cache:
1285
- self.tree_cache.read_to_load_cache()
1285
+ self.tree_cache.ready_to_load_cache()
1286
1286
 
1287
1287
  if adder.new_chunked_req is not None:
1288
1288
  assert self.chunked_req is None
@@ -736,7 +736,7 @@ class TokenizerManager:
736
736
  self.auto_create_handle_loop()
737
737
  assert (
738
738
  self.server_args.dp_size == 1
739
- ), "dp_size must be for update weights from distributed"
739
+ ), "dp_size must be 1 for update weights from distributed"
740
740
 
741
741
  # This means that weight sync
742
742
  # cannot run while requests are in progress.
@@ -1,11 +1,6 @@
1
- import json
2
1
  import logging
3
- import time
4
- from collections import defaultdict
5
2
  from http import HTTPStatus
6
- from typing import Dict, List, Optional, Tuple
7
-
8
- import torch
3
+ from typing import Optional
9
4
 
10
5
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
11
6