sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
10
10
  from sglang.srt.distributed import get_tp_group
11
11
  from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
12
12
  from sglang.srt.layers.moe import (
13
+ MoeRunner,
14
+ MoeRunnerBackend,
15
+ MoeRunnerConfig,
13
16
  should_use_flashinfer_cutlass_moe_fp4_allgather,
14
17
  should_use_flashinfer_trtllm_moe,
15
18
  )
16
19
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
20
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
17
21
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
18
22
  from sglang.srt.layers.quantization.base_config import (
19
23
  FusedMoEMethodBase,
@@ -35,12 +39,15 @@ from sglang.srt.layers.quantization.utils import (
35
39
  requantize_with_max_scale,
36
40
  )
37
41
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.utils import is_cuda, next_power_of_2
42
+ from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
39
43
 
40
44
  if TYPE_CHECKING:
41
45
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
42
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
43
- from sglang.srt.layers.moe.topk import TopKOutput
46
+ from sglang.srt.layers.moe.token_dispatcher import (
47
+ CombineInput,
48
+ StandardDispatchOutput,
49
+ )
50
+ from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
44
51
 
45
52
  if is_cuda():
46
53
  from sgl_kernel import scaled_fp4_quant
@@ -68,6 +75,17 @@ except ImportError:
68
75
  # Initialize logger for the module
69
76
  logger = logging.getLogger(__name__)
70
77
 
78
+ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
79
+ "SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
80
+ )
81
+ USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
82
+ "SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
83
+ )
84
+ # TODO make it true by default when the DeepEP PR is merged
85
+ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
86
+ "SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH", "false"
87
+ )
88
+
71
89
  # Supported activation schemes for the current configuration
72
90
  ACTIVATION_SCHEMES = ["static"]
73
91
 
@@ -322,7 +340,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
322
340
  layer: torch.nn.Module,
323
341
  num_experts: int,
324
342
  hidden_size: int,
325
- intermediate_size: int,
343
+ intermediate_size_per_partition: int,
326
344
  params_dtype: torch.dtype,
327
345
  **extra_weight_attrs,
328
346
  ):
@@ -338,7 +356,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
338
356
 
339
357
  w13_weight = ModelWeightParameter(
340
358
  data=torch.empty(
341
- num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
359
+ num_experts,
360
+ 2 * intermediate_size_per_partition,
361
+ hidden_size,
362
+ dtype=weight_dtype,
342
363
  ),
343
364
  input_dim=2,
344
365
  output_dim=1,
@@ -348,7 +369,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
348
369
 
349
370
  w2_weight = ModelWeightParameter(
350
371
  data=torch.empty(
351
- num_experts, hidden_size, intermediate_size, dtype=weight_dtype
372
+ num_experts,
373
+ hidden_size,
374
+ intermediate_size_per_partition,
375
+ dtype=weight_dtype,
352
376
  ),
353
377
  input_dim=2,
354
378
  output_dim=1,
@@ -414,28 +438,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
414
438
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
415
439
 
416
440
  # Requantize each expert's weights using the combined scale
417
- # w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
418
- # where the first intermediate_size rows are w1, the next are w3
419
- intermediate_size = layer.w13_weight.shape[1] // 2
441
+ # w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
442
+ # where the first intermediate_size_per_partition rows are w1, the next are w3
443
+ intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
420
444
  for expert_id in range(layer.w13_weight.shape[0]):
421
445
  start = 0
422
446
  for shard_id in range(2): # w1 and w3
423
447
  # Dequantize using the original scale for this shard
424
448
  dq_weight = per_tensor_dequantize(
425
449
  layer.w13_weight[expert_id][
426
- start : start + intermediate_size, :
450
+ start : start + intermediate_size_per_partition, :
427
451
  ],
428
452
  layer.w13_weight_scale[expert_id][shard_id],
429
453
  )
430
454
  # Requantize using the combined max scale
431
455
  (
432
456
  layer.w13_weight[expert_id][
433
- start : start + intermediate_size, :
457
+ start : start + intermediate_size_per_partition, :
434
458
  ],
435
459
  _,
436
460
  ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
437
461
 
438
- start += intermediate_size
462
+ start += intermediate_size_per_partition
439
463
 
440
464
  # Update the scale parameter to be per-expert instead of per-shard
441
465
  layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
@@ -457,29 +481,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
457
481
  layer.w2_input_scale.max(), requires_grad=False
458
482
  )
459
483
 
484
+ def create_moe_runner(
485
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
486
+ ):
487
+ self.moe_runner_config = moe_runner_config
488
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
489
+
460
490
  def apply(
461
491
  self,
462
492
  layer: torch.nn.Module,
463
- x: torch.Tensor,
464
- topk_output: TopKOutput,
465
- moe_runner_config: MoeRunnerConfig,
466
- ) -> torch.Tensor:
467
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
468
-
469
- return fused_experts(
470
- x,
471
- layer.w13_weight,
472
- layer.w2_weight,
473
- topk_output=topk_output,
474
- moe_runner_config=moe_runner_config,
493
+ dispatch_output: StandardDispatchOutput,
494
+ ) -> CombineInput:
495
+
496
+ quant_info = TritonMoeQuantInfo(
497
+ w13_weight=layer.w13_weight,
498
+ w2_weight=layer.w2_weight,
475
499
  use_fp8_w8a8=True,
476
- per_channel_quant=False, # ModelOpt uses per-tensor quantization
477
- w1_scale=layer.w13_weight_scale,
500
+ per_channel_quant=False,
501
+ w13_scale=layer.w13_weight_scale,
478
502
  w2_scale=layer.w2_weight_scale,
479
- a1_scale=layer.w13_input_scale,
503
+ a13_scale=layer.w13_input_scale,
480
504
  a2_scale=layer.w2_input_scale,
481
505
  )
482
506
 
507
+ return self.runner.run(dispatch_output, quant_info)
508
+
483
509
 
484
510
  class ModelOptFp4Config(QuantizationConfig):
485
511
  """Config class for FP4."""
@@ -517,6 +543,39 @@ class ModelOptFp4Config(QuantizationConfig):
517
543
  def get_config_filenames(cls) -> List[str]:
518
544
  return ["hf_quant_config.json"]
519
545
 
546
+ @staticmethod
547
+ def common_group_size(cfg: dict) -> int:
548
+ """Return the unique group_size across the config; raise if missing/mismatched."""
549
+ sizes = set()
550
+
551
+ # Top-level and 'quantization' block
552
+ v = cfg.get("group_size")
553
+ if isinstance(v, int):
554
+ sizes.add(v)
555
+ q = cfg.get("quantization")
556
+ if isinstance(q, dict):
557
+ v = q.get("group_size")
558
+ if isinstance(v, int):
559
+ sizes.add(v)
560
+
561
+ # config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
562
+ for g in (cfg.get("config_groups") or {}).values():
563
+ if isinstance(g, dict):
564
+ v = g.get("group_size")
565
+ if isinstance(v, int):
566
+ sizes.add(v)
567
+ for sub in g.values():
568
+ if isinstance(sub, dict):
569
+ v = sub.get("group_size")
570
+ if isinstance(v, int):
571
+ sizes.add(v)
572
+
573
+ if not sizes:
574
+ raise ValueError("No group_size found in config.")
575
+ if len(sizes) > 1:
576
+ raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
577
+ return next(iter(sizes))
578
+
520
579
  @classmethod
521
580
  def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
522
581
  # Handle two different config formats:
@@ -549,7 +608,7 @@ class ModelOptFp4Config(QuantizationConfig):
549
608
  else:
550
609
  kv_cache_quant_algo = "auto"
551
610
 
552
- group_size = config.get("group_size")
611
+ group_size = ModelOptFp4Config.common_group_size(config)
553
612
  exclude_modules = config.get("ignore", [])
554
613
  else:
555
614
  # Fall back to nested format (hf_quant_config.json - legacy format)
@@ -559,7 +618,7 @@ class ModelOptFp4Config(QuantizationConfig):
559
618
  kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
560
619
  if not kv_cache_quant_algo:
561
620
  kv_cache_quant_algo = "auto"
562
- group_size = quant_config.get("group_size")
621
+ group_size = ModelOptFp4Config.common_group_size(config)
563
622
  exclude_modules = quant_config.get("exclude_modules", [])
564
623
  except (ValueError, KeyError):
565
624
  raise ValueError(
@@ -595,16 +654,21 @@ class ModelOptFp4Config(QuantizationConfig):
595
654
  def is_layer_excluded(self, prefix: str, exclude_modules: list):
596
655
  import regex as re
597
656
 
657
+ fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
658
+ prefix_split = prefix.split(".")
598
659
  for pattern in exclude_modules:
599
660
  regex_str = pattern.replace(".", r"\.").replace("*", r".*")
661
+ pattern_split = pattern.split(".")
600
662
  if re.fullmatch(regex_str, prefix):
601
663
  return True
602
-
603
- # Check if the last part of the excluded pattern is contained in the last part of the prefix
604
- # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
605
- pattern_last_part = pattern.split(".")[-1]
606
- prefix_last_part = prefix.split(".")[-1]
607
- if pattern_last_part in prefix_last_part:
664
+ elif (
665
+ pattern_split[-1] in fused_patterns
666
+ and pattern_split[-1] in prefix_split[-1]
667
+ ):
668
+ # Check if the last part of the excluded pattern is contained in the last part of the prefix
669
+ # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
670
+ # e.g., model.layers.{i}.self_attn.{fused_weight_name}
671
+ assert len(prefix_split) == 5 and len(pattern_split) == 5
608
672
  return True
609
673
  return False
610
674
 
@@ -788,14 +852,25 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
788
852
  if enable_flashinfer_fp4_gemm:
789
853
  w = layer.weight.T
790
854
  w_scale_interleaved = layer.weight_scale_interleaved.T
791
- out = fp4_gemm(
792
- x_fp4,
793
- w,
794
- x_scale_interleaved,
795
- w_scale_interleaved,
796
- layer.alpha,
797
- output_dtype,
798
- )
855
+ if USE_CUTLASS_BACKEND_FOR_FP4_GEMM:
856
+ out = fp4_gemm(
857
+ x_fp4,
858
+ w,
859
+ x_scale_interleaved,
860
+ w_scale_interleaved,
861
+ layer.alpha,
862
+ output_dtype,
863
+ backend="cutlass",
864
+ )
865
+ else:
866
+ out = fp4_gemm(
867
+ x_fp4,
868
+ w,
869
+ x_scale_interleaved,
870
+ w_scale_interleaved,
871
+ layer.alpha,
872
+ output_dtype,
873
+ )
799
874
  if bias is not None:
800
875
  out = out + bias
801
876
  return out.view(*output_shape)
@@ -826,6 +901,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
826
901
  """Access the global enable_flashinfer_cutlass_moe setting."""
827
902
  return get_moe_runner_backend().is_flashinfer_cutlass()
828
903
 
904
+ @property
905
+ def enable_flashinfer_cutedsl_moe(self) -> bool:
906
+ from sglang.srt.layers.moe import get_moe_runner_backend
907
+
908
+ """Access the global enable_flashinfer_cutedsl_moe setting."""
909
+ return get_moe_runner_backend().is_flashinfer_cutedsl()
910
+
829
911
  def create_weights(
830
912
  self,
831
913
  layer: torch.nn.Module,
@@ -937,15 +1019,17 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
937
1019
  )
938
1020
 
939
1021
  w13_input_scale = PerTensorScaleParameter(
940
- data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
1022
+ data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
941
1023
  weight_loader=weight_loader,
942
1024
  )
1025
+ w13_input_scale._sglang_require_global_experts = True
943
1026
  layer.register_parameter("w13_input_scale", w13_input_scale)
944
1027
 
945
1028
  w2_input_scale = PerTensorScaleParameter(
946
- data=torch.empty(layer.num_local_experts, dtype=torch.float32),
1029
+ data=torch.empty(layer.num_experts, dtype=torch.float32),
947
1030
  weight_loader=weight_loader,
948
1031
  )
1032
+ w2_input_scale._sglang_require_global_experts = True
949
1033
  layer.register_parameter("w2_input_scale", w2_input_scale)
950
1034
 
951
1035
  def swizzle_blockscale(self, scale: torch.Tensor):
@@ -1128,6 +1212,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1128
1212
  if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
1129
1213
  w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
1130
1214
  w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
1215
+ elif self.enable_flashinfer_cutedsl_moe:
1216
+ # All-expert-one-input-scale is mathematically different from default per-expert-input-scale
1217
+ # Thus we allow users to switch the flag to do thorough testing
1218
+ if CUTEDSL_MOE_SCALAR_INPUT_SCALE:
1219
+ w13_input_scale = (
1220
+ layer.w13_input_scale.max()
1221
+ .to(torch.float32)
1222
+ .repeat(layer.w13_input_scale.shape[0])
1223
+ )
1224
+ else:
1225
+ w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
1226
+ torch.float32
1227
+ )
1228
+
1229
+ w2_input_scale = layer.w2_input_scale
1230
+
1231
+ def _slice_scale(w):
1232
+ assert w.shape == (layer.num_experts,)
1233
+ assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
1234
+ return w[
1235
+ layer.moe_ep_rank
1236
+ * layer.num_local_experts : (layer.moe_ep_rank + 1)
1237
+ * layer.num_local_experts
1238
+ ]
1239
+
1240
+ w13_input_scale = _slice_scale(w13_input_scale)
1241
+ w2_input_scale = _slice_scale(w2_input_scale)
1242
+
1243
+ if CUTEDSL_MOE_NVFP4_DISPATCH:
1244
+ assert torch.all(w13_input_scale == w13_input_scale[0])
1245
+ w13_input_scale = w13_input_scale[0]
1131
1246
  else:
1132
1247
  w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1133
1248
  w2_input_scale = layer.w2_input_scale
@@ -1210,8 +1325,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1210
1325
  layer.w13_weight_scale,
1211
1326
  )
1212
1327
 
1213
- logger.info_once("Applied flashinfer weight processing for both w13 and w2")
1214
-
1215
1328
  else:
1216
1329
  # CUTLASS processing - handle w13 and w2 separately
1217
1330
 
@@ -1228,7 +1341,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1228
1341
  layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1229
1342
 
1230
1343
  # Both flashinfer cutlass and regular cutlass use same processing for w2
1231
- logger.info_once("Applied weight processing for both w13 and w2")
1232
1344
 
1233
1345
  # Set up CUTLASS MoE parameters
1234
1346
  device = layer.w13_weight.device
@@ -1245,21 +1357,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1245
1357
  # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
1246
1358
  return self.enable_flashinfer_cutlass_moe
1247
1359
 
1360
+ def create_moe_runner(
1361
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1362
+ ):
1363
+ self.moe_runner_config = moe_runner_config
1364
+
1248
1365
  def apply(
1249
1366
  self,
1250
1367
  layer: FusedMoE,
1251
- x: torch.Tensor,
1252
- topk_output: TopKOutput,
1253
- moe_runner_config: MoeRunnerConfig,
1254
- ) -> torch.Tensor:
1368
+ dispatch_output: StandardDispatchOutput,
1369
+ ) -> CombineInput:
1370
+
1371
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1372
+
1373
+ x = dispatch_output.hidden_states
1374
+ topk_output = dispatch_output.topk_output
1375
+
1255
1376
  assert (
1256
- moe_runner_config.activation == "silu"
1377
+ self.moe_runner_config.activation == "silu"
1257
1378
  ), "Only SiLU activation is supported."
1258
1379
 
1380
+ moe_runner_config = self.moe_runner_config
1381
+
1259
1382
  # Check if this is a FlashInferFP4MoE layer that should handle its own forward
1260
1383
  if hasattr(layer, "gemm1_weights_fp4_shuffled"):
1261
1384
  # This layer was processed with flashinfer TRTLLM - delegate to its own forward
1262
- return layer.forward(x, topk_output)
1385
+ return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
1263
1386
 
1264
1387
  if self.enable_flashinfer_cutlass_moe:
1265
1388
  assert (
@@ -1312,13 +1435,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1312
1435
  tp_rank=layer.moe_tp_rank,
1313
1436
  tune_max_num_tokens=next_power_of_2(x.shape[0]),
1314
1437
  )[0]
1315
- # Scale by routed_scaling_factor is fused into select_experts.
1316
1438
  if should_use_flashinfer_cutlass_moe_fp4_allgather():
1317
1439
  output, global_output = get_local_dp_buffer(), output
1318
1440
  get_tp_group().reduce_scatterv(
1319
1441
  global_output, output=output, sizes=get_dp_global_num_tokens()
1320
1442
  )
1321
- return output
1443
+ return StandardCombineInput(hidden_states=output)
1322
1444
 
1323
1445
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
1324
1446
 
@@ -1339,4 +1461,50 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1339
1461
  apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
1340
1462
  ).to(x.dtype)
1341
1463
  # Scale by routed_scaling_factor is fused into select_experts.
1342
- return output
1464
+ return StandardCombineInput(hidden_states=output)
1465
+
1466
+ def apply_without_routing_weights(
1467
+ self,
1468
+ layer: FusedMoE,
1469
+ x: torch.Tensor,
1470
+ masked_m: torch.Tensor,
1471
+ moe_runner_config: MoeRunnerConfig,
1472
+ down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
1473
+ ) -> torch.Tensor:
1474
+ assert (
1475
+ moe_runner_config.activation == "silu"
1476
+ ), "Only SiLU activation is supported."
1477
+
1478
+ assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe"
1479
+ assert (
1480
+ not moe_runner_config.apply_router_weight_on_input
1481
+ ), "apply_router_weight_on_input is not supported for Flashinfer"
1482
+
1483
+ from sglang.srt.layers.moe.flashinfer_cutedsl_moe import (
1484
+ flashinfer_cutedsl_moe_masked,
1485
+ )
1486
+
1487
+ out = flashinfer_cutedsl_moe_masked(
1488
+ hidden_states=x,
1489
+ input_global_scale=(
1490
+ None if CUTEDSL_MOE_NVFP4_DISPATCH else layer.w13_input_scale_quant
1491
+ ),
1492
+ w1=layer.w13_weight,
1493
+ w1_blockscale=layer.w13_blockscale_swizzled,
1494
+ w1_alpha=layer.g1_alphas,
1495
+ w2=layer.w2_weight,
1496
+ a2_global_scale=layer.w2_input_scale_quant,
1497
+ w2_blockscale=layer.w2_blockscale_swizzled,
1498
+ w2_alpha=layer.g2_alphas,
1499
+ masked_m=masked_m,
1500
+ **(
1501
+ dict(
1502
+ down_sm_count=down_gemm_overlap_args.num_sms,
1503
+ down_signals=down_gemm_overlap_args.signal,
1504
+ down_start_event=down_gemm_overlap_args.start_event,
1505
+ )
1506
+ if down_gemm_overlap_args is not None
1507
+ else {}
1508
+ ),
1509
+ )
1510
+ return out
@@ -9,6 +9,8 @@ import torch
9
9
 
10
10
  from sglang.srt.distributed import get_tensor_model_parallel_rank
11
11
  from sglang.srt.distributed.parallel_state import get_tp_group
12
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
13
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
12
14
  from sglang.srt.layers.quantization.awq import AWQConfig
13
15
  from sglang.srt.layers.quantization.base_config import (
14
16
  FusedMoEMethodBase,
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
22
24
  logger = logging.getLogger(__name__)
23
25
 
24
26
  if TYPE_CHECKING:
25
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
26
- from sglang.srt.layers.moe.topk import TopKOutput
27
+ from sglang.srt.layers.moe.token_dispatcher import (
28
+ CombineInput,
29
+ StandardDispatchOutput,
30
+ )
27
31
 
28
32
 
29
33
  def get_weight_perm(num_bits: int):
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
349
353
  layer.register_parameter(key, param)
350
354
  set_weight_attrs(param, extra_weight_attrs)
351
355
 
356
+ def create_moe_runner(
357
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
358
+ ):
359
+ self.moe_runner_config = moe_runner_config
360
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
361
+
352
362
  def apply(
353
363
  self,
354
364
  layer: torch.nn.Module,
355
- x: torch.Tensor,
356
- topk_output: TopKOutput,
357
- moe_runner_config: MoeRunnerConfig,
358
- ) -> torch.Tensor:
359
- # avoid circular import
360
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
361
-
365
+ dispatch_output: StandardDispatchOutput,
366
+ ) -> CombineInput:
362
367
  assert (
363
- moe_runner_config.activation == "silu"
368
+ self.moe_runner_config.activation == "silu"
364
369
  ), "Only SiLU activation is supported."
365
370
 
366
371
  weight_bits = self.quant_config.weight_bits
367
372
  has_zp = self.quant_config.has_zp
368
373
 
369
- return fused_experts(
370
- x,
371
- layer.w13_qweight,
372
- layer.w2_qweight,
373
- topk_output=topk_output,
374
- moe_runner_config=moe_runner_config,
374
+ quant_info = TritonMoeQuantInfo(
375
+ w13_weight=layer.w13_qweight,
376
+ w2_weight=layer.w2_qweight,
375
377
  use_int4_w4a16=weight_bits == 4,
376
378
  use_int8_w8a16=weight_bits == 8,
377
- w1_scale=layer.w13_scales,
379
+ w13_scale=layer.w13_scales,
378
380
  w2_scale=layer.w2_scales,
379
- w1_zp=layer.w13_qzeros if has_zp else None,
381
+ w13_zp=layer.w13_qzeros if has_zp else None,
380
382
  w2_zp=layer.w2_qzeros if has_zp else None,
381
383
  block_shape=[0, layer.group_size],
382
384
  )
385
+ return self.runner.run(dispatch_output, quant_info)
383
386
 
384
387
  @staticmethod
385
388
  def get_weight_loader(layer, weight_loader):