sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +7 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +13 -1
  16. sglang/srt/entrypoints/openai/protocol.py +3 -1
  17. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  18. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  19. sglang/srt/function_call/ebnf_composer.py +10 -3
  20. sglang/srt/function_call/function_call_parser.py +2 -0
  21. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  22. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  24. sglang/srt/layers/attention/vision.py +56 -8
  25. sglang/srt/layers/layernorm.py +26 -1
  26. sglang/srt/layers/logits_processor.py +14 -3
  27. sglang/srt/layers/moe/ep_moe/layer.py +323 -242
  28. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  33. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  34. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  35. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  36. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  37. sglang/srt/layers/moe/topk.py +90 -24
  38. sglang/srt/layers/multimodal.py +11 -8
  39. sglang/srt/layers/quantization/fp8.py +25 -247
  40. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  41. sglang/srt/layers/quantization/modelopt_quant.py +27 -10
  42. sglang/srt/layers/quantization/unquant.py +24 -76
  43. sglang/srt/layers/quantization/w4afp8.py +68 -17
  44. sglang/srt/lora/lora_registry.py +93 -29
  45. sglang/srt/managers/cache_controller.py +9 -7
  46. sglang/srt/managers/data_parallel_controller.py +4 -0
  47. sglang/srt/managers/io_struct.py +12 -0
  48. sglang/srt/managers/mm_utils.py +154 -35
  49. sglang/srt/managers/multimodal_processor.py +3 -14
  50. sglang/srt/managers/schedule_batch.py +14 -8
  51. sglang/srt/managers/scheduler.py +64 -1
  52. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  53. sglang/srt/managers/tokenizer_manager.py +80 -15
  54. sglang/srt/managers/tp_worker.py +8 -0
  55. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  56. sglang/srt/model_executor/model_runner.py +83 -27
  57. sglang/srt/models/deepseek_v2.py +75 -84
  58. sglang/srt/models/glm4_moe.py +1035 -0
  59. sglang/srt/models/glm4_moe_nextn.py +167 -0
  60. sglang/srt/models/interns1.py +328 -0
  61. sglang/srt/models/internvl.py +143 -47
  62. sglang/srt/models/llava.py +9 -5
  63. sglang/srt/models/minicpmo.py +4 -1
  64. sglang/srt/models/qwen2_moe.py +2 -2
  65. sglang/srt/models/qwen3_moe.py +17 -71
  66. sglang/srt/multimodal/processors/base_processor.py +20 -6
  67. sglang/srt/multimodal/processors/clip.py +2 -2
  68. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  69. sglang/srt/multimodal/processors/gemma3.py +2 -2
  70. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  71. sglang/srt/multimodal/processors/internvl.py +21 -8
  72. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  73. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  74. sglang/srt/multimodal/processors/llava.py +4 -4
  75. sglang/srt/multimodal/processors/minicpm.py +2 -3
  76. sglang/srt/multimodal/processors/mlama.py +2 -2
  77. sglang/srt/multimodal/processors/mllama4.py +18 -111
  78. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  79. sglang/srt/multimodal/processors/pixtral.py +2 -2
  80. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  81. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  82. sglang/srt/multimodal/processors/vila.py +3 -1
  83. sglang/srt/poll_based_barrier.py +31 -0
  84. sglang/srt/reasoning_parser.py +2 -1
  85. sglang/srt/server_args.py +65 -6
  86. sglang/srt/two_batch_overlap.py +8 -3
  87. sglang/srt/utils.py +96 -1
  88. sglang/srt/weight_sync/utils.py +119 -0
  89. sglang/test/runners.py +4 -0
  90. sglang/test/test_utils.py +118 -5
  91. sglang/utils.py +19 -0
  92. sglang/version.py +1 -1
  93. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
  94. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
  95. sglang/srt/debug_utils.py +0 -74
  96. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
  97. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
  98. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -56,7 +56,11 @@ from sglang.srt.layers.linear import (
56
56
  RowParallelLinear,
57
57
  )
58
58
  from sglang.srt.layers.logits_processor import LogitsProcessor
59
- from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
59
+ from sglang.srt.layers.moe.ep_moe.layer import (
60
+ DeepEPMoE,
61
+ get_moe_impl_class,
62
+ use_flashinfer_trtllm_moe,
63
+ )
60
64
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
61
65
  from sglang.srt.layers.moe.topk import TopK
62
66
  from sglang.srt.layers.quantization import deep_gemm_wrapper
@@ -302,15 +306,19 @@ class DeepseekV2MoE(nn.Module):
302
306
  config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
303
307
  )
304
308
 
305
- self.topk = TopK(
306
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
307
- renormalize=config.norm_topk_prob,
308
- use_grouped_topk=True,
309
- num_expert_group=config.n_group,
310
- num_fused_shared_experts=self.num_fused_shared_experts,
311
- topk_group=config.topk_group,
312
- correction_bias=self.gate.e_score_correction_bias,
313
- routed_scaling_factor=self.routed_scaling_factor,
309
+ self.topk = (
310
+ TopK(
311
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
312
+ renormalize=config.norm_topk_prob,
313
+ use_grouped_topk=True,
314
+ num_expert_group=config.n_group,
315
+ num_fused_shared_experts=self.num_fused_shared_experts,
316
+ topk_group=config.topk_group,
317
+ correction_bias=self.gate.e_score_correction_bias,
318
+ routed_scaling_factor=self.routed_scaling_factor,
319
+ )
320
+ if not use_flashinfer_trtllm_moe
321
+ else None
314
322
  )
315
323
 
316
324
  self.experts = get_moe_impl_class()(
@@ -332,10 +340,22 @@ class DeepseekV2MoE(nn.Module):
332
340
  # Additional args for FusedMoE
333
341
  **(
334
342
  dict(
335
- enable_flashinfer_moe=True,
343
+ enable_flashinfer_cutlass_moe=True,
336
344
  enable_ep_moe=global_server_args_dict["enable_ep_moe"],
337
345
  )
338
- if global_server_args_dict["enable_flashinfer_moe"]
346
+ if global_server_args_dict["enable_flashinfer_cutlass_moe"]
347
+ else {}
348
+ ),
349
+ **(
350
+ dict(
351
+ renormalize=config.norm_topk_prob,
352
+ use_grouped_topk=True,
353
+ num_expert_group=config.n_group,
354
+ num_fused_shared_experts=self.num_fused_shared_experts,
355
+ topk_group=config.topk_group,
356
+ correction_bias=self.gate.e_score_correction_bias,
357
+ )
358
+ if use_flashinfer_trtllm_moe
339
359
  else {}
340
360
  ),
341
361
  )
@@ -455,10 +475,12 @@ class DeepseekV2MoE(nn.Module):
455
475
  with torch.cuda.stream(self.alt_stream):
456
476
  # router_logits: (num_tokens, n_experts)
457
477
  router_logits = self.gate(hidden_states)
458
- topk_output = self.topk(hidden_states, router_logits)
459
- final_hidden_states = self.experts(
460
- hidden_states=hidden_states, topk_output=topk_output
461
- )
478
+ kwargs = {"hidden_states": hidden_states}
479
+ if self.topk is not None:
480
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
481
+ else:
482
+ kwargs["router_logits"] = router_logits
483
+ final_hidden_states = self.experts(**kwargs)
462
484
  if not _is_cuda:
463
485
  final_hidden_states *= self.routed_scaling_factor
464
486
  current_stream.wait_stream(self.alt_stream)
@@ -478,10 +500,12 @@ class DeepseekV2MoE(nn.Module):
478
500
  shared_output = self._forward_shared_experts(hidden_states)
479
501
  # router_logits: (num_tokens, n_experts)
480
502
  router_logits = self.gate(hidden_states)
481
- topk_output = self.topk(hidden_states, router_logits)
482
- final_hidden_states = self.experts(
483
- hidden_states=hidden_states, topk_output=topk_output
484
- )
503
+ kwargs = {"hidden_states": hidden_states}
504
+ if self.topk is not None:
505
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
506
+ else:
507
+ kwargs["router_logits"] = router_logits
508
+ final_hidden_states = self.experts(**kwargs)
485
509
  if not _is_cuda and not _use_aiter:
486
510
  # fused in biased_grouped_topk so we can skip here
487
511
  final_hidden_states *= self.routed_scaling_factor
@@ -570,41 +594,13 @@ class DeepseekV2MoE(nn.Module):
570
594
  topk_weights = torch.empty(
571
595
  (0, self.top_k), dtype=torch.float32, device=hidden_states.device
572
596
  )
573
- if self.ep_size > 1:
574
- # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
575
- (
576
- hidden_states,
577
- topk_idx,
578
- topk_weights,
579
- reorder_topk_ids,
580
- num_recv_tokens_per_expert,
581
- seg_indptr,
582
- masked_m,
583
- expected_m,
584
- ) = self.deepep_dispatcher.dispatch(
585
- hidden_states=hidden_states,
586
- topk_idx=topk_idx,
587
- topk_weights=topk_weights,
588
- forward_batch=forward_batch,
589
- )
597
+
590
598
  final_hidden_states = self.experts(
591
599
  hidden_states=hidden_states,
592
600
  topk_idx=topk_idx,
593
601
  topk_weights=topk_weights,
594
- reorder_topk_ids=reorder_topk_ids,
595
- seg_indptr=seg_indptr,
596
- masked_m=masked_m,
597
- expected_m=expected_m,
598
- num_recv_tokens_per_expert=num_recv_tokens_per_expert,
599
602
  forward_batch=forward_batch,
600
603
  )
601
- if self.ep_size > 1:
602
- final_hidden_states = self.deepep_dispatcher.combine(
603
- hidden_states=final_hidden_states,
604
- topk_idx=topk_idx,
605
- topk_weights=topk_weights,
606
- forward_batch=forward_batch,
607
- )
608
604
 
609
605
  if shared_output is not None:
610
606
  x = shared_output
@@ -665,8 +661,7 @@ class DeepseekV2MoE(nn.Module):
665
661
 
666
662
  def op_dispatch_a(self, state):
667
663
  if self.ep_size > 1:
668
- # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
669
- self.deepep_dispatcher.dispatch_a(
664
+ self.experts.deepep_dispatcher.dispatch_a(
670
665
  hidden_states=state.hidden_states_mlp_input,
671
666
  topk_idx=state.pop("topk_idx_local"),
672
667
  topk_weights=state.pop("topk_weights_local"),
@@ -679,46 +674,32 @@ class DeepseekV2MoE(nn.Module):
679
674
  with get_global_expert_distribution_recorder().with_current_layer(
680
675
  self.layer_id
681
676
  ):
682
- (
683
- state.hidden_states_experts_input,
684
- state.topk_idx_dispatched,
685
- state.topk_weights_dispatched,
686
- state.reorder_topk_ids,
687
- state.num_recv_tokens_per_expert,
688
- state.seg_indptr,
689
- state.masked_m,
690
- state.expected_m,
691
- ) = self.deepep_dispatcher.dispatch_b(
677
+ state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
692
678
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
693
679
  )
694
680
 
695
681
  def op_experts(self, state):
696
- state.hidden_states_experts_output = self.experts(
697
- hidden_states=state.pop("hidden_states_experts_input"),
698
- topk_idx=state.topk_idx_dispatched,
699
- topk_weights=state.topk_weights_dispatched,
700
- reorder_topk_ids=state.pop("reorder_topk_ids"),
701
- seg_indptr=state.pop("seg_indptr"),
702
- masked_m=state.pop("masked_m"),
703
- expected_m=state.pop("expected_m"),
704
- num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
705
- forward_batch=state.forward_batch,
682
+ state.hidden_states_experts_output = self.experts.moe_impl(
683
+ dispatch_output=state.dispatch_output,
706
684
  )
707
685
 
708
686
  def op_combine_a(self, state):
709
687
  if self.ep_size > 1:
710
- self.deepep_dispatcher.combine_a(
688
+ self.experts.deepep_dispatcher.combine_a(
711
689
  hidden_states=state.pop("hidden_states_experts_output"),
712
- topk_idx=state.pop("topk_idx_dispatched"),
713
- topk_weights=state.pop("topk_weights_dispatched"),
690
+ topk_idx=state.dispatch_output.topk_idx,
691
+ topk_weights=state.dispatch_output.topk_weights,
714
692
  forward_batch=state.forward_batch,
715
693
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
716
694
  )
695
+ state.pop("dispatch_output")
717
696
 
718
697
  def op_combine_b(self, state):
719
698
  if self.ep_size > 1:
720
- state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
721
- tbo_subbatch_index=state.get("tbo_subbatch_index"),
699
+ state.hidden_states_after_combine = (
700
+ self.experts.deepep_dispatcher.combine_b(
701
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
702
+ )
722
703
  )
723
704
 
724
705
  def op_output(self, state):
@@ -901,7 +882,10 @@ class DeepseekV2AttentionMLA(nn.Module):
901
882
  self.disable_chunked_prefix_cache = global_server_args_dict[
902
883
  "disable_chunked_prefix_cache"
903
884
  ]
904
- self.attention_backend = global_server_args_dict["attention_backend"]
885
+
886
+ self.current_attention_backend = (
887
+ None # Attention backend used by current forward batch
888
+ )
905
889
  self.rocm_fused_decode_mla = get_bool_env_var(
906
890
  "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
907
891
  )
@@ -985,9 +969,16 @@ class DeepseekV2AttentionMLA(nn.Module):
985
969
  else:
986
970
  return AttnForwardMethod.MLA
987
971
 
988
- if self.attention_backend == "ascend":
972
+ # Determine attention backend used by current forward batch
973
+ if forward_batch.forward_mode.is_decode_or_idle():
974
+ attention_backend = global_server_args_dict["decode_attention_backend"]
975
+ else:
976
+ attention_backend = global_server_args_dict["prefill_attention_backend"]
977
+ self.current_attention_backend = attention_backend
978
+
979
+ if attention_backend == "ascend":
989
980
  return AttnForwardMethod.MLA
990
- elif self.attention_backend == "flashinfer":
981
+ elif attention_backend == "flashinfer":
991
982
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
992
983
  if (
993
984
  not self.flashinfer_mla_disable_ragged
@@ -999,7 +990,7 @@ class DeepseekV2AttentionMLA(nn.Module):
999
990
  return AttnForwardMethod.MHA
1000
991
  else:
1001
992
  return _dispatch_mla_subtype()
1002
- elif self.attention_backend == "fa3":
993
+ elif attention_backend == "fa3":
1003
994
  # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
1004
995
  if forward_batch.extend_prefix_lens_cpu is not None:
1005
996
  sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
@@ -1016,7 +1007,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1016
1007
  return AttnForwardMethod.MHA_CHUNKED_KV
1017
1008
  else:
1018
1009
  return _dispatch_mla_subtype()
1019
- elif self.attention_backend == "aiter":
1010
+ elif attention_backend == "aiter":
1020
1011
  if (
1021
1012
  forward_batch.forward_mode.is_extend()
1022
1013
  and not forward_batch.forward_mode.is_target_verify()
@@ -1264,9 +1255,9 @@ class DeepseekV2AttentionMLA(nn.Module):
1264
1255
  self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1265
1256
  ):
1266
1257
  if (
1267
- self.attention_backend == "fa3"
1268
- or self.attention_backend == "flashinfer"
1269
- or self.attention_backend == "cutlass_mla"
1258
+ self.current_attention_backend == "fa3"
1259
+ or self.current_attention_backend == "flashinfer"
1260
+ or self.current_attention_backend == "cutlass_mla"
1270
1261
  ):
1271
1262
  attn_output = self.attn_mqa(
1272
1263
  q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe