sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
45
45
 
46
46
  if TYPE_CHECKING:
47
47
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
48
- from sglang.srt.layers.moe.topk import TopKOutput
48
+ from sglang.srt.layers.moe.token_dispatcher import (
49
+ StandardDispatchOutput,
50
+ CombineInput,
51
+ )
49
52
 
50
53
  from sglang.srt.utils import is_cuda
51
54
 
@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
838
841
  from sglang.srt.layers.linear import set_weight_attrs
839
842
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
840
843
 
841
- intermediate_size = extra_weight_attrs.pop("intermediate_size")
842
-
843
- self.is_k_full = (not self.quant_config.desc_act) or (
844
- intermediate_size_per_partition == intermediate_size
845
- )
844
+ self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1
846
845
 
847
846
  if self.quant_config.group_size != -1:
848
847
  scales_size13 = hidden_size // self.quant_config.group_size
849
- w2_scales_size = (
850
- intermediate_size
851
- if self.quant_config.desc_act
852
- else intermediate_size_per_partition
853
- )
848
+ if self.quant_config.desc_act:
849
+ w2_scales_size = intermediate_size_per_partition
850
+ else:
851
+ w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
854
852
  scales_size2 = w2_scales_size // self.quant_config.group_size
855
853
  strategy = FusedMoeWeightScaleSupported.GROUP.value
856
854
  else:
@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1052
1050
  )
1053
1051
  replace_parameter(layer, "w2_scales", marlin_w2_scales)
1054
1052
 
1053
+ def create_moe_runner(
1054
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1055
+ ):
1056
+ self.moe_runner_config = moe_runner_config
1057
+
1055
1058
  def apply(
1056
1059
  self,
1057
1060
  layer: torch.nn.Module,
1058
- x: torch.Tensor,
1059
- topk_output: TopKOutput,
1060
- moe_runner_config: MoeRunnerConfig,
1061
- ) -> torch.Tensor:
1061
+ dispatch_output: StandardDispatchOutput,
1062
+ ) -> CombineInput:
1063
+
1064
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1065
+
1066
+ x = dispatch_output.hidden_states
1067
+ topk_output = dispatch_output.topk_output
1068
+
1062
1069
  # Delay the import to avoid circular dependency
1063
1070
 
1064
1071
  assert (
1065
- moe_runner_config.activation == "silu"
1072
+ self.moe_runner_config.activation == "silu"
1066
1073
  ), "Only SiLU activation is supported."
1067
1074
 
1068
1075
  # The input must currently be float16
@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1071
1078
 
1072
1079
  topk_weights, topk_ids, router_logits = topk_output
1073
1080
 
1074
- return fused_marlin_moe(
1081
+ output = fused_marlin_moe(
1075
1082
  x,
1076
1083
  layer.w13_qweight,
1077
1084
  layer.w2_qweight,
@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1087
1094
  num_bits=self.quant_config.weight_bits,
1088
1095
  is_k_full=self.is_k_full,
1089
1096
  ).to(orig_dtype)
1097
+ return StandardCombineInput(hidden_states=output)
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
10
10
  from sglang.srt.distributed import get_tp_group
11
11
  from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
12
12
  from sglang.srt.layers.moe import (
13
+ MoeRunner,
14
+ MoeRunnerBackend,
15
+ MoeRunnerConfig,
13
16
  should_use_flashinfer_cutlass_moe_fp4_allgather,
14
17
  should_use_flashinfer_trtllm_moe,
15
18
  )
16
19
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
20
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
17
21
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
18
22
  from sglang.srt.layers.quantization.base_config import (
19
23
  FusedMoEMethodBase,
@@ -35,12 +39,14 @@ from sglang.srt.layers.quantization.utils import (
35
39
  requantize_with_max_scale,
36
40
  )
37
41
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.utils import is_cuda, next_power_of_2
42
+ from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
39
43
 
40
44
  if TYPE_CHECKING:
41
45
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
42
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
43
- from sglang.srt.layers.moe.topk import TopKOutput
46
+ from sglang.srt.layers.moe.token_dispatcher import (
47
+ CombineInput,
48
+ StandardDispatchOutput,
49
+ )
44
50
 
45
51
  if is_cuda():
46
52
  from sgl_kernel import scaled_fp4_quant
@@ -68,6 +74,10 @@ except ImportError:
68
74
  # Initialize logger for the module
69
75
  logger = logging.getLogger(__name__)
70
76
 
77
+ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
78
+ "SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
79
+ )
80
+
71
81
  # Supported activation schemes for the current configuration
72
82
  ACTIVATION_SCHEMES = ["static"]
73
83
 
@@ -322,7 +332,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
322
332
  layer: torch.nn.Module,
323
333
  num_experts: int,
324
334
  hidden_size: int,
325
- intermediate_size: int,
335
+ intermediate_size_per_partition: int,
326
336
  params_dtype: torch.dtype,
327
337
  **extra_weight_attrs,
328
338
  ):
@@ -338,7 +348,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
338
348
 
339
349
  w13_weight = ModelWeightParameter(
340
350
  data=torch.empty(
341
- num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
351
+ num_experts,
352
+ 2 * intermediate_size_per_partition,
353
+ hidden_size,
354
+ dtype=weight_dtype,
342
355
  ),
343
356
  input_dim=2,
344
357
  output_dim=1,
@@ -348,7 +361,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
348
361
 
349
362
  w2_weight = ModelWeightParameter(
350
363
  data=torch.empty(
351
- num_experts, hidden_size, intermediate_size, dtype=weight_dtype
364
+ num_experts,
365
+ hidden_size,
366
+ intermediate_size_per_partition,
367
+ dtype=weight_dtype,
352
368
  ),
353
369
  input_dim=2,
354
370
  output_dim=1,
@@ -414,28 +430,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
414
430
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
415
431
 
416
432
  # Requantize each expert's weights using the combined scale
417
- # w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
418
- # where the first intermediate_size rows are w1, the next are w3
419
- intermediate_size = layer.w13_weight.shape[1] // 2
433
+ # w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
434
+ # where the first intermediate_size_per_partition rows are w1, the next are w3
435
+ intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
420
436
  for expert_id in range(layer.w13_weight.shape[0]):
421
437
  start = 0
422
438
  for shard_id in range(2): # w1 and w3
423
439
  # Dequantize using the original scale for this shard
424
440
  dq_weight = per_tensor_dequantize(
425
441
  layer.w13_weight[expert_id][
426
- start : start + intermediate_size, :
442
+ start : start + intermediate_size_per_partition, :
427
443
  ],
428
444
  layer.w13_weight_scale[expert_id][shard_id],
429
445
  )
430
446
  # Requantize using the combined max scale
431
447
  (
432
448
  layer.w13_weight[expert_id][
433
- start : start + intermediate_size, :
449
+ start : start + intermediate_size_per_partition, :
434
450
  ],
435
451
  _,
436
452
  ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
437
453
 
438
- start += intermediate_size
454
+ start += intermediate_size_per_partition
439
455
 
440
456
  # Update the scale parameter to be per-expert instead of per-shard
441
457
  layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
@@ -457,29 +473,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
457
473
  layer.w2_input_scale.max(), requires_grad=False
458
474
  )
459
475
 
476
+ def create_moe_runner(
477
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
478
+ ):
479
+ self.moe_runner_config = moe_runner_config
480
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
481
+
460
482
  def apply(
461
483
  self,
462
484
  layer: torch.nn.Module,
463
- x: torch.Tensor,
464
- topk_output: TopKOutput,
465
- moe_runner_config: MoeRunnerConfig,
466
- ) -> torch.Tensor:
467
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
468
-
469
- return fused_experts(
470
- x,
471
- layer.w13_weight,
472
- layer.w2_weight,
473
- topk_output=topk_output,
474
- moe_runner_config=moe_runner_config,
485
+ dispatch_output: StandardDispatchOutput,
486
+ ) -> CombineInput:
487
+
488
+ quant_info = TritonMoeQuantInfo(
489
+ w13_weight=layer.w13_weight,
490
+ w2_weight=layer.w2_weight,
475
491
  use_fp8_w8a8=True,
476
- per_channel_quant=False, # ModelOpt uses per-tensor quantization
477
- w1_scale=layer.w13_weight_scale,
492
+ per_channel_quant=False,
493
+ w13_scale=layer.w13_weight_scale,
478
494
  w2_scale=layer.w2_weight_scale,
479
- a1_scale=layer.w13_input_scale,
495
+ a13_scale=layer.w13_input_scale,
480
496
  a2_scale=layer.w2_input_scale,
481
497
  )
482
498
 
499
+ return self.runner.run(dispatch_output, quant_info)
500
+
483
501
 
484
502
  class ModelOptFp4Config(QuantizationConfig):
485
503
  """Config class for FP4."""
@@ -628,16 +646,21 @@ class ModelOptFp4Config(QuantizationConfig):
628
646
  def is_layer_excluded(self, prefix: str, exclude_modules: list):
629
647
  import regex as re
630
648
 
649
+ fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
650
+ prefix_split = prefix.split(".")
631
651
  for pattern in exclude_modules:
632
652
  regex_str = pattern.replace(".", r"\.").replace("*", r".*")
653
+ pattern_split = pattern.split(".")
633
654
  if re.fullmatch(regex_str, prefix):
634
655
  return True
635
-
636
- # Check if the last part of the excluded pattern is contained in the last part of the prefix
637
- # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
638
- pattern_last_part = pattern.split(".")[-1]
639
- prefix_last_part = prefix.split(".")[-1]
640
- if pattern_last_part in prefix_last_part:
656
+ elif (
657
+ pattern_split[-1] in fused_patterns
658
+ and pattern_split[-1] in prefix_split[-1]
659
+ ):
660
+ # Check if the last part of the excluded pattern is contained in the last part of the prefix
661
+ # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
662
+ # e.g., model.layers.{i}.self_attn.{fused_weight_name}
663
+ assert len(prefix_split) == 5 and len(pattern_split) == 5
641
664
  return True
642
665
  return False
643
666
 
@@ -859,6 +882,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
859
882
  """Access the global enable_flashinfer_cutlass_moe setting."""
860
883
  return get_moe_runner_backend().is_flashinfer_cutlass()
861
884
 
885
+ @property
886
+ def enable_flashinfer_cutedsl_moe(self) -> bool:
887
+ from sglang.srt.layers.moe import get_moe_runner_backend
888
+
889
+ """Access the global enable_flashinfer_cutedsl_moe setting."""
890
+ return get_moe_runner_backend().is_flashinfer_cutedsl()
891
+
862
892
  def create_weights(
863
893
  self,
864
894
  layer: torch.nn.Module,
@@ -970,15 +1000,17 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
970
1000
  )
971
1001
 
972
1002
  w13_input_scale = PerTensorScaleParameter(
973
- data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
1003
+ data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
974
1004
  weight_loader=weight_loader,
975
1005
  )
1006
+ w13_input_scale._sglang_require_global_experts = True
976
1007
  layer.register_parameter("w13_input_scale", w13_input_scale)
977
1008
 
978
1009
  w2_input_scale = PerTensorScaleParameter(
979
- data=torch.empty(layer.num_local_experts, dtype=torch.float32),
1010
+ data=torch.empty(layer.num_experts, dtype=torch.float32),
980
1011
  weight_loader=weight_loader,
981
1012
  )
1013
+ w2_input_scale._sglang_require_global_experts = True
982
1014
  layer.register_parameter("w2_input_scale", w2_input_scale)
983
1015
 
984
1016
  def swizzle_blockscale(self, scale: torch.Tensor):
@@ -1161,6 +1193,33 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1161
1193
  if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
1162
1194
  w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
1163
1195
  w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
1196
+ elif self.enable_flashinfer_cutedsl_moe:
1197
+ # All-expert-one-input-scale is mathematically different from default per-expert-input-scale
1198
+ # Thus we allow users to switch the flag to do thorough testing
1199
+ if CUTEDSL_MOE_SCALAR_INPUT_SCALE:
1200
+ w13_input_scale = (
1201
+ layer.w13_input_scale.max()
1202
+ .to(torch.float32)
1203
+ .repeat(layer.w13_input_scale.shape[0])
1204
+ )
1205
+ else:
1206
+ w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
1207
+ torch.float32
1208
+ )
1209
+
1210
+ w2_input_scale = layer.w2_input_scale
1211
+
1212
+ def _slice_scale(w):
1213
+ assert w.shape == (layer.num_experts,)
1214
+ assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
1215
+ return w[
1216
+ layer.moe_ep_rank
1217
+ * layer.num_local_experts : (layer.moe_ep_rank + 1)
1218
+ * layer.num_local_experts
1219
+ ]
1220
+
1221
+ w13_input_scale = _slice_scale(w13_input_scale)
1222
+ w2_input_scale = _slice_scale(w2_input_scale)
1164
1223
  else:
1165
1224
  w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1166
1225
  w2_input_scale = layer.w2_input_scale
@@ -1243,8 +1302,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1243
1302
  layer.w13_weight_scale,
1244
1303
  )
1245
1304
 
1246
- logger.info_once("Applied flashinfer weight processing for both w13 and w2")
1247
-
1248
1305
  else:
1249
1306
  # CUTLASS processing - handle w13 and w2 separately
1250
1307
 
@@ -1261,7 +1318,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1261
1318
  layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1262
1319
 
1263
1320
  # Both flashinfer cutlass and regular cutlass use same processing for w2
1264
- logger.info_once("Applied weight processing for both w13 and w2")
1265
1321
 
1266
1322
  # Set up CUTLASS MoE parameters
1267
1323
  device = layer.w13_weight.device
@@ -1278,21 +1334,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1278
1334
  # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
1279
1335
  return self.enable_flashinfer_cutlass_moe
1280
1336
 
1337
+ def create_moe_runner(
1338
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1339
+ ):
1340
+ self.moe_runner_config = moe_runner_config
1341
+
1281
1342
  def apply(
1282
1343
  self,
1283
1344
  layer: FusedMoE,
1284
- x: torch.Tensor,
1285
- topk_output: TopKOutput,
1286
- moe_runner_config: MoeRunnerConfig,
1287
- ) -> torch.Tensor:
1345
+ dispatch_output: StandardDispatchOutput,
1346
+ ) -> CombineInput:
1347
+
1348
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1349
+
1350
+ x = dispatch_output.hidden_states
1351
+ topk_output = dispatch_output.topk_output
1352
+
1288
1353
  assert (
1289
- moe_runner_config.activation == "silu"
1354
+ self.moe_runner_config.activation == "silu"
1290
1355
  ), "Only SiLU activation is supported."
1291
1356
 
1357
+ moe_runner_config = self.moe_runner_config
1358
+
1292
1359
  # Check if this is a FlashInferFP4MoE layer that should handle its own forward
1293
1360
  if hasattr(layer, "gemm1_weights_fp4_shuffled"):
1294
1361
  # This layer was processed with flashinfer TRTLLM - delegate to its own forward
1295
- return layer.forward(x, topk_output)
1362
+ return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
1296
1363
 
1297
1364
  if self.enable_flashinfer_cutlass_moe:
1298
1365
  assert (
@@ -1345,13 +1412,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1345
1412
  tp_rank=layer.moe_tp_rank,
1346
1413
  tune_max_num_tokens=next_power_of_2(x.shape[0]),
1347
1414
  )[0]
1348
- # Scale by routed_scaling_factor is fused into select_experts.
1349
1415
  if should_use_flashinfer_cutlass_moe_fp4_allgather():
1350
1416
  output, global_output = get_local_dp_buffer(), output
1351
1417
  get_tp_group().reduce_scatterv(
1352
1418
  global_output, output=output, sizes=get_dp_global_num_tokens()
1353
1419
  )
1354
- return output
1420
+ return StandardCombineInput(hidden_states=output)
1355
1421
 
1356
1422
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
1357
1423
 
@@ -1372,4 +1438,38 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1372
1438
  apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
1373
1439
  ).to(x.dtype)
1374
1440
  # Scale by routed_scaling_factor is fused into select_experts.
1375
- return output
1441
+ return StandardCombineInput(hidden_states=output)
1442
+
1443
+ def apply_without_routing_weights(
1444
+ self,
1445
+ layer: FusedMoE,
1446
+ x: torch.Tensor,
1447
+ masked_m: torch.Tensor,
1448
+ moe_runner_config: MoeRunnerConfig,
1449
+ ) -> torch.Tensor:
1450
+ assert (
1451
+ moe_runner_config.activation == "silu"
1452
+ ), "Only SiLU activation is supported."
1453
+
1454
+ assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe"
1455
+ assert (
1456
+ not moe_runner_config.apply_router_weight_on_input
1457
+ ), "apply_router_weight_on_input is not supported for Flashinfer"
1458
+
1459
+ from sglang.srt.layers.moe.flashinfer_cutedsl_moe import (
1460
+ flashinfer_cutedsl_moe_masked,
1461
+ )
1462
+
1463
+ out = flashinfer_cutedsl_moe_masked(
1464
+ hidden_states=x,
1465
+ input_global_scale=layer.w13_input_scale_quant,
1466
+ w1=layer.w13_weight,
1467
+ w1_blockscale=layer.w13_blockscale_swizzled,
1468
+ w1_alpha=layer.g1_alphas,
1469
+ w2=layer.w2_weight,
1470
+ a2_global_scale=layer.w2_input_scale_quant,
1471
+ w2_blockscale=layer.w2_blockscale_swizzled,
1472
+ w2_alpha=layer.g2_alphas,
1473
+ masked_m=masked_m,
1474
+ )
1475
+ return out
@@ -9,6 +9,8 @@ import torch
9
9
 
10
10
  from sglang.srt.distributed import get_tensor_model_parallel_rank
11
11
  from sglang.srt.distributed.parallel_state import get_tp_group
12
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
13
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
12
14
  from sglang.srt.layers.quantization.awq import AWQConfig
13
15
  from sglang.srt.layers.quantization.base_config import (
14
16
  FusedMoEMethodBase,
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
22
24
  logger = logging.getLogger(__name__)
23
25
 
24
26
  if TYPE_CHECKING:
25
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
26
- from sglang.srt.layers.moe.topk import TopKOutput
27
+ from sglang.srt.layers.moe.token_dispatcher import (
28
+ CombineInput,
29
+ StandardDispatchOutput,
30
+ )
27
31
 
28
32
 
29
33
  def get_weight_perm(num_bits: int):
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
349
353
  layer.register_parameter(key, param)
350
354
  set_weight_attrs(param, extra_weight_attrs)
351
355
 
356
+ def create_moe_runner(
357
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
358
+ ):
359
+ self.moe_runner_config = moe_runner_config
360
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
361
+
352
362
  def apply(
353
363
  self,
354
364
  layer: torch.nn.Module,
355
- x: torch.Tensor,
356
- topk_output: TopKOutput,
357
- moe_runner_config: MoeRunnerConfig,
358
- ) -> torch.Tensor:
359
- # avoid circular import
360
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
361
-
365
+ dispatch_output: StandardDispatchOutput,
366
+ ) -> CombineInput:
362
367
  assert (
363
- moe_runner_config.activation == "silu"
368
+ self.moe_runner_config.activation == "silu"
364
369
  ), "Only SiLU activation is supported."
365
370
 
366
371
  weight_bits = self.quant_config.weight_bits
367
372
  has_zp = self.quant_config.has_zp
368
373
 
369
- return fused_experts(
370
- x,
371
- layer.w13_qweight,
372
- layer.w2_qweight,
373
- topk_output=topk_output,
374
- moe_runner_config=moe_runner_config,
374
+ quant_info = TritonMoeQuantInfo(
375
+ w13_weight=layer.w13_qweight,
376
+ w2_weight=layer.w2_qweight,
375
377
  use_int4_w4a16=weight_bits == 4,
376
378
  use_int8_w8a16=weight_bits == 8,
377
- w1_scale=layer.w13_scales,
379
+ w13_scale=layer.w13_scales,
378
380
  w2_scale=layer.w2_scales,
379
- w1_zp=layer.w13_qzeros if has_zp else None,
381
+ w13_zp=layer.w13_qzeros if has_zp else None,
380
382
  w2_zp=layer.w2_qzeros if has_zp else None,
381
383
  block_shape=[0, layer.group_size],
382
384
  )
385
+ return self.runner.run(dispatch_output, quant_info)
383
386
 
384
387
  @staticmethod
385
388
  def get_weight_loader(layer, weight_loader):