sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
10
10
  from sglang.srt.distributed import get_tp_group
11
11
  from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
12
12
  from sglang.srt.layers.moe import (
13
+ MoeRunner,
14
+ MoeRunnerBackend,
15
+ MoeRunnerConfig,
13
16
  should_use_flashinfer_cutlass_moe_fp4_allgather,
14
17
  should_use_flashinfer_trtllm_moe,
15
18
  )
16
19
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
20
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
17
21
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
18
22
  from sglang.srt.layers.quantization.base_config import (
19
23
  FusedMoEMethodBase,
@@ -35,12 +39,15 @@ from sglang.srt.layers.quantization.utils import (
35
39
  requantize_with_max_scale,
36
40
  )
37
41
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.utils import is_cuda, next_power_of_2
42
+ from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
39
43
 
40
44
  if TYPE_CHECKING:
41
45
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
42
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
43
- from sglang.srt.layers.moe.topk import TopKOutput
46
+ from sglang.srt.layers.moe.token_dispatcher import (
47
+ CombineInput,
48
+ StandardDispatchOutput,
49
+ )
50
+ from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
44
51
 
45
52
  if is_cuda():
46
53
  from sgl_kernel import scaled_fp4_quant
@@ -68,6 +75,17 @@ except ImportError:
68
75
  # Initialize logger for the module
69
76
  logger = logging.getLogger(__name__)
70
77
 
78
+ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
79
+ "SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
80
+ )
81
+ USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
82
+ "SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
83
+ )
84
+ # TODO make it true by default when the DeepEP PR is merged
85
+ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
86
+ "SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH", "false"
87
+ )
88
+
71
89
  # Supported activation schemes for the current configuration
72
90
  ACTIVATION_SCHEMES = ["static"]
73
91
 
@@ -322,7 +340,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
322
340
  layer: torch.nn.Module,
323
341
  num_experts: int,
324
342
  hidden_size: int,
325
- intermediate_size: int,
343
+ intermediate_size_per_partition: int,
326
344
  params_dtype: torch.dtype,
327
345
  **extra_weight_attrs,
328
346
  ):
@@ -338,7 +356,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
338
356
 
339
357
  w13_weight = ModelWeightParameter(
340
358
  data=torch.empty(
341
- num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
359
+ num_experts,
360
+ 2 * intermediate_size_per_partition,
361
+ hidden_size,
362
+ dtype=weight_dtype,
342
363
  ),
343
364
  input_dim=2,
344
365
  output_dim=1,
@@ -348,7 +369,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
348
369
 
349
370
  w2_weight = ModelWeightParameter(
350
371
  data=torch.empty(
351
- num_experts, hidden_size, intermediate_size, dtype=weight_dtype
372
+ num_experts,
373
+ hidden_size,
374
+ intermediate_size_per_partition,
375
+ dtype=weight_dtype,
352
376
  ),
353
377
  input_dim=2,
354
378
  output_dim=1,
@@ -414,28 +438,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
414
438
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
415
439
 
416
440
  # Requantize each expert's weights using the combined scale
417
- # w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
418
- # where the first intermediate_size rows are w1, the next are w3
419
- intermediate_size = layer.w13_weight.shape[1] // 2
441
+ # w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
442
+ # where the first intermediate_size_per_partition rows are w1, the next are w3
443
+ intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
420
444
  for expert_id in range(layer.w13_weight.shape[0]):
421
445
  start = 0
422
446
  for shard_id in range(2): # w1 and w3
423
447
  # Dequantize using the original scale for this shard
424
448
  dq_weight = per_tensor_dequantize(
425
449
  layer.w13_weight[expert_id][
426
- start : start + intermediate_size, :
450
+ start : start + intermediate_size_per_partition, :
427
451
  ],
428
452
  layer.w13_weight_scale[expert_id][shard_id],
429
453
  )
430
454
  # Requantize using the combined max scale
431
455
  (
432
456
  layer.w13_weight[expert_id][
433
- start : start + intermediate_size, :
457
+ start : start + intermediate_size_per_partition, :
434
458
  ],
435
459
  _,
436
460
  ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
437
461
 
438
- start += intermediate_size
462
+ start += intermediate_size_per_partition
439
463
 
440
464
  # Update the scale parameter to be per-expert instead of per-shard
441
465
  layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
@@ -457,29 +481,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
457
481
  layer.w2_input_scale.max(), requires_grad=False
458
482
  )
459
483
 
484
+ def create_moe_runner(
485
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
486
+ ):
487
+ self.moe_runner_config = moe_runner_config
488
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
489
+
460
490
  def apply(
461
491
  self,
462
492
  layer: torch.nn.Module,
463
- x: torch.Tensor,
464
- topk_output: TopKOutput,
465
- moe_runner_config: MoeRunnerConfig,
466
- ) -> torch.Tensor:
467
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
468
-
469
- return fused_experts(
470
- x,
471
- layer.w13_weight,
472
- layer.w2_weight,
473
- topk_output=topk_output,
474
- moe_runner_config=moe_runner_config,
493
+ dispatch_output: StandardDispatchOutput,
494
+ ) -> CombineInput:
495
+
496
+ quant_info = TritonMoeQuantInfo(
497
+ w13_weight=layer.w13_weight,
498
+ w2_weight=layer.w2_weight,
475
499
  use_fp8_w8a8=True,
476
- per_channel_quant=False, # ModelOpt uses per-tensor quantization
477
- w1_scale=layer.w13_weight_scale,
500
+ per_channel_quant=False,
501
+ w13_scale=layer.w13_weight_scale,
478
502
  w2_scale=layer.w2_weight_scale,
479
- a1_scale=layer.w13_input_scale,
503
+ a13_scale=layer.w13_input_scale,
480
504
  a2_scale=layer.w2_input_scale,
481
505
  )
482
506
 
507
+ return self.runner.run(dispatch_output, quant_info)
508
+
483
509
 
484
510
  class ModelOptFp4Config(QuantizationConfig):
485
511
  """Config class for FP4."""
@@ -628,16 +654,21 @@ class ModelOptFp4Config(QuantizationConfig):
628
654
  def is_layer_excluded(self, prefix: str, exclude_modules: list):
629
655
  import regex as re
630
656
 
657
+ fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
658
+ prefix_split = prefix.split(".")
631
659
  for pattern in exclude_modules:
632
660
  regex_str = pattern.replace(".", r"\.").replace("*", r".*")
661
+ pattern_split = pattern.split(".")
633
662
  if re.fullmatch(regex_str, prefix):
634
663
  return True
635
-
636
- # Check if the last part of the excluded pattern is contained in the last part of the prefix
637
- # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
638
- pattern_last_part = pattern.split(".")[-1]
639
- prefix_last_part = prefix.split(".")[-1]
640
- if pattern_last_part in prefix_last_part:
664
+ elif (
665
+ pattern_split[-1] in fused_patterns
666
+ and pattern_split[-1] in prefix_split[-1]
667
+ ):
668
+ # Check if the last part of the excluded pattern is contained in the last part of the prefix
669
+ # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
670
+ # e.g., model.layers.{i}.self_attn.{fused_weight_name}
671
+ assert len(prefix_split) == 5 and len(pattern_split) == 5
641
672
  return True
642
673
  return False
643
674
 
@@ -821,14 +852,25 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
821
852
  if enable_flashinfer_fp4_gemm:
822
853
  w = layer.weight.T
823
854
  w_scale_interleaved = layer.weight_scale_interleaved.T
824
- out = fp4_gemm(
825
- x_fp4,
826
- w,
827
- x_scale_interleaved,
828
- w_scale_interleaved,
829
- layer.alpha,
830
- output_dtype,
831
- )
855
+ if USE_CUTLASS_BACKEND_FOR_FP4_GEMM:
856
+ out = fp4_gemm(
857
+ x_fp4,
858
+ w,
859
+ x_scale_interleaved,
860
+ w_scale_interleaved,
861
+ layer.alpha,
862
+ output_dtype,
863
+ backend="cutlass",
864
+ )
865
+ else:
866
+ out = fp4_gemm(
867
+ x_fp4,
868
+ w,
869
+ x_scale_interleaved,
870
+ w_scale_interleaved,
871
+ layer.alpha,
872
+ output_dtype,
873
+ )
832
874
  if bias is not None:
833
875
  out = out + bias
834
876
  return out.view(*output_shape)
@@ -859,6 +901,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
859
901
  """Access the global enable_flashinfer_cutlass_moe setting."""
860
902
  return get_moe_runner_backend().is_flashinfer_cutlass()
861
903
 
904
+ @property
905
+ def enable_flashinfer_cutedsl_moe(self) -> bool:
906
+ from sglang.srt.layers.moe import get_moe_runner_backend
907
+
908
+ """Access the global enable_flashinfer_cutedsl_moe setting."""
909
+ return get_moe_runner_backend().is_flashinfer_cutedsl()
910
+
862
911
  def create_weights(
863
912
  self,
864
913
  layer: torch.nn.Module,
@@ -970,15 +1019,17 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
970
1019
  )
971
1020
 
972
1021
  w13_input_scale = PerTensorScaleParameter(
973
- data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
1022
+ data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
974
1023
  weight_loader=weight_loader,
975
1024
  )
1025
+ w13_input_scale._sglang_require_global_experts = True
976
1026
  layer.register_parameter("w13_input_scale", w13_input_scale)
977
1027
 
978
1028
  w2_input_scale = PerTensorScaleParameter(
979
- data=torch.empty(layer.num_local_experts, dtype=torch.float32),
1029
+ data=torch.empty(layer.num_experts, dtype=torch.float32),
980
1030
  weight_loader=weight_loader,
981
1031
  )
1032
+ w2_input_scale._sglang_require_global_experts = True
982
1033
  layer.register_parameter("w2_input_scale", w2_input_scale)
983
1034
 
984
1035
  def swizzle_blockscale(self, scale: torch.Tensor):
@@ -1161,6 +1212,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1161
1212
  if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
1162
1213
  w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
1163
1214
  w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
1215
+ elif self.enable_flashinfer_cutedsl_moe:
1216
+ # All-expert-one-input-scale is mathematically different from default per-expert-input-scale
1217
+ # Thus we allow users to switch the flag to do thorough testing
1218
+ if CUTEDSL_MOE_SCALAR_INPUT_SCALE:
1219
+ w13_input_scale = (
1220
+ layer.w13_input_scale.max()
1221
+ .to(torch.float32)
1222
+ .repeat(layer.w13_input_scale.shape[0])
1223
+ )
1224
+ else:
1225
+ w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
1226
+ torch.float32
1227
+ )
1228
+
1229
+ w2_input_scale = layer.w2_input_scale
1230
+
1231
+ def _slice_scale(w):
1232
+ assert w.shape == (layer.num_experts,)
1233
+ assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
1234
+ return w[
1235
+ layer.moe_ep_rank
1236
+ * layer.num_local_experts : (layer.moe_ep_rank + 1)
1237
+ * layer.num_local_experts
1238
+ ]
1239
+
1240
+ w13_input_scale = _slice_scale(w13_input_scale)
1241
+ w2_input_scale = _slice_scale(w2_input_scale)
1242
+
1243
+ if CUTEDSL_MOE_NVFP4_DISPATCH:
1244
+ assert torch.all(w13_input_scale == w13_input_scale[0])
1245
+ w13_input_scale = w13_input_scale[0]
1164
1246
  else:
1165
1247
  w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1166
1248
  w2_input_scale = layer.w2_input_scale
@@ -1243,8 +1325,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1243
1325
  layer.w13_weight_scale,
1244
1326
  )
1245
1327
 
1246
- logger.info_once("Applied flashinfer weight processing for both w13 and w2")
1247
-
1248
1328
  else:
1249
1329
  # CUTLASS processing - handle w13 and w2 separately
1250
1330
 
@@ -1261,7 +1341,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1261
1341
  layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1262
1342
 
1263
1343
  # Both flashinfer cutlass and regular cutlass use same processing for w2
1264
- logger.info_once("Applied weight processing for both w13 and w2")
1265
1344
 
1266
1345
  # Set up CUTLASS MoE parameters
1267
1346
  device = layer.w13_weight.device
@@ -1278,21 +1357,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1278
1357
  # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
1279
1358
  return self.enable_flashinfer_cutlass_moe
1280
1359
 
1360
+ def create_moe_runner(
1361
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1362
+ ):
1363
+ self.moe_runner_config = moe_runner_config
1364
+
1281
1365
  def apply(
1282
1366
  self,
1283
1367
  layer: FusedMoE,
1284
- x: torch.Tensor,
1285
- topk_output: TopKOutput,
1286
- moe_runner_config: MoeRunnerConfig,
1287
- ) -> torch.Tensor:
1368
+ dispatch_output: StandardDispatchOutput,
1369
+ ) -> CombineInput:
1370
+
1371
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1372
+
1373
+ x = dispatch_output.hidden_states
1374
+ topk_output = dispatch_output.topk_output
1375
+
1288
1376
  assert (
1289
- moe_runner_config.activation == "silu"
1377
+ self.moe_runner_config.activation == "silu"
1290
1378
  ), "Only SiLU activation is supported."
1291
1379
 
1380
+ moe_runner_config = self.moe_runner_config
1381
+
1292
1382
  # Check if this is a FlashInferFP4MoE layer that should handle its own forward
1293
1383
  if hasattr(layer, "gemm1_weights_fp4_shuffled"):
1294
1384
  # This layer was processed with flashinfer TRTLLM - delegate to its own forward
1295
- return layer.forward(x, topk_output)
1385
+ return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
1296
1386
 
1297
1387
  if self.enable_flashinfer_cutlass_moe:
1298
1388
  assert (
@@ -1345,13 +1435,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1345
1435
  tp_rank=layer.moe_tp_rank,
1346
1436
  tune_max_num_tokens=next_power_of_2(x.shape[0]),
1347
1437
  )[0]
1348
- # Scale by routed_scaling_factor is fused into select_experts.
1349
1438
  if should_use_flashinfer_cutlass_moe_fp4_allgather():
1350
1439
  output, global_output = get_local_dp_buffer(), output
1351
1440
  get_tp_group().reduce_scatterv(
1352
1441
  global_output, output=output, sizes=get_dp_global_num_tokens()
1353
1442
  )
1354
- return output
1443
+ return StandardCombineInput(hidden_states=output)
1355
1444
 
1356
1445
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
1357
1446
 
@@ -1372,4 +1461,50 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1372
1461
  apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
1373
1462
  ).to(x.dtype)
1374
1463
  # Scale by routed_scaling_factor is fused into select_experts.
1375
- return output
1464
+ return StandardCombineInput(hidden_states=output)
1465
+
1466
+ def apply_without_routing_weights(
1467
+ self,
1468
+ layer: FusedMoE,
1469
+ x: torch.Tensor,
1470
+ masked_m: torch.Tensor,
1471
+ moe_runner_config: MoeRunnerConfig,
1472
+ down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
1473
+ ) -> torch.Tensor:
1474
+ assert (
1475
+ moe_runner_config.activation == "silu"
1476
+ ), "Only SiLU activation is supported."
1477
+
1478
+ assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe"
1479
+ assert (
1480
+ not moe_runner_config.apply_router_weight_on_input
1481
+ ), "apply_router_weight_on_input is not supported for Flashinfer"
1482
+
1483
+ from sglang.srt.layers.moe.flashinfer_cutedsl_moe import (
1484
+ flashinfer_cutedsl_moe_masked,
1485
+ )
1486
+
1487
+ out = flashinfer_cutedsl_moe_masked(
1488
+ hidden_states=x,
1489
+ input_global_scale=(
1490
+ None if CUTEDSL_MOE_NVFP4_DISPATCH else layer.w13_input_scale_quant
1491
+ ),
1492
+ w1=layer.w13_weight,
1493
+ w1_blockscale=layer.w13_blockscale_swizzled,
1494
+ w1_alpha=layer.g1_alphas,
1495
+ w2=layer.w2_weight,
1496
+ a2_global_scale=layer.w2_input_scale_quant,
1497
+ w2_blockscale=layer.w2_blockscale_swizzled,
1498
+ w2_alpha=layer.g2_alphas,
1499
+ masked_m=masked_m,
1500
+ **(
1501
+ dict(
1502
+ down_sm_count=down_gemm_overlap_args.num_sms,
1503
+ down_signals=down_gemm_overlap_args.signal,
1504
+ down_start_event=down_gemm_overlap_args.start_event,
1505
+ )
1506
+ if down_gemm_overlap_args is not None
1507
+ else {}
1508
+ ),
1509
+ )
1510
+ return out
@@ -9,6 +9,8 @@ import torch
9
9
 
10
10
  from sglang.srt.distributed import get_tensor_model_parallel_rank
11
11
  from sglang.srt.distributed.parallel_state import get_tp_group
12
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
13
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
12
14
  from sglang.srt.layers.quantization.awq import AWQConfig
13
15
  from sglang.srt.layers.quantization.base_config import (
14
16
  FusedMoEMethodBase,
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
22
24
  logger = logging.getLogger(__name__)
23
25
 
24
26
  if TYPE_CHECKING:
25
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
26
- from sglang.srt.layers.moe.topk import TopKOutput
27
+ from sglang.srt.layers.moe.token_dispatcher import (
28
+ CombineInput,
29
+ StandardDispatchOutput,
30
+ )
27
31
 
28
32
 
29
33
  def get_weight_perm(num_bits: int):
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
349
353
  layer.register_parameter(key, param)
350
354
  set_weight_attrs(param, extra_weight_attrs)
351
355
 
356
+ def create_moe_runner(
357
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
358
+ ):
359
+ self.moe_runner_config = moe_runner_config
360
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
361
+
352
362
  def apply(
353
363
  self,
354
364
  layer: torch.nn.Module,
355
- x: torch.Tensor,
356
- topk_output: TopKOutput,
357
- moe_runner_config: MoeRunnerConfig,
358
- ) -> torch.Tensor:
359
- # avoid circular import
360
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
361
-
365
+ dispatch_output: StandardDispatchOutput,
366
+ ) -> CombineInput:
362
367
  assert (
363
- moe_runner_config.activation == "silu"
368
+ self.moe_runner_config.activation == "silu"
364
369
  ), "Only SiLU activation is supported."
365
370
 
366
371
  weight_bits = self.quant_config.weight_bits
367
372
  has_zp = self.quant_config.has_zp
368
373
 
369
- return fused_experts(
370
- x,
371
- layer.w13_qweight,
372
- layer.w2_qweight,
373
- topk_output=topk_output,
374
- moe_runner_config=moe_runner_config,
374
+ quant_info = TritonMoeQuantInfo(
375
+ w13_weight=layer.w13_qweight,
376
+ w2_weight=layer.w2_qweight,
375
377
  use_int4_w4a16=weight_bits == 4,
376
378
  use_int8_w8a16=weight_bits == 8,
377
- w1_scale=layer.w13_scales,
379
+ w13_scale=layer.w13_scales,
378
380
  w2_scale=layer.w2_scales,
379
- w1_zp=layer.w13_qzeros if has_zp else None,
381
+ w13_zp=layer.w13_qzeros if has_zp else None,
380
382
  w2_zp=layer.w2_qzeros if has_zp else None,
381
383
  block_shape=[0, layer.group_size],
382
384
  )
385
+ return self.runner.run(dispatch_output, quant_info)
383
386
 
384
387
  @staticmethod
385
388
  def get_weight_loader(layer, weight_loader):