sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -23,8 +23,14 @@ from sglang.srt.layers.moe import (
23
23
  get_moe_runner_backend,
24
24
  should_use_flashinfer_trtllm_moe,
25
25
  )
26
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
27
+ CombineInput,
28
+ StandardDispatcher,
29
+ StandardDispatchOutput,
30
+ )
26
31
  from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
27
32
  from sglang.srt.layers.quantization.base_config import (
33
+ FusedMoEMethodBase,
28
34
  QuantizationConfig,
29
35
  QuantizeMethodBase,
30
36
  )
@@ -68,16 +74,6 @@ if should_use_flashinfer_trtllm_moe():
68
74
  logger = logging.getLogger(__name__)
69
75
 
70
76
 
71
- def _is_fp4_quantization_enabled():
72
- """Check if ModelOpt FP4 quantization is enabled."""
73
- try:
74
- # Use the same simple check that works for class selection
75
- quantization = global_server_args_dict.get("quantization")
76
- return quantization == "modelopt_fp4"
77
- except:
78
- return False
79
-
80
-
81
77
  def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
82
78
  # Guess tokens per expert assuming perfect expert distribution first.
83
79
  num_tokens_per_expert = (num_tokens * top_k) // num_experts
@@ -152,16 +148,6 @@ class FusedMoE(torch.nn.Module):
152
148
  self.expert_map_cpu = None
153
149
  self.expert_map_gpu = None
154
150
 
155
- self.moe_runner_config = MoeRunnerConfig(
156
- activation=activation,
157
- apply_router_weight_on_input=apply_router_weight_on_input,
158
- inplace=inplace,
159
- no_combine=no_combine,
160
- routed_scaling_factor=routed_scaling_factor,
161
- gemm1_alpha=gemm1_alpha,
162
- gemm1_clamp_limit=gemm1_clamp_limit,
163
- )
164
-
165
151
  enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
166
152
 
167
153
  if enable_flashinfer_cutlass_moe and quant_config is None:
@@ -196,13 +182,6 @@ class FusedMoE(torch.nn.Module):
196
182
  self.use_presharded_weights = use_presharded_weights
197
183
 
198
184
  self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
199
- if quant_config is None:
200
- self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
201
- self.use_triton_kernels
202
- )
203
- else:
204
- self.quant_method = quant_config.get_quant_method(self, prefix)
205
- assert self.quant_method is not None
206
185
 
207
186
  self.quant_config = quant_config
208
187
  self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
@@ -213,12 +192,40 @@ class FusedMoE(torch.nn.Module):
213
192
  and self.use_flashinfer_mxfp4_moe
214
193
  ):
215
194
  hidden_size = round_up(hidden_size, 256)
195
+ self.hidden_size = hidden_size
196
+
197
+ self.moe_runner_config = MoeRunnerConfig(
198
+ num_experts=num_experts,
199
+ num_local_experts=self.num_local_experts,
200
+ hidden_size=hidden_size,
201
+ intermediate_size_per_partition=self.intermediate_size_per_partition,
202
+ layer_id=layer_id,
203
+ top_k=top_k,
204
+ num_fused_shared_experts=num_fused_shared_experts,
205
+ params_dtype=params_dtype,
206
+ activation=activation,
207
+ apply_router_weight_on_input=apply_router_weight_on_input,
208
+ inplace=inplace,
209
+ no_combine=no_combine,
210
+ routed_scaling_factor=routed_scaling_factor,
211
+ gemm1_alpha=gemm1_alpha,
212
+ gemm1_clamp_limit=gemm1_clamp_limit,
213
+ )
214
+
215
+ if quant_config is None:
216
+ self.quant_method: FusedMoEMethodBase = UnquantizedFusedMoEMethod(
217
+ self.use_triton_kernels
218
+ )
219
+ else:
220
+ self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method(
221
+ self, prefix
222
+ )
223
+ assert self.quant_method is not None
224
+
216
225
  self.quant_method.create_weights(
217
226
  layer=self,
218
227
  num_experts=self.num_local_experts,
219
228
  hidden_size=hidden_size,
220
- # FIXME: figure out which intermediate_size to use
221
- intermediate_size=self.intermediate_size_per_partition,
222
229
  intermediate_size_per_partition=self.intermediate_size_per_partition,
223
230
  params_dtype=params_dtype,
224
231
  weight_loader=(
@@ -229,6 +236,9 @@ class FusedMoE(torch.nn.Module):
229
236
  with_bias=with_bias,
230
237
  )
231
238
 
239
+ self.quant_method.create_moe_runner(self, self.moe_runner_config)
240
+ self.dispatcher = StandardDispatcher()
241
+
232
242
  def _load_per_tensor_weight_scale(
233
243
  self,
234
244
  shard_id: str,
@@ -522,10 +532,12 @@ class FusedMoE(torch.nn.Module):
522
532
  shard_id: str,
523
533
  expert_id: int,
524
534
  ) -> None:
535
+ # WARN: This makes the `expert_id` mean "local" and "global" in different cases
536
+ if not getattr(param, "_sglang_require_global_experts", False):
537
+ expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
538
+ if expert_id == -1:
539
+ return
525
540
 
526
- expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
527
- if expert_id == -1:
528
- return
529
541
  self._weight_loader_impl(
530
542
  param=param,
531
543
  loaded_weight=loaded_weight,
@@ -594,8 +606,10 @@ class FusedMoE(torch.nn.Module):
594
606
  loaded_weight = loaded_weight.to(param.data.device)
595
607
 
596
608
  if (
597
- "compressed" in self.quant_method.__class__.__name__.lower()
598
- or "w4afp8" in self.quant_config.get_name()
609
+ (
610
+ "compressed" in self.quant_method.__class__.__name__.lower()
611
+ or "w4afp8" in self.quant_config.get_name()
612
+ )
599
613
  and (param.data[expert_id] != 1).any()
600
614
  and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
601
615
  ):
@@ -811,16 +825,17 @@ class FusedMoE(torch.nn.Module):
811
825
  elif TopKOutputChecker.format_is_triton_kernel(topk_output):
812
826
  raise NotImplementedError()
813
827
 
814
- # Matrix multiply.
815
- with use_symmetric_memory(get_tp_group()) as sm:
828
+ dispatch_output = self.dispatcher.dispatch(
829
+ hidden_states=hidden_states, topk_output=topk_output
830
+ )
816
831
 
817
- final_hidden_states = self.quant_method.apply(
818
- layer=self,
819
- x=hidden_states,
820
- topk_output=topk_output,
821
- moe_runner_config=self.moe_runner_config,
822
- )
823
- sm.tag(final_hidden_states)
832
+ # TODO: consider using symmetric memory
833
+ combine_input = self.quant_method.apply(
834
+ layer=self,
835
+ dispatch_output=dispatch_output,
836
+ )
837
+
838
+ final_hidden_states = self.dispatcher.combine(combine_input)
824
839
 
825
840
  final_hidden_states = final_hidden_states[
826
841
  ..., :origin_hidden_states_dim
@@ -953,9 +968,9 @@ class FlashInferFusedMoE(FusedMoE):
953
968
  # Matrix multiply.
954
969
  final_hidden_states = self.quant_method.apply_with_router_logits(
955
970
  layer=self,
956
- x=hidden_states,
957
- topk_output=topk_output,
958
- moe_runner_config=self.moe_runner_config,
971
+ dispatch_output=StandardDispatchOutput(
972
+ hidden_states=hidden_states, topk_output=topk_output
973
+ ),
959
974
  )
960
975
 
961
976
  if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
@@ -1055,16 +1070,3 @@ class FlashInferFP4MoE(FusedMoE):
1055
1070
  )[0]
1056
1071
 
1057
1072
  return result
1058
-
1059
-
1060
- def get_fused_moe_impl_class():
1061
- """Factory function to get the appropriate FusedMoE implementation class."""
1062
- if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
1063
- # Use FP4 variant when FP4 quantization is enabled
1064
- return FlashInferFP4MoE
1065
- elif should_use_flashinfer_trtllm_moe():
1066
- # Use regular FlashInfer variant for non-FP4 FlashInfer cases
1067
- return FlashInferFusedMoE
1068
- else:
1069
- # Default case
1070
- return FusedMoE
@@ -1,3 +1,4 @@
1
1
  from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
2
+ from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
2
3
 
3
- __all__ = ["MoeRunnerConfig"]
4
+ __all__ = ["MoeRunnerConfig", "MoeRunner"]
@@ -1,9 +1,41 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
1
4
  from dataclasses import dataclass
2
- from typing import Optional
5
+ from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
6
+
7
+ import torch
8
+
9
+ from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
10
+
11
+ if TYPE_CHECKING:
12
+ from sglang.srt.layers.moe.moe_runner.triton import (
13
+ TritonRunnerCore,
14
+ TritonRunnerInput,
15
+ TritonRunnerOutput,
16
+ )
17
+ from sglang.srt.layers.moe.token_dispatcher import (
18
+ CombineInput,
19
+ CombineInputFormat,
20
+ DispatchOutput,
21
+ DispatchOutputFormat,
22
+ )
3
23
 
4
24
 
5
25
  @dataclass
6
26
  class MoeRunnerConfig:
27
+
28
+ # MoE parameters
29
+ num_experts: Optional[int] = None
30
+ num_local_experts: Optional[int] = None
31
+ hidden_size: Optional[int] = None
32
+ intermediate_size_per_partition: Optional[int] = None
33
+ layer_id: Optional[int] = None
34
+ top_k: Optional[int] = None
35
+ num_fused_shared_experts: Optional[int] = None
36
+ params_dtype: Optional[torch.dtype] = None
37
+
38
+ # Runner configuration
7
39
  activation: str = "silu"
8
40
  apply_router_weight_on_input: bool = False
9
41
  inplace: bool = True
@@ -11,3 +43,244 @@ class MoeRunnerConfig:
11
43
  routed_scaling_factor: Optional[float] = None
12
44
  gemm1_alpha: Optional[float] = None
13
45
  gemm1_clamp_limit: Optional[float] = None
46
+
47
+
48
+ @dataclass
49
+ class RunnerInput(ABC):
50
+
51
+ @property
52
+ @abstractmethod
53
+ def runner_backend(self) -> MoeRunnerBackend: ...
54
+
55
+ def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerInput]:
56
+ return self.runner_backend == MoeRunnerBackend.TRITON
57
+
58
+
59
+ class RunnerOutput(ABC):
60
+
61
+ @property
62
+ @abstractmethod
63
+ def runner_backend(self) -> MoeRunnerBackend: ...
64
+
65
+ def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerOutput]:
66
+ return self.runner_backend == MoeRunnerBackend.TRITON
67
+
68
+
69
+ @dataclass
70
+ class MoeQuantInfo(ABC):
71
+ """Moe quantization data."""
72
+
73
+ pass
74
+
75
+
76
+ class MoeRunnerCore(ABC):
77
+
78
+ def __init__(self, config: MoeRunnerConfig):
79
+ self.config = config
80
+
81
+ @abstractmethod
82
+ def run(
83
+ self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict
84
+ ) -> RunnerOutput:
85
+ pass
86
+
87
+ @property
88
+ @abstractmethod
89
+ def runner_backend(self) -> MoeRunnerBackend: ...
90
+
91
+ def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerCore]:
92
+ return self.runner_backend == MoeRunnerBackend.TRITON
93
+
94
+
95
+ class FusedOpPool:
96
+
97
+ _fused_funcs: dict[str, Callable] = {}
98
+
99
+ @classmethod
100
+ def register_fused_func(
101
+ cls, a2a_backend_name: str, runner_backend_name: str, fused_func: Callable
102
+ ):
103
+ key = (a2a_backend_name, runner_backend_name)
104
+ if key in cls._fused_funcs:
105
+ raise ValueError(
106
+ f"Fused function for {a2a_backend_name} to {runner_backend_name} is already registered."
107
+ )
108
+ assert MoeA2ABackend(
109
+ a2a_backend_name
110
+ ), f"Invalid dispatch name: {a2a_backend_name}"
111
+ assert MoeRunnerBackend(
112
+ runner_backend_name
113
+ ), f"Invalid runner name: {runner_backend_name}"
114
+ cls._fused_funcs[key] = fused_func
115
+
116
+ @classmethod
117
+ def get_fused_func(cls, dispatch_name: str, runner_name: str) -> Optional[Callable]:
118
+ key = (dispatch_name, runner_name)
119
+ fused_func = cls._fused_funcs.get(key)
120
+ return fused_func
121
+
122
+
123
+ class PermuteMethodPool:
124
+
125
+ _pre_permute_methods: dict[
126
+ Tuple[DispatchOutputFormat, MoeRunnerBackend], Callable
127
+ ] = {}
128
+ _post_permute_methods: dict[
129
+ Tuple[MoeRunnerBackend, CombineInputFormat], Callable
130
+ ] = {}
131
+
132
+ @classmethod
133
+ def register_pre_permute(
134
+ cls,
135
+ dispatch_output_name: str,
136
+ runner_backend_name: str,
137
+ permute_func: Callable,
138
+ ):
139
+ """
140
+ Register a customized pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
141
+
142
+ :param dispatch_output_name: The DispatchOutputFormat name.
143
+ :param runner_backend_name: The MoeRunnerBackend name.
144
+ :param permute_func: The permute function to register.
145
+ """
146
+ # TODO: check if registration is valid
147
+ key = (dispatch_output_name, runner_backend_name)
148
+ if key in cls._pre_permute_methods:
149
+ raise ValueError(
150
+ f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
151
+ )
152
+ cls._pre_permute_methods[key] = permute_func
153
+
154
+ @classmethod
155
+ def register_post_permute(
156
+ cls,
157
+ runner_backend_name: str,
158
+ combine_input_name: str,
159
+ permute_func: Callable,
160
+ ):
161
+ """
162
+ Register a customized post-permute function for the given MoeRunnerBackend and CombineInputFormat.
163
+
164
+ :param runner_backend_name: The MoeRunnerBackend name.
165
+ :param combine_input_name: The CombineInputFormat name.
166
+ :param permute_func: The permute function to register.
167
+ """
168
+ # TODO: check if registration is valid
169
+ key = (runner_backend_name, combine_input_name)
170
+ if key in cls._post_permute_methods:
171
+ raise ValueError(
172
+ f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
173
+ )
174
+ cls._post_permute_methods[key] = permute_func
175
+
176
+ @classmethod
177
+ def get_pre_permute(
178
+ cls,
179
+ dispatch_output_format: DispatchOutputFormat,
180
+ runner_input_format: MoeRunnerBackend,
181
+ ) -> Callable:
182
+ """
183
+ Retrieve the pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
184
+
185
+ :param dispatch_output_format: The DispatchOutputFormat type.
186
+ :param runner_input_format: The MoeRunnerBackend type.
187
+ :return: The registered permute function or None if not found.
188
+ """
189
+ key = (dispatch_output_format, runner_input_format)
190
+ pre_permute_func = cls._pre_permute_methods.get(key)
191
+ assert (
192
+ pre_permute_func is not None
193
+ ), f"Pre-permute function for {dispatch_output_format} to {runner_input_format} is not registered"
194
+ return pre_permute_func
195
+
196
+ @classmethod
197
+ def get_post_permute(
198
+ cls,
199
+ runner_output_format: MoeRunnerBackend,
200
+ combine_input_format: CombineInputFormat,
201
+ ) -> Callable:
202
+ """
203
+ Retrieve the post-permute function for the given MoeRunnerBackend and CombineInputFormat.
204
+
205
+ :param runner_output_format: The MoeRunnerBackend type.
206
+ :param combine_input_format: The CombineInputFormat type.
207
+ :return: The registered permute function or None if not found.
208
+ """
209
+ key = (runner_output_format, combine_input_format)
210
+ post_permute_func = cls._post_permute_methods.get(key)
211
+ assert (
212
+ post_permute_func is not None
213
+ ), f"Post-permute function for {runner_output_format} to {combine_input_format} is not registered"
214
+ return post_permute_func
215
+
216
+
217
+ def register_fused_func(
218
+ a2a_backend_name: str,
219
+ runner_backend_name: str,
220
+ ) -> Callable:
221
+ """
222
+ Decorator to register a fused function for the given DispatchOutputFormat and MoeRunnerBackend.
223
+
224
+ :param a2a_backend_name: The A2A backend name.
225
+ :param runner_backend_name: The MoeRunnerBackend name.
226
+ :return: The decorator function.
227
+ """
228
+
229
+ def decorator(fused_func: Callable):
230
+ FusedOpPool.register_fused_func(
231
+ a2a_backend_name, runner_backend_name, fused_func
232
+ )
233
+ return fused_func
234
+
235
+ return decorator
236
+
237
+
238
+ def register_pre_permute(
239
+ dispatch_output_name: str,
240
+ runner_backend_name: str,
241
+ ) -> Callable:
242
+ """
243
+ Decorator to register a pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
244
+
245
+ :param dispatch_output_name: The DispatchOutputFormat name.
246
+ :param runner_backend_name: The MoeRunnerBackend name.
247
+ :return: The decorator function.
248
+ """
249
+
250
+ def decorator(
251
+ permute_func: Callable[
252
+ [DispatchOutput, MoeQuantInfo, MoeRunnerConfig, dict], RunnerInput
253
+ ]
254
+ ) -> Callable:
255
+
256
+ PermuteMethodPool.register_pre_permute(
257
+ dispatch_output_name, runner_backend_name, permute_func
258
+ )
259
+ return permute_func
260
+
261
+ return decorator
262
+
263
+
264
+ def register_post_permute(
265
+ runner_backend_name: str,
266
+ combine_input_name: str,
267
+ ) -> Callable:
268
+ """
269
+ Decorator to register a post-permute function for the given MoeRunnerBackend and CombineInputFormat.
270
+
271
+ :param runner_backend_name: The MoeRunnerBackend name.
272
+ :param combine_input_name: The CombineInputFormat name.
273
+ :return: The decorator function.
274
+ """
275
+
276
+ def decorator(
277
+ permute_func: Callable[
278
+ [RunnerOutput, MoeQuantInfo, MoeRunnerConfig, dict], CombineInput
279
+ ]
280
+ ) -> Callable:
281
+ PermuteMethodPool.register_post_permute(
282
+ runner_backend_name, combine_input_name, permute_func
283
+ )
284
+ return permute_func
285
+
286
+ return decorator
@@ -0,0 +1,80 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from typing import TYPE_CHECKING
6
+
7
+ from sglang.srt.layers.moe.moe_runner.base import (
8
+ FusedOpPool,
9
+ MoeRunnerConfig,
10
+ PermuteMethodPool,
11
+ )
12
+ from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
13
+ from sglang.srt.layers.moe.utils import get_moe_a2a_backend
14
+
15
+ if TYPE_CHECKING:
16
+ from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
17
+ from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
18
+ from sglang.srt.layers.moe.utils import MoeRunnerBackend
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class MoeRunner:
24
+
25
+ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
26
+ self.runner_backend = runner_backend
27
+ self.config = config
28
+
29
+ self.fused_func = None
30
+
31
+ if runner_backend.is_triton():
32
+ self.runner_core = TritonRunnerCore(config)
33
+ else:
34
+ raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
35
+
36
+ a2a_backend_name = get_moe_a2a_backend().value
37
+ runner_backend_name = runner_backend.value
38
+
39
+ self.fused_func = FusedOpPool.get_fused_func(
40
+ a2a_backend_name, runner_backend_name
41
+ )
42
+
43
+ SGLANG_CI_DISABLE_MOE_FUSED_FUNC = os.environ.get(
44
+ "SGLANG_CI_DISABLE_MOE_FUSED_FUNC", "0"
45
+ )
46
+ if SGLANG_CI_DISABLE_MOE_FUSED_FUNC == "1":
47
+ logger.info(
48
+ "SGLANG_CI_DISABLE_MOE_FUSED_FUNC is set to 1, disabling fused func"
49
+ )
50
+ self.fused_func = None
51
+
52
+ def run(
53
+ self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo
54
+ ) -> CombineInput:
55
+
56
+ if self.fused_func is not None:
57
+ return self.fused_func(dispatch_output, quant_info, self.config)
58
+
59
+ dispatch_format = dispatch_output.format.value
60
+ runner_format = self.runner_core.runner_backend.value
61
+ self.pre_permute_func = PermuteMethodPool.get_pre_permute(
62
+ dispatch_format, runner_format
63
+ )
64
+
65
+ running_state = {}
66
+ runner_input = self.pre_permute_func(
67
+ dispatch_output, quant_info, self.config, running_state
68
+ )
69
+ runner_output = self.runner_core.run(runner_input, quant_info, running_state)
70
+
71
+ runner_format = self.runner_core.runner_backend.value
72
+ combine_format = dispatch_output.format.value
73
+ self.post_permute_func = PermuteMethodPool.get_post_permute(
74
+ runner_format, combine_format
75
+ )
76
+ combine_input = self.post_permute_func(
77
+ runner_output, quant_info, self.config, running_state
78
+ )
79
+
80
+ return combine_input