sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from enum import Enum, auto
4
+
3
5
  # Copyright 2023-2024 SGLang Team
4
6
  # Licensed under the Apache License, Version 2.0 (the "License");
5
7
  # you may not use this file except in compliance with the License.
@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
51
53
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
52
54
  from sglang.srt.sampling.sampling_params import SamplingParams
53
55
  from sglang.srt.server_args import ServerArgs
54
- from sglang.srt.utils import get_compiler_backend
56
+ from sglang.srt.utils import flatten_nested_list, get_compiler_backend
55
57
 
56
58
  if TYPE_CHECKING:
57
59
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -70,14 +72,16 @@ global_server_args_dict = {
70
72
  "enable_dp_attention": ServerArgs.enable_dp_attention,
71
73
  "enable_ep_moe": ServerArgs.enable_ep_moe,
72
74
  "enable_deepep_moe": ServerArgs.enable_deepep_moe,
75
+ "deepep_mode": ServerArgs.deepep_mode,
73
76
  "device": ServerArgs.device,
74
77
  "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
75
78
  "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
76
- "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
77
79
  "enable_flashmla": ServerArgs.enable_flashmla,
78
80
  "disable_radix_cache": ServerArgs.disable_radix_cache,
79
81
  "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
80
82
  "chunked_prefill_size": ServerArgs.chunked_prefill_size,
83
+ "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
84
+ "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
81
85
  }
82
86
 
83
87
  logger = logging.getLogger(__name__)
@@ -143,165 +147,185 @@ class FINISH_ABORT(BaseFinishReason):
143
147
  }
144
148
 
145
149
 
150
+ class Modality(Enum):
151
+ IMAGE = auto()
152
+ MULTI_IMAGES = auto()
153
+ VIDEO = auto()
154
+ AUDIO = auto()
155
+
156
+
146
157
  @dataclasses.dataclass
147
- class MultimodalInputs:
148
- """The image related inputs."""
158
+ class MultimodalDataItem:
159
+ """
160
+ A single multimodal data, from a single image/video/audio or other
161
+ """
162
+
163
+ modality: Modality
164
+
165
+ hash: int = None
166
+ pad_value: int = None
149
167
 
150
- pixel_values: Union[torch.Tensor, np.array]
151
- data_hashes: Optional[list] = None
152
- image_sizes: Optional[list] = None
168
+ aspect_ratio_id: Optional[List[torch.Tensor]] = None
169
+ aspect_ratio_mask: Optional[List[torch.Tensor]] = None
170
+
171
+ image_sizes: Tuple[int, int] = None
153
172
  image_offsets: Optional[list] = None
173
+
174
+ # the real data, pixel_values or audio_features
175
+ # data: Union[List[torch.Tensor], List[np.array]]
176
+ pixel_values: Union[torch.Tensor, np.array] = None
177
+ image_grid_thws: Union[torch.Tensor, np.array] = None
178
+ video_grid_thws: Union[torch.Tensor, np.array] = None
179
+
180
+ image_emb_mask: Optional[torch.Tensor] = None
181
+ image_spatial_crop: Optional[torch.Tensor] = None
182
+ second_per_grid_ts: Optional[List[torch.Tensor]] = None
183
+
184
+ # [num_images, (n, w, h)]
185
+ tgt_size: Tuple[int, int] = None
186
+
187
+ audio_features: Union[torch.Tensor, np.array] = None
188
+ audio_feature_lens: Optional[List[torch.Tensor]] = None
189
+
190
+ @staticmethod
191
+ def is_empty_list(l):
192
+ if l is None:
193
+ return True
194
+ return len([item for item in flatten_nested_list(l) if item is not None]) == 0
195
+
196
+ def set_pad_value(self):
197
+ """
198
+ Set the pad value after first hashign the data
199
+ """
200
+
201
+ def hash_feature(f):
202
+ if isinstance(f, list):
203
+ return hash(tuple(flatten_nested_list(f)))
204
+ elif isinstance(f, np.ndarray):
205
+ arr = np.ascontiguousarray(f)
206
+ arr_bytes = arr.tobytes()
207
+ return hash(arr_bytes)
208
+ return hash(f)
209
+
210
+ if self.is_audio():
211
+ self.hash = hash_feature(self.audio_features)
212
+ else:
213
+ self.hash = hash_feature(self.pixel_values)
214
+
215
+ assert self.hash is not None
216
+ self.pad_value = self.hash % (1 << 30)
217
+
218
+ def is_audio(self):
219
+ return (
220
+ self.modality == Modality.AUDIO
221
+ ) and not MultimodalDataItem.is_empty_list(self.audio_features)
222
+
223
+ def is_image(self):
224
+ return (
225
+ self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
226
+ ) and not MultimodalDataItem.is_empty_list(self.pixel_values)
227
+
228
+ def is_video(self):
229
+ return (
230
+ self.modality == Modality.VIDEO
231
+ ) and not MultimodalDataItem.is_empty_list(self.pixel_values)
232
+
233
+ def validate(self):
234
+ ...
235
+ # TODO
236
+
237
+
238
+ @dataclasses.dataclass
239
+ class MultimodalInputs:
240
+ """The multimodal data related inputs."""
241
+
242
+ # items of data
243
+ mm_items: List[MultimodalDataItem]
154
244
  image_pad_len: Optional[list] = None
155
- pad_values: Optional[list] = None
156
- modalities: Optional[list] = None
157
245
  num_image_tokens: Optional[int] = None
158
246
 
159
- # Llava related
160
- aspect_ratio_ids: Optional[List[torch.Tensor]] = None
161
- aspect_ratio_mask: Optional[List[torch.Tensor]] = None
162
-
163
247
  # QWen2-VL related
164
- # [num_of_images, t, h, w]
165
- image_grid_thws: torch.Tensor = None
166
248
  mrope_position_delta: Optional[torch.Tensor] = None
167
- # Qwen2-VL video related
168
- video_token_id: Optional[int] = None
169
- video_grid_thws: List[Tuple[int, int, int]] = None
170
- second_per_grid_ts: Optional[List[torch.Tensor]] = None
171
249
 
172
- # deepseek vl2 related
173
- images_emb_mask: Optional[List[torch.Tensor]] = None
174
- image_spatial_crop: Optional[List[torch.Tensor]] = None
175
-
176
- # The id of the single-image placeholder token
250
+ # image
177
251
  im_token_id: Optional[torch.Tensor] = None
178
-
179
- # All the images in the batch should share the same special image
180
- # bound token ids.
181
252
  im_start_id: Optional[int] = None
182
253
  im_end_id: Optional[int] = None
183
254
  slice_start_id: Optional[int] = None
184
255
  slice_end_id: Optional[int] = None
185
- # [num_images, 2 (w, h)]
186
- tgt_sizes: Optional[list] = None
256
+
257
+ # video
258
+ video_token_id: Optional[int] = None
187
259
 
188
260
  # audio
189
261
  audio_start_id: Optional[torch.Tensor] = None
190
262
  audio_end_id: Optional[torch.Tensor] = None
191
- audio_features: Optional[List[torch.Tensor]] = None
192
- audio_feature_lens: Optional[List[torch.Tensor]] = None
193
263
 
194
264
  @staticmethod
195
265
  def from_dict(obj: dict):
196
266
  ret = MultimodalInputs(
197
- pixel_values=obj["pixel_values"],
198
- data_hashes=obj["data_hashes"],
267
+ mm_items=obj["mm_items"],
199
268
  )
200
269
 
270
+ assert isinstance(ret.mm_items, list)
271
+ ret.mm_items = [
272
+ item
273
+ for item in ret.mm_items
274
+ if item.is_audio() or item.is_image() or item.is_video()
275
+ ]
276
+
277
+ assert len(ret.mm_items) != 0
278
+
201
279
  # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
202
280
  # Please note that if the `input_ids` is later used in the model forward,
203
281
  # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
204
282
  # errors in cuda kernels. See also llava.py for example.
205
- ret.pad_values = [x % (1 << 30) for x in ret.data_hashes]
283
+ for item in ret.mm_items:
284
+ item.set_pad_value()
206
285
 
207
286
  optional_args = [
208
- "image_sizes",
209
287
  "modalities",
210
- "aspect_ratio_ids",
211
- "aspect_ratio_mask",
212
- "image_grid_thws",
213
- "images_emb_mask",
214
- "image_spatial_crop",
215
288
  "im_token_id",
216
289
  "im_start_id",
217
290
  "im_end_id",
218
291
  "slice_start_id",
219
292
  "slice_end_id",
220
- "tgt_sizes",
221
293
  "audio_start_id",
222
294
  "audio_end_id",
223
- "audio_features",
224
- "audio_feature_lens",
225
295
  ]
226
296
  for arg in optional_args:
227
297
  if arg in obj:
228
298
  setattr(ret, arg, obj[arg])
229
299
 
230
- # validate
231
- assert (
232
- isinstance(ret.pixel_values, torch.Tensor)
233
- or isinstance(ret.pixel_values, np.ndarray)
234
- or isinstance(ret.pixel_values, list)
235
- )
236
-
237
- assert ret.audio_features is None or isinstance(ret.audio_features, list)
238
-
239
300
  return ret
240
301
 
241
302
  def contains_image_inputs(self) -> bool:
242
303
  """ """
243
- return self.pixel_values is not None and self.pixel_values != []
304
+ return any(item.is_image() for item in self.mm_items)
244
305
 
245
306
  def contains_audio_inputs(self) -> bool:
246
307
  """ """
247
- return self.audio_features is not None and self.audio_features != []
308
+ return any(item.is_audio() for item in self.mm_items)
309
+
310
+ def collect_image_inputs(self) -> List[torch.Tensor]:
311
+ return [item.pixel_values for item in self.mm_items if item.is_image()]
248
312
 
249
313
  def merge(self, other: MultimodalInputs):
250
314
  """
251
315
  merge image inputs when requests are being merged
252
316
  """
253
- if isinstance(self.pixel_values, list):
254
- # in some rare cases, pixel values are list of patches with different shapes
255
- # e.g. minicpm
256
- self.pixel_values += other.pixel_values
257
- else:
258
- assert (
259
- self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
260
- ), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
261
- self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
262
-
263
- # args would be stacked along first dim
264
- # usually these are already tensors
265
- stack_args = [
266
- # TODO: merge with image_grid_thws, basically the same thing
267
- "tgt_sizes",
268
- "image_spatial_crop",
269
- ]
270
- for arg in stack_args:
271
- if getattr(self, arg, None) is None:
272
- setattr(self, arg, getattr(other, arg, None))
273
- elif getattr(other, arg, None) is not None:
274
- # self and other both not None
275
- setattr(
276
- self,
277
- arg,
278
- torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
279
- )
280
-
281
- if self.image_grid_thws is None:
282
- self.image_grid_thws = other.image_grid_thws
283
- elif other.image_grid_thws is not None:
284
- self.image_grid_thws = torch.concat(
285
- [self.image_grid_thws, other.image_grid_thws]
286
- )
287
317
 
288
318
  # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
289
319
  # Please note that if the `input_ids` is later used in the model forward,
290
320
  # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
291
321
  # errors in cuda kernels. See also llava.py for example.
292
- self.data_hashes += other.data_hashes
293
- self.pad_values = [x % (1 << 30) for x in self.data_hashes]
294
322
 
295
323
  # args needed to be merged
296
324
  optional_args = [
297
- "audio_features",
298
- "image_sizes",
325
+ "items",
299
326
  "image_offsets",
300
327
  "image_pad_len",
301
328
  # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
302
- "aspect_ratio_ids",
303
- "aspect_ratio_mask",
304
- "images_emb_mask",
305
329
  ]
306
330
  for arg in optional_args:
307
331
  self_arg = getattr(self, arg, None)
@@ -599,6 +623,7 @@ class Req:
599
623
  self.extend_logprob_start_len = 0
600
624
  self.is_chunked = 0
601
625
  self.req_pool_idx = None
626
+ self.already_computed = 0
602
627
 
603
628
  def __repr__(self):
604
629
  return (
@@ -740,11 +765,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
740
765
  )
741
766
  return req_pool_indices
742
767
 
743
- def alloc_token_slots(self, num_tokens: int):
768
+ def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
744
769
  if self.token_to_kv_pool_allocator.available_size() < num_tokens:
745
770
  if self.tree_cache is not None:
746
771
  self.tree_cache.evict(num_tokens)
747
772
 
773
+ if backup_state:
774
+ state = self.token_to_kv_pool_allocator.backup_state()
775
+
748
776
  out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
749
777
  if out_cache_loc is None:
750
778
  phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
@@ -758,7 +786,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
758
786
  self.tree_cache.pretty_print()
759
787
  raise RuntimeError(error_msg)
760
788
 
761
- return out_cache_loc
789
+ if backup_state:
790
+ return out_cache_loc, state
791
+ else:
792
+ return out_cache_loc
762
793
 
763
794
  def alloc_paged_token_slots_extend(
764
795
  self,
@@ -766,6 +797,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
766
797
  seq_lens: torch.Tensor,
767
798
  last_loc: torch.Tensor,
768
799
  extend_num_tokens: int,
800
+ backup_state: bool = False,
769
801
  ):
770
802
  if (
771
803
  self.token_to_kv_pool_allocator.available_size()
@@ -778,6 +810,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
778
810
  + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
779
811
  )
780
812
 
813
+ if backup_state:
814
+ state = self.token_to_kv_pool_allocator.backup_state()
815
+
781
816
  out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
782
817
  prefix_lens, seq_lens, last_loc, extend_num_tokens
783
818
  )
@@ -791,23 +826,31 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
791
826
  )
792
827
  logger.error(error_msg)
793
828
  raise RuntimeError(error_msg)
794
- return out_cache_loc
829
+
830
+ if backup_state:
831
+ return out_cache_loc, state
832
+ else:
833
+ return out_cache_loc
795
834
 
796
835
  def alloc_paged_token_slots_decode(
797
836
  self,
798
837
  seq_lens: torch.Tensor,
799
838
  last_loc: torch.Tensor,
839
+ backup_state: bool = False,
800
840
  ):
801
- if (
802
- self.token_to_kv_pool_allocator.available_size()
803
- < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
804
- ):
805
- if self.tree_cache is not None:
841
+ if self.tree_cache is not None:
842
+ if (
843
+ self.token_to_kv_pool_allocator.available_size()
844
+ < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
845
+ ):
806
846
  self.tree_cache.evict(
807
847
  len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
808
848
  )
809
- out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
810
849
 
850
+ if backup_state:
851
+ state = self.token_to_kv_pool_allocator.backup_state()
852
+
853
+ out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
811
854
  if out_cache_loc is None:
812
855
  error_msg = (
813
856
  f"Decode out of memory. Try to lower your batch size.\n"
@@ -818,7 +861,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
818
861
  )
819
862
  logger.error(error_msg)
820
863
  raise RuntimeError(error_msg)
821
- return out_cache_loc
864
+
865
+ if backup_state:
866
+ return out_cache_loc, state
867
+ else:
868
+ return out_cache_loc
822
869
 
823
870
  def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
824
871
  self.encoder_lens_cpu = []
@@ -938,8 +985,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
938
985
  # If req.input_embeds is already a list, append its content directly
939
986
  input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
940
987
 
941
- if req.is_retracted:
942
- req.already_computed = 0
943
988
  req.cached_tokens += pre_len - req.already_computed
944
989
  req.already_computed = seq_len
945
990
  req.is_retracted = False
@@ -1095,17 +1140,25 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1095
1140
  # TODO (lianmin): Revisit this. It should be seq_len - 1
1096
1141
  self.extend_logprob_start_lens.extend([0] * running_bs)
1097
1142
 
1098
- def check_decode_mem(self, buf_multiplier=1):
1099
- bs = len(self.reqs) * buf_multiplier
1100
- if self.token_to_kv_pool_allocator.available_size() >= bs:
1101
- return True
1143
+ def new_page_count_next_decode(self):
1144
+ page_size = self.token_to_kv_pool_allocator.page_size
1145
+ if page_size == 1:
1146
+ return len(self.reqs)
1147
+ return sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1102
1148
 
1103
- self.tree_cache.evict(bs)
1149
+ def check_decode_mem(self, buf_multiplier=1):
1150
+ tokens_required = (
1151
+ self.new_page_count_next_decode()
1152
+ * buf_multiplier
1153
+ * self.token_to_kv_pool_allocator.page_size
1154
+ )
1104
1155
 
1105
- if self.token_to_kv_pool_allocator.available_size() >= bs:
1156
+ if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1106
1157
  return True
1107
1158
 
1108
- return False
1159
+ self.tree_cache.evict(tokens_required)
1160
+
1161
+ return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1109
1162
 
1110
1163
  def retract_decode(self, server_args: ServerArgs):
1111
1164
  """Retract the decoding requests when there is not enough memory."""
@@ -1167,7 +1220,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1167
1220
  self.req_to_token_pool.free(req.req_pool_idx)
1168
1221
  else:
1169
1222
  # TODO: apply more fine-grained retraction
1170
- last_uncached_pos = len(req.prefix_indices)
1223
+ last_uncached_pos = (
1224
+ len(req.prefix_indices) // server_args.page_size
1225
+ ) * server_args.page_size
1171
1226
  token_indices = self.req_to_token_pool.req_to_token[
1172
1227
  req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1173
1228
  ]
@@ -1373,20 +1428,25 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1373
1428
 
1374
1429
  def get_model_worker_batch(self) -> ModelWorkerBatch:
1375
1430
  if self.forward_mode.is_decode_or_idle():
1376
- if (
1377
- global_server_args_dict["enable_flashinfer_mla"]
1378
- or global_server_args_dict["enable_flashmla"]
1379
- ):
1380
- decode_seq_lens = self.seq_lens.cpu()
1381
- else:
1382
- decode_seq_lens = None
1383
1431
  extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1384
1432
  else:
1385
- decode_seq_lens = None
1386
1433
  extend_seq_lens = self.extend_lens
1387
1434
  extend_prefix_lens = self.prefix_lens
1388
1435
  extend_logprob_start_lens = self.extend_logprob_start_lens
1389
1436
 
1437
+ # Create seq_lens_cpu when needed
1438
+ if (
1439
+ (
1440
+ global_server_args_dict["use_mla_backend"]
1441
+ and global_server_args_dict["attention_backend"] == "flashinfer"
1442
+ )
1443
+ or global_server_args_dict["enable_flashmla"]
1444
+ or global_server_args_dict["attention_backend"] == "fa3"
1445
+ ):
1446
+ seq_lens_cpu = self.seq_lens.cpu()
1447
+ else:
1448
+ seq_lens_cpu = None
1449
+
1390
1450
  if self.sampling_info:
1391
1451
  if self.has_grammar:
1392
1452
  self.sampling_info.grammars = [req.grammar for req in self.reqs]
@@ -1409,7 +1469,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1409
1469
  global_num_tokens=self.global_num_tokens,
1410
1470
  global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1411
1471
  can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1412
- decode_seq_lens=decode_seq_lens,
1472
+ seq_lens_cpu=seq_lens_cpu,
1413
1473
  extend_num_tokens=self.extend_num_tokens,
1414
1474
  extend_seq_lens=extend_seq_lens,
1415
1475
  extend_prefix_lens=extend_prefix_lens,
@@ -1470,6 +1530,7 @@ class ModelWorkerBatch:
1470
1530
  req_pool_indices: torch.Tensor
1471
1531
  # The sequence length
1472
1532
  seq_lens: torch.Tensor
1533
+ seq_lens_cpu: Optional[torch.Tensor]
1473
1534
  # The indices of output tokens in the token_to_kv_pool_allocator
1474
1535
  out_cache_loc: torch.Tensor
1475
1536
 
@@ -1486,9 +1547,6 @@ class ModelWorkerBatch:
1486
1547
  global_num_tokens_for_logprob: Optional[List[int]]
1487
1548
  can_run_dp_cuda_graph: bool
1488
1549
 
1489
- # For decode
1490
- decode_seq_lens: Optional[torch.Tensor]
1491
-
1492
1550
  # For extend
1493
1551
  extend_num_tokens: Optional[int]
1494
1552
  extend_seq_lens: Optional[List[int]]