sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 8,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 8,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 2
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -1,3 +1,4 @@
1
+ # NOTE: this file will be separated into sglang/srt/layers/moe/moe_runner/triton_utils.py
1
2
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py
2
3
 
3
4
  """Fused MoE kernel."""
@@ -6,13 +7,12 @@ from __future__ import annotations
6
7
 
7
8
  import functools
8
9
  import os
9
- from typing import List, Optional
10
+ from typing import TYPE_CHECKING, List, Optional
10
11
 
11
12
  import torch
12
13
  import triton.language as tl
13
14
 
14
15
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
15
- from sglang.srt.layers.moe.topk import StandardTopKOutput
16
16
  from sglang.srt.utils import (
17
17
  cpu_has_amx_support,
18
18
  direct_register_custom_op,
@@ -26,6 +26,9 @@ from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_c
26
26
  from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton
27
27
  from .moe_align_block_size import moe_align_block_size
28
28
 
29
+ if TYPE_CHECKING:
30
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
31
+
29
32
  _is_hip = is_hip()
30
33
  _is_cuda = is_cuda()
31
34
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -43,7 +43,7 @@ def get_moe_configs(
43
43
  be picked and the associated configuration chosen to invoke the kernel.
44
44
  """
45
45
  # Supported Triton versions, should be sorted from the newest to the oldest
46
- supported_triton_versions = ["3.3.1", "3.2.0", "3.1.0"]
46
+ supported_triton_versions = ["3.4.0", "3.3.1", "3.2.0", "3.1.0"]
47
47
 
48
48
  # First look up if an optimized configuration is available in the configs
49
49
  # directory
@@ -51,10 +51,14 @@ def get_moe_configs(
51
51
 
52
52
  # We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
53
53
  # so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
54
+ config_dir = os.environ.get(
55
+ "SGLANG_MOE_CONFIG_DIR", os.path.dirname(os.path.realpath(__file__))
56
+ )
57
+
54
58
  triton_version = triton.__version__
55
59
  version_dir = f"triton_{triton_version.replace('.', '_')}"
56
60
  config_file_path = os.path.join(
57
- os.path.dirname(os.path.realpath(__file__)),
61
+ config_dir,
58
62
  "configs",
59
63
  version_dir,
60
64
  json_file_name,
@@ -75,7 +79,7 @@ def get_moe_configs(
75
79
  if try_triton_version == triton_version:
76
80
  continue
77
81
  try_config_file_path = os.path.join(
78
- os.path.dirname(os.path.realpath(__file__)),
82
+ config_dir,
79
83
  "configs",
80
84
  f"triton_{try_triton_version.replace('.', '_')}",
81
85
  json_file_name,
@@ -735,29 +735,32 @@ def _moe_sum_reduce_kernel(
735
735
  token_block_id = tl.program_id(0)
736
736
  dim_block_id = tl.program_id(1)
737
737
 
738
- token_start = token_block_id * BLOCK_M
739
- token_end = min((token_block_id + 1) * BLOCK_M, token_num)
738
+ offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
739
+ offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
740
740
 
741
- dim_start = dim_block_id * BLOCK_DIM
742
- dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
741
+ mask_token = offs_token < token_num
742
+ mask_dim = offs_dim < hidden_dim
743
743
 
744
- offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
744
+ base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
745
745
 
746
- for token_index in range(token_start, token_end):
747
- accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
748
- input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
749
- for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
750
- tmp = tl.load(
751
- input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
752
- )
753
- accumulator += tmp
754
- accumulator = accumulator * routed_scaling_factor
755
- store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
756
- tl.store(
757
- store_t_ptr,
758
- accumulator.to(input_ptr.dtype.element_ty),
759
- mask=offs_dim < dim_end,
746
+ accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
747
+
748
+ for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
749
+ tile = tl.load(
750
+ base_ptrs + i * input_stride_1,
751
+ mask=mask_token[:, None] & mask_dim[None, :],
752
+ other=0.0,
760
753
  )
754
+ accumulator += tile.to(tl.float32)
755
+ accumulator *= routed_scaling_factor
756
+
757
+ # -------- Write back --------
758
+ store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
759
+ tl.store(
760
+ store_ptrs,
761
+ accumulator.to(input_ptr.dtype.element_ty),
762
+ mask=mask_token[:, None] & mask_dim[None, :],
763
+ )
761
764
 
762
765
 
763
766
  def moe_sum_reduce_triton(
@@ -772,7 +775,7 @@ def moe_sum_reduce_triton(
772
775
  BLOCK_M = 1
773
776
  BLOCK_DIM = 2048
774
777
  NUM_STAGE = 1
775
- num_warps = 8
778
+ num_warps = 16
776
779
 
777
780
  grid = (
778
781
  triton.cdiv(token_num, BLOCK_M),
@@ -11,20 +11,21 @@ from sglang.srt.distributed import (
11
11
  get_moe_expert_parallel_world_size,
12
12
  get_moe_tensor_parallel_rank,
13
13
  get_moe_tensor_parallel_world_size,
14
- get_tp_group,
15
14
  tensor_model_parallel_all_reduce,
16
15
  )
17
- from sglang.srt.distributed.device_communicators.pynccl_allocator import (
18
- use_symmetric_memory,
19
- )
20
16
  from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
21
17
  from sglang.srt.layers.moe import (
22
18
  MoeRunnerConfig,
23
19
  get_moe_runner_backend,
24
20
  should_use_flashinfer_trtllm_moe,
25
21
  )
22
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
23
+ StandardDispatcher,
24
+ StandardDispatchOutput,
25
+ )
26
26
  from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
27
27
  from sglang.srt.layers.quantization.base_config import (
28
+ FusedMoEMethodBase,
28
29
  QuantizationConfig,
29
30
  QuantizeMethodBase,
30
31
  )
@@ -68,16 +69,6 @@ if should_use_flashinfer_trtllm_moe():
68
69
  logger = logging.getLogger(__name__)
69
70
 
70
71
 
71
- def _is_fp4_quantization_enabled():
72
- """Check if ModelOpt FP4 quantization is enabled."""
73
- try:
74
- # Use the same simple check that works for class selection
75
- quantization = global_server_args_dict.get("quantization")
76
- return quantization == "modelopt_fp4"
77
- except:
78
- return False
79
-
80
-
81
72
  def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
82
73
  # Guess tokens per expert assuming perfect expert distribution first.
83
74
  num_tokens_per_expert = (num_tokens * top_k) // num_experts
@@ -152,16 +143,6 @@ class FusedMoE(torch.nn.Module):
152
143
  self.expert_map_cpu = None
153
144
  self.expert_map_gpu = None
154
145
 
155
- self.moe_runner_config = MoeRunnerConfig(
156
- activation=activation,
157
- apply_router_weight_on_input=apply_router_weight_on_input,
158
- inplace=inplace,
159
- no_combine=no_combine,
160
- routed_scaling_factor=routed_scaling_factor,
161
- gemm1_alpha=gemm1_alpha,
162
- gemm1_clamp_limit=gemm1_clamp_limit,
163
- )
164
-
165
146
  enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
166
147
 
167
148
  if enable_flashinfer_cutlass_moe and quant_config is None:
@@ -196,13 +177,6 @@ class FusedMoE(torch.nn.Module):
196
177
  self.use_presharded_weights = use_presharded_weights
197
178
 
198
179
  self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
199
- if quant_config is None:
200
- self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
201
- self.use_triton_kernels
202
- )
203
- else:
204
- self.quant_method = quant_config.get_quant_method(self, prefix)
205
- assert self.quant_method is not None
206
180
 
207
181
  self.quant_config = quant_config
208
182
  self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
@@ -213,12 +187,40 @@ class FusedMoE(torch.nn.Module):
213
187
  and self.use_flashinfer_mxfp4_moe
214
188
  ):
215
189
  hidden_size = round_up(hidden_size, 256)
190
+ self.hidden_size = hidden_size
191
+
192
+ self.moe_runner_config = MoeRunnerConfig(
193
+ num_experts=num_experts,
194
+ num_local_experts=self.num_local_experts,
195
+ hidden_size=hidden_size,
196
+ intermediate_size_per_partition=self.intermediate_size_per_partition,
197
+ layer_id=layer_id,
198
+ top_k=top_k,
199
+ num_fused_shared_experts=num_fused_shared_experts,
200
+ params_dtype=params_dtype,
201
+ activation=activation,
202
+ apply_router_weight_on_input=apply_router_weight_on_input,
203
+ inplace=inplace,
204
+ no_combine=no_combine,
205
+ routed_scaling_factor=routed_scaling_factor,
206
+ gemm1_alpha=gemm1_alpha,
207
+ gemm1_clamp_limit=gemm1_clamp_limit,
208
+ )
209
+
210
+ if quant_config is None:
211
+ self.quant_method: FusedMoEMethodBase = UnquantizedFusedMoEMethod(
212
+ self.use_triton_kernels
213
+ )
214
+ else:
215
+ self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method(
216
+ self, prefix
217
+ )
218
+ assert self.quant_method is not None
219
+
216
220
  self.quant_method.create_weights(
217
221
  layer=self,
218
222
  num_experts=self.num_local_experts,
219
223
  hidden_size=hidden_size,
220
- # FIXME: figure out which intermediate_size to use
221
- intermediate_size=self.intermediate_size_per_partition,
222
224
  intermediate_size_per_partition=self.intermediate_size_per_partition,
223
225
  params_dtype=params_dtype,
224
226
  weight_loader=(
@@ -229,6 +231,16 @@ class FusedMoE(torch.nn.Module):
229
231
  with_bias=with_bias,
230
232
  )
231
233
 
234
+ self.quant_method.create_moe_runner(self, self.moe_runner_config)
235
+ self.dispatcher = StandardDispatcher()
236
+
237
+ self.should_fuse_routed_scaling_factor_in_topk = isinstance(
238
+ self.quant_method, ModelOptNvFp4FusedMoEMethod
239
+ ) or (
240
+ isinstance(self.quant_method, Fp8MoEMethod)
241
+ and self.quant_method.use_cutlass_fused_experts_fp8
242
+ )
243
+
232
244
  def _load_per_tensor_weight_scale(
233
245
  self,
234
246
  shard_id: str,
@@ -522,10 +534,12 @@ class FusedMoE(torch.nn.Module):
522
534
  shard_id: str,
523
535
  expert_id: int,
524
536
  ) -> None:
537
+ # WARN: This makes the `expert_id` mean "local" and "global" in different cases
538
+ if not getattr(param, "_sglang_require_global_experts", False):
539
+ expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
540
+ if expert_id == -1:
541
+ return
525
542
 
526
- expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
527
- if expert_id == -1:
528
- return
529
543
  self._weight_loader_impl(
530
544
  param=param,
531
545
  loaded_weight=loaded_weight,
@@ -563,7 +577,10 @@ class FusedMoE(torch.nn.Module):
563
577
  )
564
578
 
565
579
  # Flashinfer assumes w31 format for w13_weight. Same for the scales.
566
- if should_use_flashinfer_trtllm_moe():
580
+ if should_use_flashinfer_trtllm_moe() and (
581
+ isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
582
+ or isinstance(self.quant_method, Fp8MoEMethod)
583
+ ):
567
584
  shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
568
585
 
569
586
  WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
@@ -594,8 +611,10 @@ class FusedMoE(torch.nn.Module):
594
611
  loaded_weight = loaded_weight.to(param.data.device)
595
612
 
596
613
  if (
597
- "compressed" in self.quant_method.__class__.__name__.lower()
598
- or "w4afp8" in self.quant_config.get_name()
614
+ (
615
+ "compressed" in self.quant_method.__class__.__name__.lower()
616
+ or "w4afp8" in self.quant_config.get_name()
617
+ )
599
618
  and (param.data[expert_id] != 1).any()
600
619
  and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
601
620
  ):
@@ -811,16 +830,17 @@ class FusedMoE(torch.nn.Module):
811
830
  elif TopKOutputChecker.format_is_triton_kernel(topk_output):
812
831
  raise NotImplementedError()
813
832
 
814
- # Matrix multiply.
815
- with use_symmetric_memory(get_tp_group()) as sm:
833
+ dispatch_output = self.dispatcher.dispatch(
834
+ hidden_states=hidden_states, topk_output=topk_output
835
+ )
816
836
 
817
- final_hidden_states = self.quant_method.apply(
818
- layer=self,
819
- x=hidden_states,
820
- topk_output=topk_output,
821
- moe_runner_config=self.moe_runner_config,
822
- )
823
- sm.tag(final_hidden_states)
837
+ # TODO: consider using symmetric memory
838
+ combine_input = self.quant_method.apply(
839
+ layer=self,
840
+ dispatch_output=dispatch_output,
841
+ )
842
+
843
+ final_hidden_states = self.dispatcher.combine(combine_input)
824
844
 
825
845
  final_hidden_states = final_hidden_states[
826
846
  ..., :origin_hidden_states_dim
@@ -923,12 +943,6 @@ class FusedMoE(torch.nn.Module):
923
943
  for shard_id in ["w1", "w2", "w3"]
924
944
  ]
925
945
 
926
- def should_fuse_routed_scaling_factor_in_topk(self):
927
- return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
928
- isinstance(self.quant_method, Fp8MoEMethod)
929
- and self.quant_method.use_cutlass_fused_experts_fp8
930
- )
931
-
932
946
 
933
947
  class FlashInferFusedMoE(FusedMoE):
934
948
  def __init__(self, *args, **kwargs):
@@ -953,9 +967,9 @@ class FlashInferFusedMoE(FusedMoE):
953
967
  # Matrix multiply.
954
968
  final_hidden_states = self.quant_method.apply_with_router_logits(
955
969
  layer=self,
956
- x=hidden_states,
957
- topk_output=topk_output,
958
- moe_runner_config=self.moe_runner_config,
970
+ dispatch_output=StandardDispatchOutput(
971
+ hidden_states=hidden_states, topk_output=topk_output
972
+ ),
959
973
  )
960
974
 
961
975
  if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
@@ -1055,16 +1069,3 @@ class FlashInferFP4MoE(FusedMoE):
1055
1069
  )[0]
1056
1070
 
1057
1071
  return result
1058
-
1059
-
1060
- def get_fused_moe_impl_class():
1061
- """Factory function to get the appropriate FusedMoE implementation class."""
1062
- if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
1063
- # Use FP4 variant when FP4 quantization is enabled
1064
- return FlashInferFP4MoE
1065
- elif should_use_flashinfer_trtllm_moe():
1066
- # Use regular FlashInfer variant for non-FP4 FlashInfer cases
1067
- return FlashInferFusedMoE
1068
- else:
1069
- # Default case
1070
- return FusedMoE
@@ -1,3 +1,4 @@
1
1
  from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
2
+ from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
2
3
 
3
- __all__ = ["MoeRunnerConfig"]
4
+ __all__ = ["MoeRunnerConfig", "MoeRunner"]