sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -23,8 +23,13 @@ from sglang.srt.layers.moe import (
23
23
  get_moe_runner_backend,
24
24
  should_use_flashinfer_trtllm_moe,
25
25
  )
26
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
27
+ CombineInput,
28
+ StandardDispatcher,
29
+ )
26
30
  from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
27
31
  from sglang.srt.layers.quantization.base_config import (
32
+ FusedMoEMethodBase,
28
33
  QuantizationConfig,
29
34
  QuantizeMethodBase,
30
35
  )
@@ -152,16 +157,6 @@ class FusedMoE(torch.nn.Module):
152
157
  self.expert_map_cpu = None
153
158
  self.expert_map_gpu = None
154
159
 
155
- self.moe_runner_config = MoeRunnerConfig(
156
- activation=activation,
157
- apply_router_weight_on_input=apply_router_weight_on_input,
158
- inplace=inplace,
159
- no_combine=no_combine,
160
- routed_scaling_factor=routed_scaling_factor,
161
- gemm1_alpha=gemm1_alpha,
162
- gemm1_clamp_limit=gemm1_clamp_limit,
163
- )
164
-
165
160
  enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
166
161
 
167
162
  if enable_flashinfer_cutlass_moe and quant_config is None:
@@ -175,6 +170,8 @@ class FusedMoE(torch.nn.Module):
175
170
  self.moe_tp_rank = get_moe_tensor_parallel_rank()
176
171
  assert num_experts % self.moe_ep_size == 0
177
172
  self.num_local_experts = num_experts // self.moe_ep_size
173
+ self.start_expert_id = self.moe_ep_rank * self.num_local_experts
174
+ self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
178
175
  if self.moe_ep_size > 1:
179
176
  # TODO(ch-wan): support shared experts fusion
180
177
  # Create a tensor of size num_experts filled with -1
@@ -194,13 +191,6 @@ class FusedMoE(torch.nn.Module):
194
191
  self.use_presharded_weights = use_presharded_weights
195
192
 
196
193
  self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
197
- if quant_config is None:
198
- self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
199
- self.use_triton_kernels
200
- )
201
- else:
202
- self.quant_method = quant_config.get_quant_method(self, prefix)
203
- assert self.quant_method is not None
204
194
 
205
195
  self.quant_config = quant_config
206
196
  self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
@@ -211,12 +201,40 @@ class FusedMoE(torch.nn.Module):
211
201
  and self.use_flashinfer_mxfp4_moe
212
202
  ):
213
203
  hidden_size = round_up(hidden_size, 256)
204
+ self.hidden_size = hidden_size
205
+
206
+ self.moe_runner_config = MoeRunnerConfig(
207
+ num_experts=num_experts,
208
+ num_local_experts=self.num_local_experts,
209
+ hidden_size=hidden_size,
210
+ intermediate_size_per_partition=self.intermediate_size_per_partition,
211
+ layer_id=layer_id,
212
+ top_k=top_k,
213
+ num_fused_shared_experts=num_fused_shared_experts,
214
+ params_dtype=params_dtype,
215
+ activation=activation,
216
+ apply_router_weight_on_input=apply_router_weight_on_input,
217
+ inplace=inplace,
218
+ no_combine=no_combine,
219
+ routed_scaling_factor=routed_scaling_factor,
220
+ gemm1_alpha=gemm1_alpha,
221
+ gemm1_clamp_limit=gemm1_clamp_limit,
222
+ )
223
+
224
+ if quant_config is None:
225
+ self.quant_method: FusedMoEMethodBase = UnquantizedFusedMoEMethod(
226
+ self.use_triton_kernels
227
+ )
228
+ else:
229
+ self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method(
230
+ self, prefix
231
+ )
232
+ assert self.quant_method is not None
233
+
214
234
  self.quant_method.create_weights(
215
235
  layer=self,
216
236
  num_experts=self.num_local_experts,
217
237
  hidden_size=hidden_size,
218
- # FIXME: figure out which intermediate_size to use
219
- intermediate_size=self.intermediate_size_per_partition,
220
238
  intermediate_size_per_partition=self.intermediate_size_per_partition,
221
239
  params_dtype=params_dtype,
222
240
  weight_loader=(
@@ -227,6 +245,9 @@ class FusedMoE(torch.nn.Module):
227
245
  with_bias=with_bias,
228
246
  )
229
247
 
248
+ self.quant_method.create_moe_runner(self, self.moe_runner_config)
249
+ self.dispatcher = StandardDispatcher()
250
+
230
251
  def _load_per_tensor_weight_scale(
231
252
  self,
232
253
  shard_id: str,
@@ -592,9 +613,12 @@ class FusedMoE(torch.nn.Module):
592
613
  loaded_weight = loaded_weight.to(param.data.device)
593
614
 
594
615
  if (
595
- "compressed" in self.quant_method.__class__.__name__.lower()
596
- and param.data[expert_id] != 1
597
- and (param.data[expert_id] - loaded_weight).abs() > 1e-5
616
+ (
617
+ "compressed" in self.quant_method.__class__.__name__.lower()
618
+ or "w4afp8" in self.quant_config.get_name()
619
+ )
620
+ and (param.data[expert_id] != 1).any()
621
+ and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
598
622
  ):
599
623
  raise ValueError(
600
624
  "input_scales of w1 and w3 of a layer "
@@ -808,16 +832,17 @@ class FusedMoE(torch.nn.Module):
808
832
  elif TopKOutputChecker.format_is_triton_kernel(topk_output):
809
833
  raise NotImplementedError()
810
834
 
811
- # Matrix multiply.
812
- with use_symmetric_memory(get_tp_group()) as sm:
835
+ dispatch_output = self.dispatcher.dispatch(
836
+ hidden_states=hidden_states, topk_output=topk_output
837
+ )
813
838
 
814
- final_hidden_states = self.quant_method.apply(
815
- layer=self,
816
- x=hidden_states,
817
- topk_output=topk_output,
818
- moe_runner_config=self.moe_runner_config,
819
- )
820
- sm.tag(final_hidden_states)
839
+ # TODO: consider using symmetric memory
840
+ combine_input = self.quant_method.apply(
841
+ layer=self,
842
+ dispatch_output=dispatch_output,
843
+ )
844
+
845
+ final_hidden_states = self.dispatcher.combine(combine_input)
821
846
 
822
847
  final_hidden_states = final_hidden_states[
823
848
  ..., :origin_hidden_states_dim
@@ -952,7 +977,6 @@ class FlashInferFusedMoE(FusedMoE):
952
977
  layer=self,
953
978
  x=hidden_states,
954
979
  topk_output=topk_output,
955
- moe_runner_config=self.moe_runner_config,
956
980
  )
957
981
 
958
982
  if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
@@ -1052,16 +1076,3 @@ class FlashInferFP4MoE(FusedMoE):
1052
1076
  )[0]
1053
1077
 
1054
1078
  return result
1055
-
1056
-
1057
- def get_fused_moe_impl_class():
1058
- """Factory function to get the appropriate FusedMoE implementation class."""
1059
- if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
1060
- # Use FP4 variant when FP4 quantization is enabled
1061
- return FlashInferFP4MoE
1062
- elif should_use_flashinfer_trtllm_moe():
1063
- # Use regular FlashInfer variant for non-FP4 FlashInfer cases
1064
- return FlashInferFusedMoE
1065
- else:
1066
- # Default case
1067
- return FusedMoE
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import triton
7
+
8
+ from sglang.srt.utils import is_cuda, is_hip
9
+
10
+ _is_cuda = is_cuda()
11
+ _is_hip = is_hip()
12
+
13
+ if _is_cuda or _is_hip:
14
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
15
+
16
+
17
+ def moe_align_block_size(
18
+ topk_ids: torch.Tensor, block_size: int, num_experts: int
19
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
20
+ """
21
+ Aligns the token distribution across experts to be compatible with block
22
+ size for matrix multiplication.
23
+
24
+ Parameters:
25
+ - topk_ids: A tensor of shape [total_tokens, top_k] representing the
26
+ top-k expert indices for each token.
27
+ - block_size: The block size used in block matrix multiplication.
28
+ - num_experts: The total number of experts.
29
+
30
+ Returns:
31
+ - sorted_token_ids: A tensor containing the sorted token indices according
32
+ to their allocated expert.
33
+ - expert_ids: A tensor indicating the assigned expert index for each block.
34
+ - num_tokens_post_padded: The total number of tokens after padding,
35
+ ensuring divisibility by block_size.
36
+
37
+ This function pads the number of tokens that each expert needs to process
38
+ so that it is divisible by block_size.
39
+ Padding ensures that during block matrix multiplication, the dimensions
40
+ align correctly.
41
+
42
+ Example:
43
+ Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
44
+ block_size = 4, and num_experts = 4:
45
+ - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
46
+ with each expert needing to process 3 tokens.
47
+ - As block_size is 4, we pad 1 token for each expert.
48
+ - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
49
+ - Then append padding tokens [12, 12, 12, 12] for each block.
50
+ - After sorting by expert index, we obtain token_ids
51
+ [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
52
+ Tokens 12 are non-existent (padding) and are ignored in
53
+ the subsequent matrix multiplication.
54
+ - The padding ensures that the total number of tokens is now divisible
55
+ by block_size for proper block matrix operations.
56
+ """
57
+ max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
58
+ sorted_ids = torch.empty(
59
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
60
+ )
61
+ max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
62
+ expert_ids = torch.empty(
63
+ (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
64
+ )
65
+ num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
66
+
67
+ # In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
68
+ cumsum_buffer = torch.empty(
69
+ (num_experts + 2,), dtype=torch.int32, device=topk_ids.device
70
+ )
71
+
72
+ # Threshold based on benchmark results
73
+ fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
74
+ if not fuse_sorted_ids_padding:
75
+ sorted_ids.fill_(topk_ids.numel())
76
+
77
+ sgl_moe_align_block_size(
78
+ topk_ids,
79
+ num_experts + 1,
80
+ block_size,
81
+ sorted_ids,
82
+ expert_ids,
83
+ num_tokens_post_pad,
84
+ cumsum_buffer,
85
+ fuse_sorted_ids_padding,
86
+ )
87
+ return sorted_ids, expert_ids, num_tokens_post_pad
@@ -1,3 +1,4 @@
1
1
  from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
2
+ from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
2
3
 
3
- __all__ = ["MoeRunnerConfig"]
4
+ __all__ = ["MoeRunnerConfig", "MoeRunner"]
@@ -1,9 +1,41 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
1
4
  from dataclasses import dataclass
2
- from typing import Optional
5
+ from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
6
+
7
+ import torch
8
+
9
+ from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
10
+
11
+ if TYPE_CHECKING:
12
+ from sglang.srt.layers.moe.moe_runner.triton import (
13
+ TritonRunnerCore,
14
+ TritonRunnerInput,
15
+ TritonRunnerOutput,
16
+ )
17
+ from sglang.srt.layers.moe.token_dispatcher import (
18
+ CombineInput,
19
+ CombineInputFormat,
20
+ DispatchOutput,
21
+ DispatchOutputFormat,
22
+ )
3
23
 
4
24
 
5
25
  @dataclass
6
26
  class MoeRunnerConfig:
27
+
28
+ # MoE parameters
29
+ num_experts: Optional[int] = None
30
+ num_local_experts: Optional[int] = None
31
+ hidden_size: Optional[int] = None
32
+ intermediate_size_per_partition: Optional[int] = None
33
+ layer_id: Optional[int] = None
34
+ top_k: Optional[int] = None
35
+ num_fused_shared_experts: Optional[int] = None
36
+ params_dtype: Optional[torch.dtype] = None
37
+
38
+ # Runner configuration
7
39
  activation: str = "silu"
8
40
  apply_router_weight_on_input: bool = False
9
41
  inplace: bool = True
@@ -11,3 +43,244 @@ class MoeRunnerConfig:
11
43
  routed_scaling_factor: Optional[float] = None
12
44
  gemm1_alpha: Optional[float] = None
13
45
  gemm1_clamp_limit: Optional[float] = None
46
+
47
+
48
+ @dataclass
49
+ class RunnerInput(ABC):
50
+
51
+ @property
52
+ @abstractmethod
53
+ def runner_backend(self) -> MoeRunnerBackend: ...
54
+
55
+ def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerInput]:
56
+ return self.runner_backend == MoeRunnerBackend.TRITON
57
+
58
+
59
+ class RunnerOutput(ABC):
60
+
61
+ @property
62
+ @abstractmethod
63
+ def runner_backend(self) -> MoeRunnerBackend: ...
64
+
65
+ def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerOutput]:
66
+ return self.runner_backend == MoeRunnerBackend.TRITON
67
+
68
+
69
+ @dataclass
70
+ class MoeQuantInfo(ABC):
71
+ """Moe quantization data."""
72
+
73
+ pass
74
+
75
+
76
+ class MoeRunnerCore(ABC):
77
+
78
+ def __init__(self, config: MoeRunnerConfig):
79
+ self.config = config
80
+
81
+ @abstractmethod
82
+ def run(
83
+ self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict
84
+ ) -> RunnerOutput:
85
+ pass
86
+
87
+ @property
88
+ @abstractmethod
89
+ def runner_backend(self) -> MoeRunnerBackend: ...
90
+
91
+ def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerCore]:
92
+ return self.runner_backend == MoeRunnerBackend.TRITON
93
+
94
+
95
+ class FusedOpPool:
96
+
97
+ _fused_funcs: dict[str, Callable] = {}
98
+
99
+ @classmethod
100
+ def register_fused_func(
101
+ cls, a2a_backend_name: str, runner_backend_name: str, fused_func: Callable
102
+ ):
103
+ key = (a2a_backend_name, runner_backend_name)
104
+ if key in cls._fused_funcs:
105
+ raise ValueError(
106
+ f"Fused function for {a2a_backend_name} to {runner_backend_name} is already registered."
107
+ )
108
+ assert MoeA2ABackend(
109
+ a2a_backend_name
110
+ ), f"Invalid dispatch name: {a2a_backend_name}"
111
+ assert MoeRunnerBackend(
112
+ runner_backend_name
113
+ ), f"Invalid runner name: {runner_backend_name}"
114
+ cls._fused_funcs[key] = fused_func
115
+
116
+ @classmethod
117
+ def get_fused_func(cls, dispatch_name: str, runner_name: str) -> Optional[Callable]:
118
+ key = (dispatch_name, runner_name)
119
+ fused_func = cls._fused_funcs.get(key)
120
+ return fused_func
121
+
122
+
123
+ class PermuteMethodPool:
124
+
125
+ _pre_permute_methods: dict[
126
+ Tuple[DispatchOutputFormat, MoeRunnerBackend], Callable
127
+ ] = {}
128
+ _post_permute_methods: dict[
129
+ Tuple[MoeRunnerBackend, CombineInputFormat], Callable
130
+ ] = {}
131
+
132
+ @classmethod
133
+ def register_pre_permute(
134
+ cls,
135
+ dispatch_output_name: str,
136
+ runner_backend_name: str,
137
+ permute_func: Callable,
138
+ ):
139
+ """
140
+ Register a customized pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
141
+
142
+ :param dispatch_output_name: The DispatchOutputFormat name.
143
+ :param runner_backend_name: The MoeRunnerBackend name.
144
+ :param permute_func: The permute function to register.
145
+ """
146
+ # TODO: check if registration is valid
147
+ key = (dispatch_output_name, runner_backend_name)
148
+ if key in cls._pre_permute_methods:
149
+ raise ValueError(
150
+ f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
151
+ )
152
+ cls._pre_permute_methods[key] = permute_func
153
+
154
+ @classmethod
155
+ def register_post_permute(
156
+ cls,
157
+ runner_backend_name: str,
158
+ combine_input_name: str,
159
+ permute_func: Callable,
160
+ ):
161
+ """
162
+ Register a customized post-permute function for the given MoeRunnerBackend and CombineInputFormat.
163
+
164
+ :param runner_backend_name: The MoeRunnerBackend name.
165
+ :param combine_input_name: The CombineInputFormat name.
166
+ :param permute_func: The permute function to register.
167
+ """
168
+ # TODO: check if registration is valid
169
+ key = (runner_backend_name, combine_input_name)
170
+ if key in cls._post_permute_methods:
171
+ raise ValueError(
172
+ f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
173
+ )
174
+ cls._post_permute_methods[key] = permute_func
175
+
176
+ @classmethod
177
+ def get_pre_permute(
178
+ cls,
179
+ dispatch_output_format: DispatchOutputFormat,
180
+ runner_input_format: MoeRunnerBackend,
181
+ ) -> Callable:
182
+ """
183
+ Retrieve the pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
184
+
185
+ :param dispatch_output_format: The DispatchOutputFormat type.
186
+ :param runner_input_format: The MoeRunnerBackend type.
187
+ :return: The registered permute function or None if not found.
188
+ """
189
+ key = (dispatch_output_format, runner_input_format)
190
+ pre_permute_func = cls._pre_permute_methods.get(key)
191
+ assert (
192
+ pre_permute_func is not None
193
+ ), f"Pre-permute function for {dispatch_output_format} to {runner_input_format} is not registered"
194
+ return pre_permute_func
195
+
196
+ @classmethod
197
+ def get_post_permute(
198
+ cls,
199
+ runner_output_format: MoeRunnerBackend,
200
+ combine_input_format: CombineInputFormat,
201
+ ) -> Callable:
202
+ """
203
+ Retrieve the post-permute function for the given MoeRunnerBackend and CombineInputFormat.
204
+
205
+ :param runner_output_format: The MoeRunnerBackend type.
206
+ :param combine_input_format: The CombineInputFormat type.
207
+ :return: The registered permute function or None if not found.
208
+ """
209
+ key = (runner_output_format, combine_input_format)
210
+ post_permute_func = cls._post_permute_methods.get(key)
211
+ assert (
212
+ post_permute_func is not None
213
+ ), f"Post-permute function for {runner_output_format} to {combine_input_format} is not registered"
214
+ return post_permute_func
215
+
216
+
217
+ def register_fused_func(
218
+ a2a_backend_name: str,
219
+ runner_backend_name: str,
220
+ ) -> Callable:
221
+ """
222
+ Decorator to register a fused function for the given DispatchOutputFormat and MoeRunnerBackend.
223
+
224
+ :param a2a_backend_name: The A2A backend name.
225
+ :param runner_backend_name: The MoeRunnerBackend name.
226
+ :return: The decorator function.
227
+ """
228
+
229
+ def decorator(fused_func: Callable):
230
+ FusedOpPool.register_fused_func(
231
+ a2a_backend_name, runner_backend_name, fused_func
232
+ )
233
+ return fused_func
234
+
235
+ return decorator
236
+
237
+
238
+ def register_pre_permute(
239
+ dispatch_output_name: str,
240
+ runner_backend_name: str,
241
+ ) -> Callable:
242
+ """
243
+ Decorator to register a pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
244
+
245
+ :param dispatch_output_name: The DispatchOutputFormat name.
246
+ :param runner_backend_name: The MoeRunnerBackend name.
247
+ :return: The decorator function.
248
+ """
249
+
250
+ def decorator(
251
+ permute_func: Callable[
252
+ [DispatchOutput, MoeQuantInfo, MoeRunnerConfig, dict], RunnerInput
253
+ ]
254
+ ) -> Callable:
255
+
256
+ PermuteMethodPool.register_pre_permute(
257
+ dispatch_output_name, runner_backend_name, permute_func
258
+ )
259
+ return permute_func
260
+
261
+ return decorator
262
+
263
+
264
+ def register_post_permute(
265
+ runner_backend_name: str,
266
+ combine_input_name: str,
267
+ ) -> Callable:
268
+ """
269
+ Decorator to register a post-permute function for the given MoeRunnerBackend and CombineInputFormat.
270
+
271
+ :param runner_backend_name: The MoeRunnerBackend name.
272
+ :param combine_input_name: The CombineInputFormat name.
273
+ :return: The decorator function.
274
+ """
275
+
276
+ def decorator(
277
+ permute_func: Callable[
278
+ [RunnerOutput, MoeQuantInfo, MoeRunnerConfig, dict], CombineInput
279
+ ]
280
+ ) -> Callable:
281
+ PermuteMethodPool.register_post_permute(
282
+ runner_backend_name, combine_input_name, permute_func
283
+ )
284
+ return permute_func
285
+
286
+ return decorator
@@ -0,0 +1,80 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from typing import TYPE_CHECKING
6
+
7
+ from sglang.srt.layers.moe.moe_runner.base import (
8
+ FusedOpPool,
9
+ MoeRunnerConfig,
10
+ PermuteMethodPool,
11
+ )
12
+ from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
13
+ from sglang.srt.layers.moe.utils import get_moe_a2a_backend
14
+
15
+ if TYPE_CHECKING:
16
+ from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
17
+ from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
18
+ from sglang.srt.layers.moe.utils import MoeRunnerBackend
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class MoeRunner:
24
+
25
+ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
26
+ self.runner_backend = runner_backend
27
+ self.config = config
28
+
29
+ self.fused_func = None
30
+
31
+ if runner_backend.is_triton():
32
+ self.runner_core = TritonRunnerCore(config)
33
+ else:
34
+ raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
35
+
36
+ a2a_backend_name = get_moe_a2a_backend().value
37
+ runner_backend_name = runner_backend.value
38
+
39
+ self.fused_func = FusedOpPool.get_fused_func(
40
+ a2a_backend_name, runner_backend_name
41
+ )
42
+
43
+ SGLANG_CI_DISABLE_MOE_FUSED_FUNC = os.environ.get(
44
+ "SGLANG_CI_DISABLE_MOE_FUSED_FUNC", "0"
45
+ )
46
+ if SGLANG_CI_DISABLE_MOE_FUSED_FUNC == "1":
47
+ logger.info(
48
+ "SGLANG_CI_DISABLE_MOE_FUSED_FUNC is set to 1, disabling fused func"
49
+ )
50
+ self.fused_func = None
51
+
52
+ def run(
53
+ self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo
54
+ ) -> CombineInput:
55
+
56
+ if self.fused_func is not None:
57
+ return self.fused_func(dispatch_output, quant_info, self.config)
58
+
59
+ dispatch_format = dispatch_output.format.value
60
+ runner_format = self.runner_core.runner_backend.value
61
+ self.pre_permute_func = PermuteMethodPool.get_pre_permute(
62
+ dispatch_format, runner_format
63
+ )
64
+
65
+ running_state = {}
66
+ runner_input = self.pre_permute_func(
67
+ dispatch_output, quant_info, self.config, running_state
68
+ )
69
+ runner_output = self.runner_core.run(runner_input, quant_info, running_state)
70
+
71
+ runner_format = self.runner_core.runner_backend.value
72
+ combine_format = dispatch_output.format.value
73
+ self.post_permute_func = PermuteMethodPool.get_post_permute(
74
+ runner_format, combine_format
75
+ )
76
+ combine_input = self.post_permute_func(
77
+ runner_output, quant_info, self.config, running_state
78
+ )
79
+
80
+ return combine_input