sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/lang/interpreter.py +1 -1
  2. sglang/srt/configs/internvl.py +6 -0
  3. sglang/srt/configs/model_config.py +2 -1
  4. sglang/srt/disaggregation/mini_lb.py +2 -2
  5. sglang/srt/distributed/parallel_state.py +46 -41
  6. sglang/srt/entrypoints/engine.py +1 -1
  7. sglang/srt/entrypoints/http_server.py +5 -1
  8. sglang/srt/entrypoints/openai/protocol.py +3 -3
  9. sglang/srt/entrypoints/openai/serving_chat.py +3 -3
  10. sglang/srt/entrypoints/openai/serving_completions.py +3 -1
  11. sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
  12. sglang/srt/entrypoints/openai/serving_responses.py +1 -1
  13. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  14. sglang/srt/layers/attention/aiter_backend.py +93 -68
  15. sglang/srt/layers/communicator.py +45 -7
  16. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  17. sglang/srt/layers/moe/ep_moe/layer.py +2 -7
  18. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  19. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  21. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  24. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  25. sglang/srt/layers/moe/utils.py +0 -1
  26. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
  27. sglang/srt/layers/quantization/modelopt_quant.py +35 -2
  28. sglang/srt/layers/quantization/mxfp4.py +4 -1
  29. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  30. sglang/srt/layers/quantization/quark/utils.py +97 -0
  31. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  32. sglang/srt/layers/quantization/w4afp8.py +30 -25
  33. sglang/srt/layers/rocm_linear_utils.py +44 -0
  34. sglang/srt/layers/rotary_embedding.py +0 -18
  35. sglang/srt/managers/cache_controller.py +42 -39
  36. sglang/srt/managers/detokenizer_manager.py +0 -34
  37. sglang/srt/managers/multi_tokenizer_mixin.py +48 -6
  38. sglang/srt/managers/schedule_policy.py +3 -2
  39. sglang/srt/managers/scheduler.py +7 -100
  40. sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
  41. sglang/srt/managers/template_manager.py +3 -3
  42. sglang/srt/managers/tokenizer_manager.py +1 -0
  43. sglang/srt/mem_cache/allocator.py +1 -1
  44. sglang/srt/mem_cache/hicache_storage.py +15 -10
  45. sglang/srt/mem_cache/hiradix_cache.py +16 -0
  46. sglang/srt/mem_cache/memory_pool_host.py +18 -11
  47. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  48. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +35 -6
  49. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
  50. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  51. sglang/srt/metrics/collector.py +12 -4
  52. sglang/srt/metrics/utils.py +48 -0
  53. sglang/srt/model_executor/forward_batch_info.py +16 -17
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +245 -36
  56. sglang/srt/models/glm4_moe.py +10 -1
  57. sglang/srt/models/gpt_oss.py +5 -4
  58. sglang/srt/models/internvl.py +28 -0
  59. sglang/srt/models/longcat_flash.py +26 -15
  60. sglang/srt/models/longcat_flash_nextn.py +23 -15
  61. sglang/srt/models/minicpmv.py +165 -3
  62. sglang/srt/models/qwen2_moe.py +4 -1
  63. sglang/srt/models/qwen3.py +8 -2
  64. sglang/srt/models/qwen3_moe.py +39 -8
  65. sglang/srt/models/torch_native_llama.py +1 -1
  66. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  67. sglang/srt/server_args.py +79 -2
  68. sglang/srt/speculative/eagle_worker.py +158 -112
  69. sglang/srt/utils.py +12 -10
  70. sglang/test/few_shot_gsm8k.py +1 -0
  71. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  72. sglang/utils.py +1 -0
  73. sglang/version.py +1 -1
  74. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +2 -2
  75. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +83 -76
  76. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  77. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  78. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  79. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  80. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  81. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  82. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
  83. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -651,9 +651,6 @@ class LongcatFlashForCausalLM(nn.Module):
651
651
  ).T
652
652
  else:
653
653
  w = self_attn.kv_b_proj.weight
654
- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
655
- # This may affect the accuracy of fp8 model.
656
- # Fix deepseek v3 blockwise bmm by using deep_gemm
657
654
  use_deep_gemm_bmm = False
658
655
 
659
656
  if w.dtype in (
@@ -790,6 +787,9 @@ class LongcatFlashForCausalLM(nn.Module):
790
787
  self.config.hidden_size / self.config.kv_lora_rank
791
788
  ) ** 0.5
792
789
 
790
+ # TODO(linguoyuan) EPMoE not support DEEPGEMM_BLACKWELL, DeepEP needs to be supported in the future
791
+ deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 = False
792
+
793
793
  if (
794
794
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
795
795
  and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
@@ -804,24 +804,35 @@ class LongcatFlashForCausalLM(nn.Module):
804
804
  for layer_id in range(self.config.num_hidden_layers):
805
805
  layer = self.model.layers[layer_id]
806
806
  for i in range(2):
807
- for module in [
808
- layer.self_attn[i].fused_qkv_a_proj_with_mqa,
809
- layer.self_attn[i].q_b_proj,
810
- layer.self_attn[i].kv_b_proj,
811
- layer.self_attn[i].o_proj,
812
- ]:
813
- requant_weight_ue8m0_inplace(
814
- module.weight, module.weight_scale_inv, weight_block_size
815
- )
807
+ self_attn = layer.self_attn[i]
808
+ module_list = [
809
+ self_attn.kv_b_proj,
810
+ self_attn.o_proj,
811
+ ]
812
+
813
+ if self.config.q_lora_rank is not None:
814
+ module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
815
+ module_list.append(self_attn.q_b_proj)
816
+ else:
817
+ module_list.append(self_attn.kv_a_proj_with_mqa)
818
+ module_list.append(self_attn.q_proj)
819
+
820
+ for module in module_list:
821
+ if hasattr(module, "weight_scale_inv"):
822
+ requant_weight_ue8m0_inplace(
823
+ module.weight, module.weight_scale_inv, weight_block_size
824
+ )
825
+
816
826
  mlp = layer.mlps[i]
817
827
  assert isinstance(mlp, LongcatFlashMLP)
818
828
  for module in [
819
829
  mlp.gate_up_proj,
820
830
  mlp.down_proj,
821
831
  ]:
822
- requant_weight_ue8m0_inplace(
823
- module.weight, module.weight_scale_inv, weight_block_size
824
- )
832
+ if hasattr(module, "weight_scale_inv"):
833
+ requant_weight_ue8m0_inplace(
834
+ module.weight, module.weight_scale_inv, weight_block_size
835
+ )
825
836
 
826
837
  for layer_id in range(self.config.num_hidden_layers):
827
838
  experts = layer.mlp.experts
@@ -344,9 +344,6 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
344
344
  ).T
345
345
  else:
346
346
  w = self_attn.kv_b_proj.weight
347
- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
348
- # This may affect the accuracy of fp8 model.
349
- # Fix deepseek v3 blockwise bmm by using deep_gemm
350
347
  use_deep_gemm_bmm = False
351
348
  if w.dtype in (
352
349
  torch.float8_e4m3fn,
@@ -480,24 +477,35 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
480
477
  def _weight_requant_ue8m0(self):
481
478
  weight_block_size = self.quant_config.weight_block_size
482
479
  layer = self.model.decoder
483
- for module in [
484
- layer.self_attn.fused_qkv_a_proj_with_mqa,
485
- layer.self_attn.q_b_proj,
486
- layer.self_attn.kv_b_proj,
487
- layer.self_attn.o_proj,
488
- ]:
489
- requant_weight_ue8m0_inplace(
490
- module.weight, module.weight_scale_inv, weight_block_size
491
- )
480
+ self_attn = layer.self_attn
481
+ module_list = [
482
+ self_attn.kv_b_proj,
483
+ self_attn.o_proj,
484
+ ]
485
+
486
+ if self.config.q_lora_rank is not None:
487
+ module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
488
+ module_list.append(self_attn.q_b_proj)
489
+ else:
490
+ module_list.append(self_attn.kv_a_proj_with_mqa)
491
+ module_list.append(self_attn.q_proj)
492
+
493
+ for module in module_list:
494
+ if hasattr(module, "weight_scale_inv"):
495
+ requant_weight_ue8m0_inplace(
496
+ module.weight, module.weight_scale_inv, weight_block_size
497
+ )
498
+
492
499
  mlp = layer.mlps
493
500
  assert isinstance(mlp, LongcatFlashMLP)
494
501
  for module in [
495
502
  mlp.gate_up_proj,
496
503
  mlp.down_proj,
497
504
  ]:
498
- requant_weight_ue8m0_inplace(
499
- module.weight, module.weight_scale_inv, weight_block_size
500
- )
505
+ if hasattr(module, "weight_scale_inv"):
506
+ requant_weight_ue8m0_inplace(
507
+ module.weight, module.weight_scale_inv, weight_block_size
508
+ )
501
509
 
502
510
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
503
511
  stacked_params_mapping = [
@@ -54,6 +54,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
54
54
  from sglang.srt.model_loader.utils import set_default_torch_dtype
55
55
  from sglang.srt.model_loader.weight_utils import default_weight_loader
56
56
  from sglang.srt.models.idefics2 import Idefics2VisionTransformer
57
+ from sglang.srt.models.llama import LlamaConfig, LlamaForCausalLM
57
58
  from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
58
59
  from sglang.srt.utils import add_prefix, flatten_nested_list
59
60
 
@@ -581,7 +582,7 @@ class MiniCPMBaseModel(nn.Module):
581
582
 
582
583
  def init_llm(
583
584
  self,
584
- config: Qwen2Config,
585
+ config: PretrainedConfig,
585
586
  quant_config: Optional[QuantizationConfig] = None,
586
587
  prefix: str = "",
587
588
  ) -> nn.Module:
@@ -774,7 +775,168 @@ class MiniCPMV2_6(MiniCPMBaseModel):
774
775
  return pattern.pad_input_tokens(input_ids, image_inputs)
775
776
 
776
777
 
777
- _SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
778
+ class MiniCPMV4_0(MiniCPMBaseModel):
779
+ packed_modules_mapping = {
780
+ "qkv_proj": [
781
+ "q_proj",
782
+ "k_proj",
783
+ "v_proj",
784
+ ],
785
+ "gate_up_proj": [
786
+ "gate_proj",
787
+ "up_proj",
788
+ ],
789
+ }
790
+ # LoRA specific attributes
791
+ supported_lora_modules = [
792
+ # vision encoder
793
+ "fc1",
794
+ "fc2",
795
+ "out_proj",
796
+ # language model
797
+ "qkv_proj", # same name with vision encoder
798
+ "o_proj",
799
+ "gate_up_proj",
800
+ "down_proj",
801
+ # resampler
802
+ "kv_proj",
803
+ ]
804
+
805
+ # BitandBytes specific attributes
806
+ bitsandbytes_stacked_params_mapping = {
807
+ # shard_name, weight_name, index
808
+ "q_proj": ("qkv_proj", 0),
809
+ "k_proj": ("qkv_proj", 1),
810
+ "v_proj": ("qkv_proj", 2),
811
+ "gate_proj": ("gate_up_proj", 0),
812
+ "up_proj": ("gate_up_proj", 1),
813
+ }
814
+
815
+ embedding_modules = {}
816
+ embedding_padding_modules = []
817
+
818
+ def __init__(
819
+ self,
820
+ config: PretrainedConfig,
821
+ quant_config: Optional[QuantizationConfig] = None,
822
+ prefix: str = "",
823
+ ):
824
+ super().__init__(config=config, quant_config=quant_config, prefix=prefix)
825
+ assert self.version == (4, 0)
826
+
827
+ def init_llm(
828
+ self,
829
+ config: LlamaConfig,
830
+ quant_config: Optional[QuantizationConfig] = None,
831
+ prefix: str = "",
832
+ ) -> nn.Module:
833
+ return LlamaForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
834
+
835
+ def init_vision_module(
836
+ self,
837
+ config: PretrainedConfig,
838
+ quant_config: Optional[QuantizationConfig],
839
+ prefix: str = "",
840
+ ) -> nn.Module:
841
+ model = Idefics2VisionTransformer(
842
+ config=config.vision_config, quant_config=quant_config, prefix=prefix
843
+ )
844
+ if self.config.drop_vision_last_layer:
845
+ model.encoder.layers = model.encoder.layers[:-1]
846
+
847
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
848
+ setattr(model, "patch_size", model.embeddings.patch_size)
849
+ return model
850
+
851
+ def init_resampler(
852
+ self,
853
+ embed_dim: int,
854
+ vision_dim: int,
855
+ quant_config: Optional[QuantizationConfig] = None,
856
+ prefix: str = "",
857
+ ) -> nn.Module:
858
+ with set_default_torch_dtype(torch.float16):
859
+ # The resampler in 2.6 remains consistent with the one in 2.5.
860
+ resampler = Resampler2_5(
861
+ num_queries=self.config.query_num,
862
+ embed_dim=embed_dim,
863
+ num_heads=embed_dim // 128,
864
+ kv_dim=vision_dim,
865
+ quant_config=quant_config,
866
+ prefix=prefix,
867
+ )
868
+
869
+ return resampler.to(device="cuda", dtype=torch.get_default_dtype())
870
+
871
+ def get_vision_embedding(
872
+ self,
873
+ pixel_values: List[torch.Tensor],
874
+ patch_attn_mask: Optional[torch.Tensor] = None,
875
+ tgt_sizes: Optional[torch.Tensor] = None,
876
+ ) -> torch.Tensor:
877
+ vision_embedding = self.vpm(
878
+ pixel_values,
879
+ patch_attention_mask=patch_attn_mask,
880
+ tgt_sizes=tgt_sizes,
881
+ )
882
+ return vision_embedding
883
+
884
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
885
+ # list of tensors
886
+ pixel_values = flatten_nested_list([item.feature for item in items])
887
+ tgt_sizes = torch.stack(
888
+ flatten_nested_list([item.tgt_size for item in items]), dim=0
889
+ )
890
+ assert len(pixel_values) == tgt_sizes.shape[0]
891
+
892
+ device = self.vpm.embeddings.position_embedding.weight.device
893
+ dtype = self.vpm.embeddings.position_embedding.weight.dtype
894
+ all_pixel_values_lst = [
895
+ i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
896
+ ]
897
+
898
+ max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
899
+ assert isinstance(max_patches, int)
900
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(
901
+ all_pixel_values_lst, batch_first=True, padding_value=0.0
902
+ )
903
+
904
+ B, L, _ = all_pixel_values.shape
905
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
906
+ patch_attn_mask = torch.zeros(
907
+ (B, 1, max_patches), dtype=torch.bool, device=device
908
+ )
909
+
910
+ tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
911
+ mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
912
+ patch_attn_mask[:, 0, :] = torch.arange(
913
+ patch_attn_mask.size(2), device=patch_attn_mask.device
914
+ ).unsqueeze(0) < mask_shapes.unsqueeze(1)
915
+
916
+ vision_embedding = self.vpm(
917
+ all_pixel_values.type(dtype),
918
+ patch_attention_mask=patch_attn_mask,
919
+ tgt_sizes=tgt_sizes,
920
+ )
921
+ return self.resampler(vision_embedding, tgt_sizes)
922
+
923
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
924
+ # Get all special token IDs
925
+ im_start_id: int = image_inputs.im_start_id
926
+ im_end_id: int = image_inputs.im_end_id
927
+ slice_start_id: int = image_inputs.slice_start_id
928
+ slice_end_id: int = image_inputs.slice_end_id
929
+
930
+ media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
931
+ pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
932
+
933
+ return pattern.pad_input_tokens(input_ids, image_inputs)
934
+
935
+
936
+ _SUPPORT_VERSION = {
937
+ (2, 6): MiniCPMV2_6,
938
+ (4, 0): MiniCPMV4_0,
939
+ }
778
940
 
779
941
 
780
942
  class MiniCPMV:
@@ -809,7 +971,7 @@ class MiniCPMV:
809
971
  # Dispatch class based on version
810
972
  instance_class = _SUPPORT_VERSION.get(version)
811
973
  if instance_class is None:
812
- raise ValueError("Currently, MiniCPMV only supports versions 2.6")
974
+ raise ValueError("Currently, MiniCPMV only supports versions 2.6 and 4.0")
813
975
 
814
976
  try:
815
977
  minicpmv = instance_class(
@@ -105,11 +105,14 @@ class Qwen2MoeMLP(nn.Module):
105
105
  def forward(
106
106
  self,
107
107
  x,
108
+ should_allreduce_fusion: bool = False,
108
109
  use_reduce_scatter: bool = False,
109
110
  ):
110
111
  gate_up, _ = self.gate_up_proj(x)
111
112
  x = self.act_fn(gate_up)
112
- x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter)
113
+ x, _ = self.down_proj(
114
+ x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
115
+ )
113
116
  return x
114
117
 
115
118
 
@@ -24,7 +24,10 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
24
24
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
25
25
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
26
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
27
- from sglang.srt.model_loader.weight_utils import default_weight_loader
27
+ from sglang.srt.model_loader.weight_utils import (
28
+ default_weight_loader,
29
+ maybe_remap_kv_scale_name,
30
+ )
28
31
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
29
32
  from sglang.srt.models.qwen2 import Qwen2Model
30
33
  from sglang.srt.utils import add_prefix, is_cuda
@@ -458,7 +461,10 @@ class Qwen3ForCausalLM(nn.Module):
458
461
  continue
459
462
  if name.startswith("model.vision_tower") and name not in params_dict:
460
463
  continue
461
-
464
+ if "scale" in name:
465
+ name = maybe_remap_kv_scale_name(name, params_dict)
466
+ if name is None:
467
+ continue
462
468
  for param_name, weight_name, shard_id in stacked_params_mapping:
463
469
  if weight_name not in name:
464
470
  continue
@@ -42,7 +42,10 @@ from sglang.srt.layers.linear import (
42
42
  RowParallelLinear,
43
43
  )
44
44
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
- from sglang.srt.layers.moe import get_moe_a2a_backend
45
+ from sglang.srt.layers.moe import (
46
+ get_moe_a2a_backend,
47
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
48
+ )
46
49
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
47
50
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
48
51
  from sglang.srt.layers.moe.topk import TopK
@@ -57,10 +60,17 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
57
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
58
61
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
59
62
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
60
- from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
63
+ from sglang.srt.utils import (
64
+ add_prefix,
65
+ is_cuda,
66
+ is_flashinfer_available,
67
+ is_non_idle_and_non_empty,
68
+ )
61
69
 
62
70
  Qwen3MoeConfig = None
63
71
 
72
+ _is_flashinfer_available = is_flashinfer_available()
73
+
64
74
  logger = logging.getLogger(__name__)
65
75
  _is_cuda = is_cuda()
66
76
 
@@ -119,11 +129,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
119
129
  self,
120
130
  hidden_states: torch.Tensor,
121
131
  forward_batch: Optional[ForwardBatch] = None,
132
+ should_allreduce_fusion: bool = False,
122
133
  use_reduce_scatter: bool = False,
123
134
  ) -> torch.Tensor:
124
135
 
125
136
  if not get_moe_a2a_backend().is_deepep():
126
- return self.forward_normal(hidden_states, use_reduce_scatter)
137
+ return self.forward_normal(
138
+ hidden_states, should_allreduce_fusion, use_reduce_scatter
139
+ )
127
140
  else:
128
141
  return self.forward_deepep(hidden_states, forward_batch)
129
142
 
@@ -137,6 +150,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
137
150
  def forward_normal(
138
151
  self,
139
152
  hidden_states: torch.Tensor,
153
+ should_allreduce_fusion: bool = False,
140
154
  use_reduce_scatter: bool = False,
141
155
  ) -> torch.Tensor:
142
156
  num_tokens, hidden_dim = hidden_states.shape
@@ -146,7 +160,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
146
160
  router_logits, _ = self.gate(hidden_states)
147
161
  topk_output = self.topk(hidden_states, router_logits)
148
162
  final_hidden_states = self.experts(hidden_states, topk_output)
149
- if self.tp_size > 1 and not use_reduce_scatter:
163
+ if (
164
+ self.tp_size > 1
165
+ and not should_allreduce_fusion
166
+ and not use_reduce_scatter
167
+ and not should_use_flashinfer_cutlass_moe_fp4_allgather()
168
+ ):
150
169
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
151
170
 
152
171
  return final_hidden_states.view(num_tokens, hidden_dim)
@@ -500,6 +519,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
500
519
  input_layernorm=self.input_layernorm,
501
520
  post_attention_layernorm=self.post_attention_layernorm,
502
521
  allow_reduce_scatter=True,
522
+ is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
503
523
  )
504
524
 
505
525
  def forward(
@@ -525,17 +545,28 @@ class Qwen3MoeDecoderLayer(nn.Module):
525
545
  hidden_states, residual, forward_batch
526
546
  )
527
547
 
548
+ should_allreduce_fusion = (
549
+ self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
550
+ forward_batch
551
+ )
552
+ )
553
+
528
554
  # For DP with padding, reduce scatter can be used instead of all-reduce.
529
555
  use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
530
556
  forward_batch
531
557
  )
532
558
 
533
- hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
534
-
535
- hidden_states, residual = self.layer_communicator.postprocess_layer(
536
- hidden_states, residual, forward_batch
559
+ hidden_states = self.mlp(
560
+ hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
537
561
  )
538
562
 
563
+ if should_allreduce_fusion:
564
+ hidden_states._sglang_needs_allreduce_fusion = True
565
+ else:
566
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
567
+ hidden_states, residual, forward_batch
568
+ )
569
+
539
570
  return hidden_states, residual
540
571
 
541
572
  def op_comm_prepare_attn(
@@ -22,7 +22,7 @@ Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
22
22
 
23
23
  Here is a quick example to enable TP:
24
24
  ```python
25
- from sglang.srt.model_parallel import tensor_parallel
25
+ from sglang.srt.layers.model_parallel import tensor_parallel
26
26
 
27
27
  device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
28
28
  tensor_parallel(model, device_mesh)
@@ -1,7 +1,7 @@
1
1
  import re
2
2
  from typing import Dict, Optional, Tuple, Type
3
3
 
4
- from sglang.srt.harmony_parser import HarmonyParser
4
+ from sglang.srt.parser.harmony_parser import HarmonyParser
5
5
 
6
6
 
7
7
  class StreamingParseResult:
sglang/srt/server_args.py CHANGED
@@ -26,7 +26,7 @@ from typing import List, Literal, Optional, Union
26
26
  from sglang.srt.function_call.function_call_parser import FunctionCallParser
27
27
  from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
28
28
  from sglang.srt.lora.lora_registry import LoRARef
29
- from sglang.srt.reasoning_parser import ReasoningParser
29
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
30
30
  from sglang.srt.utils import (
31
31
  LORA_TARGET_ALL_MODULES,
32
32
  SUPPORTED_LORA_TARGET_MODULES,
@@ -195,6 +195,8 @@ class ServerArgs:
195
195
  bucket_inter_token_latency: Optional[List[float]] = None
196
196
  bucket_e2e_request_latency: Optional[List[float]] = None
197
197
  collect_tokens_histogram: bool = False
198
+ prompt_tokens_buckets: Optional[List[str]] = None
199
+ generation_tokens_buckets: Optional[List[str]] = None
198
200
  decode_log_interval: int = 40
199
201
  enable_request_time_stats_logging: bool = False
200
202
  kv_events_config: Optional[str] = None
@@ -1234,6 +1236,26 @@ class ServerArgs:
1234
1236
  default=ServerArgs.collect_tokens_histogram,
1235
1237
  help="Collect prompt/generation tokens histogram.",
1236
1238
  )
1239
+ bucket_rule = (
1240
+ "Supports 3 rule types: 'default' uses predefined buckets; 'tse <middle> <base> <count>' "
1241
+ "generates two sides exponential distributed buckets (e.g., 'tse 1000 2 8' generates buckets "
1242
+ "[984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'customer <value1> "
1243
+ "<value2> ...' uses custom bucket values (e.g., 'customer 10 50 100 500')."
1244
+ )
1245
+ parser.add_argument(
1246
+ "--prompt-tokens-buckets",
1247
+ type=str,
1248
+ nargs="+",
1249
+ default=ServerArgs.prompt_tokens_buckets,
1250
+ help=f"The buckets rule of prompt tokens. {bucket_rule}",
1251
+ )
1252
+ parser.add_argument(
1253
+ "--generation-tokens-buckets",
1254
+ type=str,
1255
+ nargs="+",
1256
+ default=ServerArgs.generation_tokens_buckets,
1257
+ help=f"The buckets rule for generation tokens histogram. {bucket_rule}",
1258
+ )
1237
1259
  parser.add_argument(
1238
1260
  "--gc-warning-threshold-secs",
1239
1261
  type=float,
@@ -2185,6 +2207,12 @@ class ServerArgs:
2185
2207
 
2186
2208
  # Check multi tokenizer
2187
2209
  assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
2210
+ self.validate_buckets_rule(
2211
+ "--prompt-tokens-buckets", self.prompt_tokens_buckets
2212
+ )
2213
+ self.validate_buckets_rule(
2214
+ "--generation-tokens-buckets", self.generation_tokens_buckets
2215
+ )
2188
2216
 
2189
2217
  def check_lora_server_args(self):
2190
2218
  assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
@@ -2277,6 +2305,54 @@ class ServerArgs:
2277
2305
  f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
2278
2306
  )
2279
2307
 
2308
+ def validate_buckets_rule(self, arg_name: str, buckets_rule: List[str]):
2309
+ if not buckets_rule:
2310
+ return
2311
+
2312
+ assert len(buckets_rule) > 0, f"{arg_name} cannot be empty list"
2313
+ rule = buckets_rule[0]
2314
+ assert rule in [
2315
+ "tse",
2316
+ "default",
2317
+ "customer",
2318
+ ], f"Unsupported {arg_name} rule type: '{rule}'. Must be one of: 'tse', 'default', 'customer'"
2319
+
2320
+ if rule == "tse":
2321
+ assert (
2322
+ len(buckets_rule) == 4
2323
+ ), f"{arg_name} TSE rule requires exactly 4 parameters: ['tse', middle, base, count], got {len(buckets_rule)}"
2324
+ try:
2325
+ middle = float(buckets_rule[1])
2326
+ base = float(buckets_rule[2])
2327
+ count = int(buckets_rule[3])
2328
+ except (ValueError, IndexError):
2329
+ assert (
2330
+ False
2331
+ ), f"{arg_name} TSE rule parameters must be: ['tse', <float:middle>, <float:base>, <int:count>]"
2332
+ assert base > 1, f"{arg_name} TSE base must be larger than 1, got: {base}"
2333
+ assert count > 0, f"{arg_name} TSE count must be positive, got: {count}"
2334
+ assert middle > 0, f"{arg_name} TSE middle must be positive, got: {middle}"
2335
+
2336
+ elif rule == "default":
2337
+ assert (
2338
+ len(buckets_rule) == 1
2339
+ ), f"{arg_name} default rule should only have one parameter: ['default'], got {len(buckets_rule)}"
2340
+
2341
+ elif rule == "customer":
2342
+ assert (
2343
+ len(buckets_rule) >= 2
2344
+ ), f"{arg_name} customer rule requires at least one bucket value: ['customer', value1, ...]"
2345
+ try:
2346
+ bucket_values = [float(x) for x in buckets_rule[1:]]
2347
+ except ValueError:
2348
+ assert False, f"{arg_name} customer rule bucket values must be numeric"
2349
+ assert len(set(bucket_values)) == len(
2350
+ bucket_values
2351
+ ), f"{arg_name} customer rule bucket values should not contain duplicates"
2352
+ assert all(
2353
+ val >= 0 for val in bucket_values
2354
+ ), f"{arg_name} customer rule bucket values should be non-negative"
2355
+
2280
2356
  def model_specific_adjustments(self):
2281
2357
  hf_config = self.get_hf_config()
2282
2358
  model_arch = hf_config.architectures[0]
@@ -2336,7 +2412,8 @@ class ServerArgs:
2336
2412
  assert self.attention_backend in {
2337
2413
  "fa3",
2338
2414
  "aiter",
2339
- }, "fa3 or aiter is required for Llama4 model"
2415
+ "triton",
2416
+ }, "fa3, aiter, or triton is required for Llama4 model"
2340
2417
  elif model_arch in [
2341
2418
  "Gemma2ForCausalLM",
2342
2419
  "Gemma3ForCausalLM",