sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 8,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 2
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 8,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 8,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 2
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -1,3 +1,4 @@
1
+ # NOTE: this file will be separated into sglang/srt/layers/moe/moe_runner/triton_utils.py
1
2
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py
2
3
 
3
4
  """Fused MoE kernel."""
@@ -6,13 +7,12 @@ from __future__ import annotations
6
7
 
7
8
  import functools
8
9
  import os
9
- from typing import List, Optional
10
+ from typing import TYPE_CHECKING, List, Optional
10
11
 
11
12
  import torch
12
13
  import triton.language as tl
13
14
 
14
15
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
15
- from sglang.srt.layers.moe.topk import StandardTopKOutput
16
16
  from sglang.srt.utils import (
17
17
  cpu_has_amx_support,
18
18
  direct_register_custom_op,
@@ -26,6 +26,9 @@ from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_c
26
26
  from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton
27
27
  from .moe_align_block_size import moe_align_block_size
28
28
 
29
+ if TYPE_CHECKING:
30
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
31
+
29
32
  _is_hip = is_hip()
30
33
  _is_cuda = is_cuda()
31
34
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -43,7 +43,7 @@ def get_moe_configs(
43
43
  be picked and the associated configuration chosen to invoke the kernel.
44
44
  """
45
45
  # Supported Triton versions, should be sorted from the newest to the oldest
46
- supported_triton_versions = ["3.3.1", "3.2.0", "3.1.0"]
46
+ supported_triton_versions = ["3.4.0", "3.3.1", "3.2.0", "3.1.0"]
47
47
 
48
48
  # First look up if an optimized configuration is available in the configs
49
49
  # directory
@@ -51,10 +51,14 @@ def get_moe_configs(
51
51
 
52
52
  # We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
53
53
  # so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
54
+ config_dir = os.environ.get(
55
+ "SGLANG_MOE_CONFIG_DIR", os.path.dirname(os.path.realpath(__file__))
56
+ )
57
+
54
58
  triton_version = triton.__version__
55
59
  version_dir = f"triton_{triton_version.replace('.', '_')}"
56
60
  config_file_path = os.path.join(
57
- os.path.dirname(os.path.realpath(__file__)),
61
+ config_dir,
58
62
  "configs",
59
63
  version_dir,
60
64
  json_file_name,
@@ -75,7 +79,7 @@ def get_moe_configs(
75
79
  if try_triton_version == triton_version:
76
80
  continue
77
81
  try_config_file_path = os.path.join(
78
- os.path.dirname(os.path.realpath(__file__)),
82
+ config_dir,
79
83
  "configs",
80
84
  f"triton_{try_triton_version.replace('.', '_')}",
81
85
  json_file_name,
@@ -735,29 +735,32 @@ def _moe_sum_reduce_kernel(
735
735
  token_block_id = tl.program_id(0)
736
736
  dim_block_id = tl.program_id(1)
737
737
 
738
- token_start = token_block_id * BLOCK_M
739
- token_end = min((token_block_id + 1) * BLOCK_M, token_num)
738
+ offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
739
+ offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
740
740
 
741
- dim_start = dim_block_id * BLOCK_DIM
742
- dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
741
+ mask_token = offs_token < token_num
742
+ mask_dim = offs_dim < hidden_dim
743
743
 
744
- offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
744
+ base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
745
745
 
746
- for token_index in range(token_start, token_end):
747
- accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
748
- input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
749
- for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
750
- tmp = tl.load(
751
- input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
752
- )
753
- accumulator += tmp
754
- accumulator = accumulator * routed_scaling_factor
755
- store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
756
- tl.store(
757
- store_t_ptr,
758
- accumulator.to(input_ptr.dtype.element_ty),
759
- mask=offs_dim < dim_end,
746
+ accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
747
+
748
+ for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
749
+ tile = tl.load(
750
+ base_ptrs + i * input_stride_1,
751
+ mask=mask_token[:, None] & mask_dim[None, :],
752
+ other=0.0,
760
753
  )
754
+ accumulator += tile.to(tl.float32)
755
+ accumulator *= routed_scaling_factor
756
+
757
+ # -------- Write back --------
758
+ store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
759
+ tl.store(
760
+ store_ptrs,
761
+ accumulator.to(input_ptr.dtype.element_ty),
762
+ mask=mask_token[:, None] & mask_dim[None, :],
763
+ )
761
764
 
762
765
 
763
766
  def moe_sum_reduce_triton(
@@ -772,7 +775,7 @@ def moe_sum_reduce_triton(
772
775
  BLOCK_M = 1
773
776
  BLOCK_DIM = 2048
774
777
  NUM_STAGE = 1
775
- num_warps = 8
778
+ num_warps = 16
776
779
 
777
780
  grid = (
778
781
  triton.cdiv(token_num, BLOCK_M),
@@ -11,20 +11,21 @@ from sglang.srt.distributed import (
11
11
  get_moe_expert_parallel_world_size,
12
12
  get_moe_tensor_parallel_rank,
13
13
  get_moe_tensor_parallel_world_size,
14
- get_tp_group,
15
14
  tensor_model_parallel_all_reduce,
16
15
  )
17
- from sglang.srt.distributed.device_communicators.pynccl_allocator import (
18
- use_symmetric_memory,
19
- )
20
16
  from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
21
17
  from sglang.srt.layers.moe import (
22
18
  MoeRunnerConfig,
23
19
  get_moe_runner_backend,
24
20
  should_use_flashinfer_trtllm_moe,
25
21
  )
22
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
23
+ StandardDispatcher,
24
+ StandardDispatchOutput,
25
+ )
26
26
  from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
27
27
  from sglang.srt.layers.quantization.base_config import (
28
+ FusedMoEMethodBase,
28
29
  QuantizationConfig,
29
30
  QuantizeMethodBase,
30
31
  )
@@ -68,16 +69,6 @@ if should_use_flashinfer_trtllm_moe():
68
69
  logger = logging.getLogger(__name__)
69
70
 
70
71
 
71
- def _is_fp4_quantization_enabled():
72
- """Check if ModelOpt FP4 quantization is enabled."""
73
- try:
74
- # Use the same simple check that works for class selection
75
- quantization = global_server_args_dict.get("quantization")
76
- return quantization == "modelopt_fp4"
77
- except:
78
- return False
79
-
80
-
81
72
  def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
82
73
  # Guess tokens per expert assuming perfect expert distribution first.
83
74
  num_tokens_per_expert = (num_tokens * top_k) // num_experts
@@ -152,16 +143,6 @@ class FusedMoE(torch.nn.Module):
152
143
  self.expert_map_cpu = None
153
144
  self.expert_map_gpu = None
154
145
 
155
- self.moe_runner_config = MoeRunnerConfig(
156
- activation=activation,
157
- apply_router_weight_on_input=apply_router_weight_on_input,
158
- inplace=inplace,
159
- no_combine=no_combine,
160
- routed_scaling_factor=routed_scaling_factor,
161
- gemm1_alpha=gemm1_alpha,
162
- gemm1_clamp_limit=gemm1_clamp_limit,
163
- )
164
-
165
146
  enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
166
147
 
167
148
  if enable_flashinfer_cutlass_moe and quant_config is None:
@@ -175,8 +156,7 @@ class FusedMoE(torch.nn.Module):
175
156
  self.moe_tp_rank = get_moe_tensor_parallel_rank()
176
157
  assert num_experts % self.moe_ep_size == 0
177
158
  self.num_local_experts = num_experts // self.moe_ep_size
178
- self.start_expert_id = self.moe_ep_rank * self.num_local_experts
179
- self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
159
+
180
160
  if self.moe_ep_size > 1:
181
161
  # TODO(ch-wan): support shared experts fusion
182
162
  # Create a tensor of size num_experts filled with -1
@@ -196,13 +176,6 @@ class FusedMoE(torch.nn.Module):
196
176
  self.use_presharded_weights = use_presharded_weights
197
177
 
198
178
  self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
199
- if quant_config is None:
200
- self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
201
- self.use_triton_kernels
202
- )
203
- else:
204
- self.quant_method = quant_config.get_quant_method(self, prefix)
205
- assert self.quant_method is not None
206
179
 
207
180
  self.quant_config = quant_config
208
181
  self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
@@ -213,12 +186,36 @@ class FusedMoE(torch.nn.Module):
213
186
  and self.use_flashinfer_mxfp4_moe
214
187
  ):
215
188
  hidden_size = round_up(hidden_size, 256)
189
+ self.hidden_size = hidden_size
190
+
191
+ self.moe_runner_config = MoeRunnerConfig(
192
+ num_experts=num_experts,
193
+ num_local_experts=self.num_local_experts,
194
+ hidden_size=hidden_size,
195
+ intermediate_size_per_partition=self.intermediate_size_per_partition,
196
+ layer_id=layer_id,
197
+ top_k=top_k,
198
+ num_fused_shared_experts=num_fused_shared_experts,
199
+ params_dtype=params_dtype,
200
+ activation=activation,
201
+ apply_router_weight_on_input=apply_router_weight_on_input,
202
+ inplace=inplace,
203
+ no_combine=no_combine,
204
+ routed_scaling_factor=routed_scaling_factor,
205
+ gemm1_alpha=gemm1_alpha,
206
+ gemm1_clamp_limit=gemm1_clamp_limit,
207
+ )
208
+
209
+ self.quant_method: Optional[FusedMoEMethodBase] = None
210
+ if quant_config is not None:
211
+ self.quant_method = quant_config.get_quant_method(self, prefix)
212
+ if self.quant_method is None:
213
+ self.quant_method = UnquantizedFusedMoEMethod(self.use_triton_kernels)
214
+
216
215
  self.quant_method.create_weights(
217
216
  layer=self,
218
217
  num_experts=self.num_local_experts,
219
218
  hidden_size=hidden_size,
220
- # FIXME: figure out which intermediate_size to use
221
- intermediate_size=self.intermediate_size_per_partition,
222
219
  intermediate_size_per_partition=self.intermediate_size_per_partition,
223
220
  params_dtype=params_dtype,
224
221
  weight_loader=(
@@ -229,6 +226,16 @@ class FusedMoE(torch.nn.Module):
229
226
  with_bias=with_bias,
230
227
  )
231
228
 
229
+ self.quant_method.create_moe_runner(self, self.moe_runner_config)
230
+ self.dispatcher = StandardDispatcher()
231
+
232
+ self.should_fuse_routed_scaling_factor_in_topk = isinstance(
233
+ self.quant_method, ModelOptNvFp4FusedMoEMethod
234
+ ) or (
235
+ isinstance(self.quant_method, Fp8MoEMethod)
236
+ and self.quant_method.use_cutlass_fused_experts_fp8
237
+ )
238
+
232
239
  def _load_per_tensor_weight_scale(
233
240
  self,
234
241
  shard_id: str,
@@ -522,10 +529,12 @@ class FusedMoE(torch.nn.Module):
522
529
  shard_id: str,
523
530
  expert_id: int,
524
531
  ) -> None:
532
+ # WARN: This makes the `expert_id` mean "local" and "global" in different cases
533
+ if not getattr(param, "_sglang_require_global_experts", False):
534
+ expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
535
+ if expert_id == -1:
536
+ return
525
537
 
526
- expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
527
- if expert_id == -1:
528
- return
529
538
  self._weight_loader_impl(
530
539
  param=param,
531
540
  loaded_weight=loaded_weight,
@@ -563,7 +572,10 @@ class FusedMoE(torch.nn.Module):
563
572
  )
564
573
 
565
574
  # Flashinfer assumes w31 format for w13_weight. Same for the scales.
566
- if should_use_flashinfer_trtllm_moe():
575
+ if should_use_flashinfer_trtllm_moe() and (
576
+ isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
577
+ or isinstance(self.quant_method, Fp8MoEMethod)
578
+ ):
567
579
  shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
568
580
 
569
581
  WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
@@ -594,8 +606,10 @@ class FusedMoE(torch.nn.Module):
594
606
  loaded_weight = loaded_weight.to(param.data.device)
595
607
 
596
608
  if (
597
- "compressed" in self.quant_method.__class__.__name__.lower()
598
- or "w4afp8" in self.quant_config.get_name()
609
+ (
610
+ "compressed" in self.quant_method.__class__.__name__.lower()
611
+ or "w4afp8" in self.quant_config.get_name()
612
+ )
599
613
  and (param.data[expert_id] != 1).any()
600
614
  and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
601
615
  ):
@@ -811,16 +825,17 @@ class FusedMoE(torch.nn.Module):
811
825
  elif TopKOutputChecker.format_is_triton_kernel(topk_output):
812
826
  raise NotImplementedError()
813
827
 
814
- # Matrix multiply.
815
- with use_symmetric_memory(get_tp_group()) as sm:
828
+ dispatch_output = self.dispatcher.dispatch(
829
+ hidden_states=hidden_states, topk_output=topk_output
830
+ )
816
831
 
817
- final_hidden_states = self.quant_method.apply(
818
- layer=self,
819
- x=hidden_states,
820
- topk_output=topk_output,
821
- moe_runner_config=self.moe_runner_config,
822
- )
823
- sm.tag(final_hidden_states)
832
+ # TODO: consider using symmetric memory
833
+ combine_input = self.quant_method.apply(
834
+ layer=self,
835
+ dispatch_output=dispatch_output,
836
+ )
837
+
838
+ final_hidden_states = self.dispatcher.combine(combine_input)
824
839
 
825
840
  final_hidden_states = final_hidden_states[
826
841
  ..., :origin_hidden_states_dim
@@ -923,12 +938,6 @@ class FusedMoE(torch.nn.Module):
923
938
  for shard_id in ["w1", "w2", "w3"]
924
939
  ]
925
940
 
926
- def should_fuse_routed_scaling_factor_in_topk(self):
927
- return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
928
- isinstance(self.quant_method, Fp8MoEMethod)
929
- and self.quant_method.use_cutlass_fused_experts_fp8
930
- )
931
-
932
941
 
933
942
  class FlashInferFusedMoE(FusedMoE):
934
943
  def __init__(self, *args, **kwargs):
@@ -953,9 +962,9 @@ class FlashInferFusedMoE(FusedMoE):
953
962
  # Matrix multiply.
954
963
  final_hidden_states = self.quant_method.apply_with_router_logits(
955
964
  layer=self,
956
- x=hidden_states,
957
- topk_output=topk_output,
958
- moe_runner_config=self.moe_runner_config,
965
+ dispatch_output=StandardDispatchOutput(
966
+ hidden_states=hidden_states, topk_output=topk_output
967
+ ),
959
968
  )
960
969
 
961
970
  if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
@@ -1055,16 +1064,3 @@ class FlashInferFP4MoE(FusedMoE):
1055
1064
  )[0]
1056
1065
 
1057
1066
  return result
1058
-
1059
-
1060
- def get_fused_moe_impl_class():
1061
- """Factory function to get the appropriate FusedMoE implementation class."""
1062
- if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
1063
- # Use FP4 variant when FP4 quantization is enabled
1064
- return FlashInferFP4MoE
1065
- elif should_use_flashinfer_trtllm_moe():
1066
- # Use regular FlashInfer variant for non-FP4 FlashInfer cases
1067
- return FlashInferFusedMoE
1068
- else:
1069
- # Default case
1070
- return FusedMoE
@@ -1,3 +1,4 @@
1
1
  from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
2
+ from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
2
3
 
3
- __all__ = ["MoeRunnerConfig"]
4
+ __all__ = ["MoeRunnerConfig", "MoeRunner"]