sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -145,6 +145,7 @@ class SchedulerStats:
145
145
  num_prefill_infight_queue_reqs: int = 0
146
146
  num_decode_prealloc_queue_reqs: int = 0
147
147
  num_decode_transfer_queue_reqs: int = 0
148
+ total_retracted_reqs: int = 0
148
149
 
149
150
 
150
151
  class SchedulerMetricsCollector:
@@ -219,6 +220,13 @@ class SchedulerMetricsCollector:
219
220
  multiprocess_mode="mostrecent",
220
221
  )
221
222
 
223
+ self.total_retracted_reqs = Gauge(
224
+ name="sglang:total_retracted_reqs",
225
+ documentation="The total number of retracted requests due to kvcache full.",
226
+ labelnames=labels.keys(),
227
+ multiprocess_mode="mostrecent",
228
+ )
229
+
222
230
  # Disaggregation queue metrics
223
231
  self.num_prefill_prealloc_queue_reqs = Gauge(
224
232
  name="sglang:num_prefill_prealloc_queue_reqs",
@@ -279,6 +287,7 @@ class SchedulerMetricsCollector:
279
287
  self._log_gauge(self.num_grammar_queue_reqs, stats.num_grammar_queue_reqs)
280
288
  self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
281
289
  self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
290
+ self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs)
282
291
 
283
292
  # Disaggregation metrics
284
293
  self._log_gauge(
@@ -29,9 +29,9 @@ from torch.profiler import ProfilerActivity, profile
29
29
  from sglang.srt.custom_op import CustomOp
30
30
  from sglang.srt.distributed import get_tensor_model_parallel_rank
31
31
  from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
32
+ from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
32
33
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
33
34
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
34
- from sglang.srt.managers.schedule_batch import global_server_args_dict
35
35
  from sglang.srt.model_executor.forward_batch_info import (
36
36
  CaptureHiddenMode,
37
37
  ForwardBatch,
@@ -167,8 +167,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
167
167
  # is very small. We add more values here to make sure we capture the maximum bs.
168
168
  capture_bs += [model_runner.req_to_token_pool.size]
169
169
 
170
+ mul_base = 1
171
+
170
172
  if server_args.enable_two_batch_overlap:
171
- capture_bs = [bs for bs in capture_bs if bs % 2 == 0]
173
+ mul_base *= 2
174
+
175
+ if require_gathered_buffer(server_args):
176
+ mul_base *= get_attention_tp_size()
177
+
178
+ capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
172
179
 
173
180
  if server_args.cuda_graph_max_bs:
174
181
  capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
@@ -264,7 +271,7 @@ class CudaGraphRunner:
264
271
  if self.enable_torch_compile:
265
272
  set_torch_compile_config()
266
273
 
267
- if self.model_runner.server_args.lora_paths is not None:
274
+ if self.model_runner.server_args.enable_lora:
268
275
  self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
269
276
 
270
277
  # Graph inputs
@@ -306,20 +313,37 @@ class CudaGraphRunner:
306
313
  self.encoder_lens = None
307
314
 
308
315
  if self.require_gathered_buffer:
309
- self.gathered_buffer = torch.zeros(
310
- (
311
- self.max_num_token,
312
- self.model_runner.model_config.hidden_size,
313
- ),
314
- dtype=self.model_runner.dtype,
315
- )
316
316
  if self.require_mlp_tp_gather:
317
317
  self.global_num_tokens_gpu = torch.zeros(
318
318
  (self.dp_size,), dtype=torch.int32
319
319
  )
320
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
321
+ (self.dp_size,), dtype=torch.int32
322
+ )
323
+ self.gathered_buffer = torch.zeros(
324
+ (
325
+ self.max_num_token * self.dp_size,
326
+ self.model_runner.model_config.hidden_size,
327
+ ),
328
+ dtype=self.model_runner.dtype,
329
+ )
320
330
  else:
321
331
  assert self.require_attn_tp_gather
322
332
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
333
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
334
+ (1,), dtype=torch.int32
335
+ )
336
+ self.gathered_buffer = torch.zeros(
337
+ (
338
+ self.max_num_token,
339
+ self.model_runner.model_config.hidden_size,
340
+ ),
341
+ dtype=self.model_runner.dtype,
342
+ )
343
+ else:
344
+ self.global_num_tokens_gpu = None
345
+ self.global_num_tokens_for_logprob_gpu = None
346
+ self.gathered_buffer = None
323
347
 
324
348
  self.custom_mask = torch.ones(
325
349
  (
@@ -342,9 +366,9 @@ class CudaGraphRunner:
342
366
  def can_run(self, forward_batch: ForwardBatch):
343
367
  if self.require_mlp_tp_gather:
344
368
  cuda_graph_bs = (
345
- sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
369
+ max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
346
370
  if self.model_runner.spec_algorithm.is_eagle()
347
- else sum(forward_batch.global_num_tokens_cpu)
371
+ else max(forward_batch.global_num_tokens_cpu)
348
372
  )
349
373
  else:
350
374
  cuda_graph_bs = forward_batch.batch_size
@@ -480,16 +504,19 @@ class CudaGraphRunner:
480
504
  if self.require_mlp_tp_gather:
481
505
  self.global_num_tokens_gpu.copy_(
482
506
  torch.tensor(
483
- [
484
- num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
485
- for i in range(self.dp_size)
486
- ],
507
+ [num_tokens] * self.dp_size,
487
508
  dtype=torch.int32,
488
509
  device=input_ids.device,
489
510
  )
490
511
  )
491
- global_num_tokens = self.global_num_tokens_gpu
492
- gathered_buffer = self.gathered_buffer[:num_tokens]
512
+ self.global_num_tokens_for_logprob_gpu.copy_(
513
+ torch.tensor(
514
+ [num_tokens] * self.dp_size,
515
+ dtype=torch.int32,
516
+ device=input_ids.device,
517
+ )
518
+ )
519
+ gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
493
520
  elif self.require_attn_tp_gather:
494
521
  self.global_num_tokens_gpu.copy_(
495
522
  torch.tensor(
@@ -498,10 +525,15 @@ class CudaGraphRunner:
498
525
  device=input_ids.device,
499
526
  )
500
527
  )
501
- global_num_tokens = self.global_num_tokens_gpu
528
+ self.global_num_tokens_for_logprob_gpu.copy_(
529
+ torch.tensor(
530
+ [num_tokens],
531
+ dtype=torch.int32,
532
+ device=input_ids.device,
533
+ )
534
+ )
502
535
  gathered_buffer = self.gathered_buffer[:num_tokens]
503
536
  else:
504
- global_num_tokens = None
505
537
  gathered_buffer = None
506
538
 
507
539
  spec_info = self.get_spec_info(num_tokens)
@@ -510,11 +542,10 @@ class CudaGraphRunner:
510
542
  spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
511
543
  )
512
544
 
513
- if self.model_runner.server_args.lora_paths is not None:
514
- # Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
515
- # different logic to handle lora, so we need to set `lora_paths` to a list of non-None
516
- # values if lora is enabled.
517
- lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
545
+ if self.model_runner.server_args.enable_lora:
546
+ # It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever
547
+ # `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization).
548
+ lora_paths = [None] * bs
518
549
  else:
519
550
  lora_paths = None
520
551
 
@@ -532,7 +563,9 @@ class CudaGraphRunner:
532
563
  encoder_lens=encoder_lens,
533
564
  return_logprob=False,
534
565
  positions=positions,
535
- global_num_tokens_gpu=global_num_tokens,
566
+ global_num_tokens_gpu=self.global_num_tokens_gpu,
567
+ global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
568
+ dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
536
569
  gathered_buffer=gathered_buffer,
537
570
  mrope_positions=mrope_positions,
538
571
  spec_algorithm=self.model_runner.spec_algorithm,
@@ -636,12 +669,13 @@ class CudaGraphRunner:
636
669
 
637
670
  # Pad
638
671
  if self.require_mlp_tp_gather:
639
- total_batch_size = (
640
- sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
672
+ max_num_tokens = max(forward_batch.global_num_tokens_cpu)
673
+ max_batch_size = (
674
+ max_num_tokens / self.num_tokens_per_bs
641
675
  if self.model_runner.spec_algorithm.is_eagle()
642
- else sum(forward_batch.global_num_tokens_cpu)
676
+ else max_num_tokens
643
677
  )
644
- index = bisect.bisect_left(self.capture_bs, total_batch_size)
678
+ index = bisect.bisect_left(self.capture_bs, max_batch_size)
645
679
  else:
646
680
  index = bisect.bisect_left(self.capture_bs, raw_bs)
647
681
  bs = self.capture_bs[index]
@@ -671,7 +705,8 @@ class CudaGraphRunner:
671
705
  if forward_batch.mrope_positions is not None:
672
706
  self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
673
707
  if self.require_gathered_buffer:
674
- self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
708
+ self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
709
+ self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
675
710
  if enable_num_token_non_padded(self.model_runner.server_args):
676
711
  self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
677
712
  if self.enable_two_batch_overlap:
@@ -38,6 +38,11 @@ import torch
38
38
  import triton
39
39
  import triton.language as tl
40
40
 
41
+ from sglang.srt.layers.dp_attention import (
42
+ DPPaddingMode,
43
+ get_attention_dp_rank,
44
+ get_attention_tp_size,
45
+ )
41
46
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
42
47
  from sglang.srt.utils import (
43
48
  flatten_nested_list,
@@ -48,6 +53,7 @@ from sglang.srt.utils import (
48
53
 
49
54
  if TYPE_CHECKING:
50
55
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
56
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
51
57
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
52
58
  from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
53
59
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -78,6 +84,9 @@ class ForwardMode(IntEnum):
78
84
  # It is now used for triggering the sampling_info_done event for the first prefill batch.
79
85
  DUMMY_FIRST = auto()
80
86
 
87
+ # Split Prefill for PD multiplexing
88
+ SPLIT_PREFILL = auto()
89
+
81
90
  def is_prefill(self):
82
91
  return self.is_extend()
83
92
 
@@ -98,6 +107,9 @@ class ForwardMode(IntEnum):
98
107
  def is_idle(self):
99
108
  return self == ForwardMode.IDLE
100
109
 
110
+ def is_decode_or_idle(self):
111
+ return self == ForwardMode.DECODE or self == ForwardMode.IDLE
112
+
101
113
  def is_target_verify(self):
102
114
  return self == ForwardMode.TARGET_VERIFY
103
115
 
@@ -121,8 +133,8 @@ class ForwardMode(IntEnum):
121
133
  def is_dummy_first(self):
122
134
  return self == ForwardMode.DUMMY_FIRST
123
135
 
124
- def is_decode_or_idle(self):
125
- return self == ForwardMode.DECODE or self == ForwardMode.IDLE
136
+ def is_split_prefill(self):
137
+ return self == ForwardMode.SPLIT_PREFILL
126
138
 
127
139
 
128
140
  @total_ordering
@@ -194,6 +206,14 @@ class ForwardBatch:
194
206
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
195
207
  extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
196
208
 
209
+ # For split prefill
210
+ # intermediate values for split prefill
211
+ hidden_states: torch.Tensor = None
212
+ residual: torch.Tensor = None
213
+ model_specific_states: Dict[str, any] = None
214
+ # current split index of layer
215
+ split_index: int = 0
216
+
197
217
  # For MLA chunked prefix cache used in chunked prefill
198
218
  # Tell attention backend whether the kv cache needs to be attended in current pass
199
219
  attn_attend_prefix_cache: Optional[bool] = None
@@ -229,7 +249,7 @@ class ForwardBatch:
229
249
  lora_paths: Optional[List[str]] = None
230
250
 
231
251
  # For input embeddings
232
- input_embeds: Optional[torch.tensor] = None
252
+ input_embeds: Optional[torch.Tensor] = None
233
253
 
234
254
  # For cross-encoder model
235
255
  token_type_ids: Optional[torch.Tensor] = None
@@ -248,6 +268,8 @@ class ForwardBatch:
248
268
  # Has to be None when cuda graph is captured.
249
269
  global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
250
270
  global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
271
+ # The padding mode for DP attention
272
+ dp_padding_mode: Optional[DPPaddingMode] = None
251
273
  # for extend, local start pos and num tokens is different in logits processor
252
274
  # this will be computed in get_dp_local_info
253
275
  # this will be recomputed in LogitsMetadata.from_forward_batch
@@ -273,7 +295,7 @@ class ForwardBatch:
273
295
  # For two-batch overlap
274
296
  tbo_split_seq_index: Optional[int] = None
275
297
  tbo_parent_token_range: Optional[Tuple[int, int]] = None
276
- tbo_children: Optional[List["ForwardBatch"]] = None
298
+ tbo_children: Optional[List[ForwardBatch]] = None
277
299
 
278
300
  @classmethod
279
301
  def init_new(
@@ -327,20 +349,38 @@ class ForwardBatch:
327
349
  len(batch.input_ids), dtype=torch.int32
328
350
  ).to(device, non_blocking=True)
329
351
 
330
- # For DP attention
352
+ # For MLP sync
331
353
  if batch.global_num_tokens is not None:
332
-
333
- spec_num_draft_tokens = (
334
- batch.spec_num_draft_tokens
335
- if batch.spec_num_draft_tokens is not None
336
- else 1
354
+ from sglang.srt.speculative.eagle_utils import (
355
+ EagleDraftInput,
356
+ EagleVerifyInput,
337
357
  )
338
- global_num_tokens = [
339
- x * spec_num_draft_tokens for x in batch.global_num_tokens
340
- ]
341
- global_num_tokens_for_logprob = [
342
- x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
343
- ]
358
+
359
+ assert batch.global_num_tokens_for_logprob is not None
360
+ # process global_num_tokens and global_num_tokens_for_logprob
361
+ if batch.spec_info is not None:
362
+ if isinstance(batch.spec_info, EagleDraftInput):
363
+ global_num_tokens = [
364
+ x * batch.spec_info.num_tokens_per_batch
365
+ for x in batch.global_num_tokens
366
+ ]
367
+ global_num_tokens_for_logprob = [
368
+ x * batch.spec_info.num_tokens_for_logprob_per_batch
369
+ for x in batch.global_num_tokens_for_logprob
370
+ ]
371
+ else:
372
+ assert isinstance(batch.spec_info, EagleVerifyInput)
373
+ global_num_tokens = [
374
+ x * batch.spec_info.draft_token_num
375
+ for x in batch.global_num_tokens
376
+ ]
377
+ global_num_tokens_for_logprob = [
378
+ x * batch.spec_info.draft_token_num
379
+ for x in batch.global_num_tokens_for_logprob
380
+ ]
381
+ else:
382
+ global_num_tokens = batch.global_num_tokens
383
+ global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
344
384
 
345
385
  ret.global_num_tokens_cpu = global_num_tokens
346
386
  ret.global_num_tokens_gpu = torch.tensor(
@@ -352,15 +392,8 @@ class ForwardBatch:
352
392
  global_num_tokens_for_logprob, dtype=torch.int64
353
393
  ).to(device, non_blocking=True)
354
394
 
355
- sum_len = sum(global_num_tokens)
356
- ret.gathered_buffer = torch.zeros(
357
- (sum_len, model_runner.model_config.hidden_size),
358
- dtype=model_runner.dtype,
359
- device=device,
360
- )
361
-
362
395
  if ret.forward_mode.is_idle():
363
- ret.positions = torch.empty((0,), device=device)
396
+ ret.positions = torch.empty((0,), dtype=torch.int64, device=device)
364
397
  TboForwardBatchPreparer.prepare(
365
398
  ret, is_draft_worker=model_runner.is_draft_worker
366
399
  )
@@ -405,7 +438,7 @@ class ForwardBatch:
405
438
  ret._compute_mrope_positions(model_runner, batch)
406
439
 
407
440
  # Init lora information
408
- if model_runner.server_args.lora_paths is not None:
441
+ if model_runner.server_args.enable_lora:
409
442
  model_runner.lora_manager.prepare_lora_batch(ret)
410
443
 
411
444
  TboForwardBatchPreparer.prepare(
@@ -560,6 +593,158 @@ class ForwardBatch:
560
593
  )
561
594
  self.prefix_chunk_kv_indices.append(chunk_kv_indices)
562
595
 
596
+ def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0):
597
+ if value == 0:
598
+ return torch.cat(
599
+ [tensor, tensor.new_zeros(size - tensor.shape[0], *tensor.shape[1:])],
600
+ dim=0,
601
+ )
602
+ else:
603
+ return torch.cat(
604
+ [
605
+ tensor,
606
+ tensor.new_full((size - tensor.shape[0], *tensor.shape[1:]), value),
607
+ ],
608
+ dim=0,
609
+ )
610
+
611
+ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
612
+
613
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput
614
+
615
+ assert self.global_num_tokens_cpu is not None
616
+ assert self.global_num_tokens_for_logprob_cpu is not None
617
+
618
+ global_num_tokens = self.global_num_tokens_cpu
619
+ sync_group_size = len(global_num_tokens)
620
+ attn_tp_size = get_attention_tp_size()
621
+
622
+ for i in range(sync_group_size):
623
+ # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
624
+ # there is no reduce-scatter in LM logprob, so we do not need to adjust the padded length for logprob
625
+ global_num_tokens[i] = (
626
+ (global_num_tokens[i] - 1) // attn_tp_size + 1
627
+ ) * attn_tp_size
628
+
629
+ dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
630
+ self.dp_padding_mode = dp_padding_mode
631
+
632
+ if dp_padding_mode.is_max_len():
633
+ # when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
634
+ # where transferred tokens should be padded to the same length.
635
+ max_num_tokens = max(global_num_tokens)
636
+ global_num_tokens = [max_num_tokens] * sync_group_size
637
+ buffer_len = max_num_tokens * sync_group_size
638
+ else:
639
+ buffer_len = sum(global_num_tokens)
640
+
641
+ self.gathered_buffer = torch.zeros(
642
+ (buffer_len, model_runner.model_config.hidden_size),
643
+ dtype=model_runner.dtype,
644
+ device=model_runner.device,
645
+ )
646
+
647
+ bs = self.batch_size
648
+ if len(global_num_tokens) > 1:
649
+ num_tokens = global_num_tokens[get_attention_dp_rank()]
650
+ else:
651
+ num_tokens = global_num_tokens[0]
652
+
653
+ # padding
654
+ self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
655
+ self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
656
+
657
+ seq_len_fill_value = (
658
+ model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
659
+ )
660
+ self.seq_lens = self._pad_tensor_to_size(
661
+ self.seq_lens, bs, value=seq_len_fill_value
662
+ )
663
+ if self.seq_lens_cpu is not None:
664
+ self.seq_lens_cpu = self._pad_tensor_to_size(
665
+ self.seq_lens_cpu, bs, value=seq_len_fill_value
666
+ )
667
+
668
+ self.out_cache_loc = self._pad_tensor_to_size(self.out_cache_loc, num_tokens)
669
+ if self.encoder_lens is not None:
670
+ self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs)
671
+ self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
672
+ self.global_num_tokens_cpu = global_num_tokens
673
+ self.global_num_tokens_gpu = self.global_num_tokens_gpu.new_tensor(
674
+ global_num_tokens
675
+ )
676
+
677
+ if self.mrope_positions is not None:
678
+ self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
679
+
680
+ if self.extend_seq_lens is not None:
681
+ self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
682
+
683
+ if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
684
+ spec_info = self.spec_info
685
+ self.output_cache_loc_backup = self.out_cache_loc
686
+ self.hidden_states_backup = spec_info.hidden_states
687
+ if spec_info.topk_p is not None:
688
+ spec_info.topk_p = self._pad_tensor_to_size(spec_info.topk_p, bs)
689
+ if spec_info.topk_index is not None:
690
+ spec_info.topk_index = self._pad_tensor_to_size(
691
+ spec_info.topk_index, bs
692
+ )
693
+ if spec_info.accept_length is not None:
694
+ spec_info.accept_length = self._pad_tensor_to_size(
695
+ spec_info.accept_length, bs
696
+ )
697
+ spec_info.hidden_states = self._pad_tensor_to_size(
698
+ spec_info.hidden_states, num_tokens
699
+ )
700
+
701
+ def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
702
+
703
+ bs = self.batch_size
704
+
705
+ if self.spec_info is not None:
706
+ if self.forward_mode.is_decode(): # draft
707
+ num_tokens = self.hidden_states_backup.shape[0]
708
+ self.positions = self.positions[:num_tokens]
709
+ self.seq_lens = self.seq_lens[:bs]
710
+ self.req_pool_indices = self.req_pool_indices[:bs]
711
+ if self.seq_lens_cpu is not None:
712
+ self.seq_lens_cpu = self.seq_lens_cpu[:bs]
713
+ logits_output.next_token_logits = logits_output.next_token_logits[
714
+ :num_tokens
715
+ ]
716
+ logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
717
+ elif self.forward_mode.is_target_verify(): # verify
718
+ num_tokens = bs * self.spec_info.draft_token_num
719
+ logits_output.next_token_logits = logits_output.next_token_logits[
720
+ :num_tokens
721
+ ]
722
+ logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
723
+ elif self.forward_mode.is_draft_extend(): # draft extend
724
+ self.spec_info.accept_length = self.spec_info.accept_length[:bs]
725
+ logits_output.next_token_logits = logits_output.next_token_logits[:bs]
726
+ logits_output.hidden_states = logits_output.hidden_states[:bs]
727
+ elif self.forward_mode.is_extend() or self.forward_mode.is_idle():
728
+ logits_output.next_token_logits = logits_output.next_token_logits[:bs]
729
+ logits_output.hidden_states = logits_output.hidden_states[:bs]
730
+
731
+ if hasattr(self, "hidden_states_backup"):
732
+ self.spec_info.hidden_states = self.hidden_states_backup
733
+ if hasattr(self, "output_cache_loc_backup"):
734
+ self.out_cache_loc = self.output_cache_loc_backup
735
+
736
+ elif self.forward_mode.is_decode() or self.forward_mode.is_idle():
737
+ logits_output.next_token_logits = logits_output.next_token_logits[:bs]
738
+ if logits_output.hidden_states is not None:
739
+ logits_output.hidden_states = logits_output.hidden_states[:bs]
740
+ elif self.forward_mode.is_extend():
741
+ num_tokens = self.seq_lens_sum
742
+ logits_output.next_token_logits = logits_output.next_token_logits[
743
+ :num_tokens
744
+ ]
745
+ if logits_output.hidden_states is not None:
746
+ logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
747
+
563
748
  # Here we suppose the length of each chunk is equal
564
749
  # For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
565
750
  # num_prefix_chunks = cdiv(1024, 256) = 4