sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,339 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import logging
5
+ from typing import Callable, Optional
6
+
7
+ import torch
8
+ from compressed_tensors.quantization import ActivationOrdering
9
+
10
+ # yapf conflicts with isort for this block
11
+ # yapf: disable
12
+ from sglang.srt.layers.parameter import (
13
+ BasevLLMParameter,
14
+ ChannelQuantScaleParameter,
15
+ GroupQuantScaleParameter,
16
+ PackedColumnParameter,
17
+ PackedvLLMParameter,
18
+ RowvLLMParameter,
19
+ permute_param_layout_,
20
+ )
21
+ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
22
+ CompressedTensorsScheme,
23
+ )
24
+ from sglang.srt.layers.quantization.marlin_utils import (
25
+ MarlinLinearLayerConfig,
26
+ apply_gptq_marlin_linear,
27
+ check_marlin_supports_shape,
28
+ marlin_is_k_full,
29
+ marlin_make_empty_g_idx,
30
+ marlin_make_workspace,
31
+ marlin_permute_scales,
32
+ marlin_repeat_scales_on_all_ranks,
33
+ marlin_sort_g_idx,
34
+ marlin_zero_points,
35
+ )
36
+ from sglang.srt.layers.quantization.utils import (
37
+ get_scalar_types,
38
+ replace_parameter,
39
+ unpack_cols,
40
+ )
41
+ from sglang.srt.utils import is_cuda
42
+
43
+ _is_cuda = is_cuda()
44
+
45
+ if _is_cuda:
46
+ from sgl_kernel import gptq_marlin_repack
47
+
48
+
49
+ ScalarType, scalar_types = get_scalar_types()
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+ __all__ = ["CompressedTensorsWNA16"]
54
+ WNA16_SUPPORTED_TYPES_MAP = {
55
+ 4: scalar_types.uint4b8,
56
+ 8: scalar_types.uint8b128
57
+ }
58
+ WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
59
+ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
60
+
61
+
62
+ class CompressedTensorsWNA16(CompressedTensorsScheme):
63
+ _kernel_backends_being_used: set[str] = set()
64
+
65
+ def __init__(self,
66
+ strategy: str,
67
+ num_bits: int,
68
+ group_size: Optional[int] = None,
69
+ symmetric: Optional[bool] = True,
70
+ actorder: Optional[ActivationOrdering] = None):
71
+
72
+ self.pack_factor = 32 // num_bits
73
+ self.strategy = strategy
74
+ self.symmetric = symmetric
75
+ self.group_size = -1 if group_size is None else group_size
76
+ self.has_g_idx = actorder == ActivationOrdering.GROUP
77
+
78
+ if self.group_size == -1 and self.strategy != "channel":
79
+ raise ValueError("Marlin kernels require group quantization or "
80
+ "channelwise quantization, but found no group "
81
+ "size and strategy is not channelwise.")
82
+
83
+ if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
84
+ raise ValueError(
85
+ f"Unsupported num_bits = {num_bits}. "
86
+ f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
87
+
88
+ self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
89
+ if not self.symmetric else
90
+ WNA16_SUPPORTED_TYPES_MAP[num_bits])
91
+
92
+ @classmethod
93
+ def get_min_capability(cls) -> int:
94
+ # ampere and up
95
+ return 80
96
+
97
+ def create_weights(self, layer: torch.nn.Module, output_size: int,
98
+ input_size: int, output_partition_sizes: list[int],
99
+ input_size_per_partition: int,
100
+ params_dtype: torch.dtype, weight_loader: Callable,
101
+ **kwargs):
102
+
103
+ output_size_per_partition = sum(output_partition_sizes)
104
+
105
+ self.kernel_config = MarlinLinearLayerConfig(
106
+ full_weight_shape=(input_size, output_size),
107
+ partition_weight_shape=(
108
+ input_size_per_partition,
109
+ output_size_per_partition,
110
+ ),
111
+ weight_type=self.quant_type,
112
+ act_type=params_dtype,
113
+ group_size=self.group_size,
114
+ zero_points=not self.symmetric,
115
+ has_g_idx=self.has_g_idx
116
+ )
117
+
118
+ # If group_size is -1, we are in channelwise case.
119
+ group_size = self.group_size if self.group_size != -1 else input_size
120
+ row_parallel = (input_size != input_size_per_partition)
121
+ partition_scales = not marlin_repeat_scales_on_all_ranks(
122
+ self.has_g_idx, self.group_size, row_parallel)
123
+
124
+ scales_and_zp_size = input_size // group_size
125
+
126
+ if partition_scales:
127
+ assert input_size_per_partition % group_size == 0
128
+ scales_and_zp_size = input_size_per_partition // group_size
129
+
130
+ weight = PackedvLLMParameter(input_dim=1,
131
+ output_dim=0,
132
+ weight_loader=weight_loader,
133
+ packed_factor=self.pack_factor,
134
+ packed_dim=1,
135
+ data=torch.empty(
136
+ output_size_per_partition,
137
+ input_size_per_partition //
138
+ self.pack_factor,
139
+ dtype=torch.int32,
140
+ ))
141
+
142
+ weight_scale_args = {
143
+ "weight_loader":
144
+ weight_loader,
145
+ "data":
146
+ torch.empty(
147
+ output_size_per_partition,
148
+ scales_and_zp_size,
149
+ dtype=params_dtype,
150
+ )
151
+ }
152
+
153
+ zeros_args = {
154
+ "weight_loader":
155
+ weight_loader,
156
+ "data":
157
+ torch.zeros(
158
+ output_size_per_partition // self.pack_factor,
159
+ scales_and_zp_size,
160
+ dtype=torch.int32,
161
+ )
162
+ }
163
+
164
+ if not partition_scales:
165
+ weight_scale = ChannelQuantScaleParameter(output_dim=0,
166
+ **weight_scale_args)
167
+
168
+ if not self.symmetric:
169
+ qzeros = PackedColumnParameter(output_dim=0,
170
+ packed_dim=0,
171
+ packed_factor=self.pack_factor,
172
+ **zeros_args)
173
+ else:
174
+ weight_scale = GroupQuantScaleParameter(output_dim=0,
175
+ input_dim=1,
176
+ **weight_scale_args)
177
+ if not self.symmetric:
178
+ qzeros = PackedvLLMParameter(input_dim=1,
179
+ output_dim=0,
180
+ packed_dim=0,
181
+ packed_factor=self.pack_factor,
182
+ **zeros_args)
183
+
184
+ # A 2D array defining the original shape of the weights
185
+ # before packing
186
+ weight_shape = BasevLLMParameter(data=torch.empty(2,
187
+ dtype=torch.int64),
188
+ weight_loader=weight_loader)
189
+
190
+ layer.register_parameter("weight_packed", weight)
191
+ layer.register_parameter("weight_scale", weight_scale)
192
+ layer.register_parameter("weight_shape", weight_shape)
193
+
194
+ if not self.symmetric:
195
+ layer.register_parameter("weight_zero_point", qzeros)
196
+
197
+ # group index (for activation reordering)
198
+ if self.has_g_idx:
199
+ weight_g_idx = RowvLLMParameter(data=torch.empty(
200
+ input_size_per_partition,
201
+ dtype=torch.int32,
202
+ ),
203
+ input_dim=0,
204
+ weight_loader=weight_loader)
205
+ layer.register_parameter("weight_g_idx", weight_g_idx)
206
+
207
+ # Checkpoints are serialized in compressed-tensors format, which is
208
+ # different from the format the kernel may want. Handle repacking here.
209
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
210
+ # Default names since marlin requires empty parameters for these,
211
+ # TODO: remove this requirement from marlin (allow optional tensors)
212
+ self.w_q_name = "weight_packed"
213
+ self.w_s_name = "weight_scale"
214
+ self.w_zp_name = "weight_zero_point"
215
+ self.w_gidx_name = "weight_g_idx"
216
+
217
+ device = getattr(layer, self.w_q_name).device
218
+ c = self.kernel_config
219
+
220
+ check_marlin_supports_shape(
221
+ c.partition_weight_shape[1], # out_features
222
+ c.partition_weight_shape[0], # in_features
223
+ c.full_weight_shape[0], # in_features
224
+ c.group_size,
225
+ )
226
+
227
+ row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
228
+ self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
229
+
230
+ # Allocate marlin workspace.
231
+ self.workspace = marlin_make_workspace(device)
232
+
233
+ def _transform_param(
234
+ layer: torch.nn.Module, name: Optional[str], fn: Callable
235
+ ) -> None:
236
+ if name is not None and getattr(layer, name, None) is not None:
237
+
238
+ old_param = getattr(layer, name)
239
+ new_param = fn(old_param)
240
+ # replace the parameter with torch.nn.Parameter for TorchDynamo
241
+ # compatibility
242
+ replace_parameter(
243
+ layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
244
+ )
245
+
246
+ def transform_w_q(x):
247
+ assert isinstance(x, BasevLLMParameter)
248
+ permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
249
+ x.data = gptq_marlin_repack(
250
+ x.data.contiguous(),
251
+ perm=layer.g_idx_sort_indices,
252
+ size_k=c.partition_weight_shape[0],
253
+ size_n=c.partition_weight_shape[1],
254
+ num_bits=c.weight_type.size_bits,
255
+ )
256
+ return x
257
+
258
+ def transform_w_s(x):
259
+ assert isinstance(x, BasevLLMParameter)
260
+ permute_param_layout_(x, input_dim=0, output_dim=1)
261
+ x.data = marlin_permute_scales(
262
+ x.data.contiguous(),
263
+ size_k=c.partition_weight_shape[0],
264
+ size_n=c.partition_weight_shape[1],
265
+ group_size=c.group_size,
266
+ )
267
+ return x
268
+
269
+ if c.has_g_idx:
270
+ g_idx, g_idx_sort_indices = marlin_sort_g_idx(
271
+ getattr(layer, self.w_gidx_name)
272
+ )
273
+ _transform_param(layer, self.w_gidx_name, lambda _: g_idx)
274
+ layer.g_idx_sort_indices = g_idx_sort_indices
275
+ else:
276
+ setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
277
+ layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
278
+
279
+ if c.zero_points:
280
+ grouped_k = (
281
+ c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
282
+ )
283
+ _transform_param(
284
+ layer,
285
+ self.w_zp_name,
286
+ lambda x: marlin_zero_points(
287
+ unpack_cols(
288
+ x.t(),
289
+ c.weight_type.size_bits,
290
+ grouped_k,
291
+ c.partition_weight_shape[1],
292
+ ),
293
+ size_k=grouped_k,
294
+ size_n=c.partition_weight_shape[1],
295
+ num_bits=c.weight_type.size_bits,
296
+ ),
297
+ )
298
+ else:
299
+ setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
300
+ _transform_param(layer, self.w_q_name, transform_w_q)
301
+ _transform_param(layer, self.w_s_name, transform_w_s)
302
+
303
+ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
304
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
305
+ c = self.kernel_config
306
+
307
+ def _get_weight_params(
308
+ layer: torch.nn.Module,
309
+ ) -> tuple[
310
+ torch.Tensor, # w_q
311
+ torch.Tensor, # w_s
312
+ Optional[torch.Tensor], # w_zp,
313
+ Optional[torch.Tensor], # w_gidx
314
+ ]:
315
+ return (
316
+ getattr(layer, self.w_q_name),
317
+ getattr(layer, self.w_s_name),
318
+ getattr(layer, self.w_zp_name or "", None),
319
+ getattr(layer, self.w_gidx_name or "", None),
320
+ )
321
+
322
+ w_q, w_s, w_zp, w_gidx = _get_weight_params(layer)
323
+
324
+ # `process_weights_after_loading` will ensure w_zp and w_gidx are not
325
+ # None for marlin
326
+ return apply_gptq_marlin_linear(
327
+ input=x,
328
+ weight=w_q,
329
+ weight_scale=w_s,
330
+ weight_zp=w_zp, # type: ignore
331
+ g_idx=w_gidx, # type: ignore
332
+ g_idx_sort_indices=layer.g_idx_sort_indices,
333
+ workspace=self.workspace,
334
+ wtype=c.weight_type,
335
+ input_size_per_partition=c.partition_weight_shape[0],
336
+ output_size_per_partition=c.partition_weight_shape[1],
337
+ is_k_full=self.is_k_full,
338
+ bias=bias,
339
+ )
@@ -31,8 +31,9 @@ except ImportError:
31
31
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
32
32
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
33
33
  from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
34
+ from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
34
35
  from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
35
- from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
36
+ from sglang.srt.layers.moe.utils import get_moe_runner_backend
36
37
  from sglang.srt.layers.parameter import (
37
38
  BlockQuantScaleParameter,
38
39
  ModelWeightParameter,
@@ -525,12 +526,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
525
526
  self.quant_config = quant_config
526
527
  self.block_quant = self.quant_config.weight_block_size is not None
527
528
  self.cutlass_fp8_supported = cutlass_fp8_supported()
528
- self.use_cutlass_fused_experts_fp8 = (
529
- get_bool_env_var("SGLANG_CUTLASS_MOE")
530
- and self.cutlass_fp8_supported
531
- and self.block_quant
532
- and (is_sm100_supported() or is_sm90_supported())
533
- )
534
529
 
535
530
  def create_weights(
536
531
  self,
@@ -638,58 +633,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
638
633
  layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
639
634
  layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
640
635
  assert self.quant_config.activation_scheme == "dynamic"
641
- if self.use_cutlass_fused_experts_fp8:
642
- self.ab_strides1 = torch.full(
643
- (num_experts,),
644
- hidden_size,
645
- device=w13_weight.device,
646
- dtype=torch.int64,
647
- )
648
- self.c_strides1 = torch.full(
649
- (num_experts,),
650
- 2 * intermediate_size_per_partition,
651
- device=w13_weight.device,
652
- dtype=torch.int64,
653
- )
654
- self.ab_strides2 = torch.full(
655
- (num_experts,),
656
- intermediate_size_per_partition,
657
- device=w2_weight.device,
658
- dtype=torch.int64,
659
- )
660
- self.c_strides2 = torch.full(
661
- (num_experts,),
662
- hidden_size,
663
- device=w2_weight.device,
664
- dtype=torch.int64,
665
- )
666
- self.workspace = torch.empty(
667
- 90000, device=w13_weight.device, dtype=torch.uint8
668
- )
669
- self.a_ptr = torch.empty(
670
- num_experts, device=w13_weight.device, dtype=torch.int64
671
- )
672
- self.b_ptr = torch.empty(
673
- num_experts, device=w13_weight.device, dtype=torch.int64
674
- )
675
- self.out_ptr = torch.empty(
676
- num_experts, device=w13_weight.device, dtype=torch.int64
677
- )
678
- self.a_scales_ptr = torch.empty(
679
- num_experts, device=w13_weight.device, dtype=torch.int64
680
- )
681
- self.b_scales_ptr = torch.empty(
682
- num_experts, device=w13_weight.device, dtype=torch.int64
683
- )
684
- self.expert_offsets = torch.empty(
685
- num_experts + 1, device=w13_weight.device, dtype=torch.int32
686
- )
687
- self.problem_sizes1 = torch.empty(
688
- num_experts, 3, device=w13_weight.device, dtype=torch.int32
689
- )
690
- self.problem_sizes2 = torch.empty(
691
- num_experts, 3, device=w13_weight.device, dtype=torch.int32
692
- )
636
+ if self._should_use_cutlass_fused_experts():
637
+ self._ensure_cutlass_buffers_initialized(layer)
693
638
 
694
639
  else:
695
640
  # Allocate 2 scales for w1 and w3 respectively.
@@ -1006,8 +951,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1006
951
  def create_moe_runner(
1007
952
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1008
953
  ):
954
+
955
+ from sglang.srt.layers import deep_gemm_wrapper
956
+ from sglang.srt.layers.moe.utils import (
957
+ get_moe_a2a_backend,
958
+ get_moe_runner_backend,
959
+ )
960
+
1009
961
  self.moe_runner_config = moe_runner_config
1010
- self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
962
+ moe_runner_backend = get_moe_runner_backend()
963
+
964
+ if moe_runner_backend.is_auto():
965
+ if (
966
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
967
+ and get_moe_a2a_backend().is_deepep()
968
+ ):
969
+ moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
970
+ else:
971
+ moe_runner_backend = MoeRunnerBackend.TRITON
972
+ if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
973
+ self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
974
+ else:
975
+ # TODO(cwan): refactor other backends
976
+ pass
1011
977
 
1012
978
  def apply(
1013
979
  self,
@@ -1018,13 +984,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1018
984
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1019
985
 
1020
986
  x = dispatch_output.hidden_states
1021
- topk_output = dispatch_output.topk_output
1022
987
  moe_runner_config = self.moe_runner_config
1023
988
 
1024
989
  if use_intel_amx_backend(layer):
1025
990
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
1026
991
 
1027
- topk_weights, topk_ids, _ = topk_output
992
+ topk_weights, topk_ids, _ = dispatch_output.topk_output
1028
993
  x, topk_weights = apply_topk_weights_cpu(
1029
994
  moe_runner_config.apply_router_weight_on_input, topk_weights, x
1030
995
  )
@@ -1051,17 +1016,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1051
1016
  ret = self.maybe_apply_hip_fused_experts(
1052
1017
  layer,
1053
1018
  x,
1054
- topk_output,
1019
+ dispatch_output.topk_output,
1055
1020
  moe_runner_config.activation,
1056
1021
  moe_runner_config.no_combine,
1057
1022
  )
1058
1023
  if ret is not None:
1059
1024
  return StandardCombineInput(hidden_states=ret)
1060
1025
 
1061
- if self.use_cutlass_fused_experts_fp8:
1026
+ if self._should_use_cutlass_fused_experts():
1062
1027
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
1063
1028
 
1064
- topk_weights, topk_ids, _ = topk_output
1029
+ topk_weights, topk_ids, _ = dispatch_output.topk_output
1065
1030
  output = cutlass_fused_experts_fp8(
1066
1031
  x,
1067
1032
  layer.w13_weight.transpose(1, 2),
@@ -1087,24 +1052,130 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1087
1052
  )
1088
1053
  return StandardCombineInput(hidden_states=output)
1089
1054
 
1090
- quant_info = TritonMoeQuantInfo(
1091
- w13_weight=layer.w13_weight,
1092
- w2_weight=layer.w2_weight,
1093
- use_fp8_w8a8=True,
1094
- w13_scale=(
1095
- layer.w13_weight_scale_inv
1096
- if self.block_quant
1097
- else layer.w13_weight_scale
1098
- ),
1099
- w2_scale=(
1100
- layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
1101
- ),
1102
- a13_scale=layer.w13_input_scale,
1103
- a2_scale=layer.w2_input_scale,
1104
- block_shape=self.quant_config.weight_block_size,
1105
- )
1055
+ if self.runner.runner_backend.is_deep_gemm():
1056
+
1057
+ w13_weight = layer.w13_weight
1058
+ w2_weight = layer.w2_weight
1059
+
1060
+ if self.block_quant:
1061
+ block_shape = self.quant_config.weight_block_size
1062
+ w13_scale = layer.w13_weight_scale_inv
1063
+ w2_scale = layer.w2_weight_scale_inv
1064
+ else:
1065
+ # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
1066
+ scale_block_size = 128
1067
+ block_shape = [scale_block_size, scale_block_size]
1068
+ w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
1069
+ w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
1070
+ w13_scale = (
1071
+ layer.w13_weight_scale.unsqueeze(1)
1072
+ .repeat_interleave(w13_scale_n, dim=1)
1073
+ .unsqueeze(2)
1074
+ .repeat_interleave(w13_scale_k, dim=2)
1075
+ )
1076
+ w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
1077
+ w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
1078
+ w2_scale = (
1079
+ layer.w2_weight_scale.unsqueeze(1)
1080
+ .repeat_interleave(w2_scale_n, dim=1)
1081
+ .unsqueeze(2)
1082
+ .repeat_interleave(w2_scale_k, dim=2)
1083
+ )
1084
+ quant_info = DeepGemmMoeQuantInfo(
1085
+ w13_weight=w13_weight,
1086
+ w2_weight=w2_weight,
1087
+ use_fp8=True,
1088
+ w13_scale=w13_scale,
1089
+ w2_scale=w2_scale,
1090
+ block_shape=block_shape,
1091
+ )
1092
+ elif self.runner.runner_backend.is_triton():
1093
+ quant_info = TritonMoeQuantInfo(
1094
+ w13_weight=layer.w13_weight,
1095
+ w2_weight=layer.w2_weight,
1096
+ use_fp8_w8a8=True,
1097
+ w13_scale=(
1098
+ layer.w13_weight_scale_inv
1099
+ if self.block_quant
1100
+ else layer.w13_weight_scale
1101
+ ),
1102
+ w2_scale=(
1103
+ layer.w2_weight_scale_inv
1104
+ if self.block_quant
1105
+ else layer.w2_weight_scale
1106
+ ),
1107
+ a13_scale=layer.w13_input_scale,
1108
+ a2_scale=layer.w2_input_scale,
1109
+ block_shape=self.quant_config.weight_block_size,
1110
+ )
1111
+ else:
1112
+ raise NotImplementedError(
1113
+ "Unsupported runner backend: %s" % self.runner.runner_backend
1114
+ )
1115
+
1106
1116
  return self.runner.run(dispatch_output, quant_info)
1107
1117
 
1118
+ def _should_use_cutlass_fused_experts(self) -> bool:
1119
+ """Decide whether to use Cutlass FP8 fused-experts path based on moe runner backend,
1120
+ with env var override via `SGLANG_CUTLASS_MOE`.
1121
+ """
1122
+ backend = get_moe_runner_backend()
1123
+ env_force = get_bool_env_var("SGLANG_CUTLASS_MOE")
1124
+ # TODO: remove env var in the future, it should be handled by moe runner backend
1125
+ if env_force:
1126
+ return True
1127
+ return (
1128
+ backend.is_flashinfer_cutlass()
1129
+ and self.cutlass_fp8_supported
1130
+ and self.block_quant
1131
+ and (is_sm100_supported() or is_sm90_supported())
1132
+ )
1133
+
1134
+ def _ensure_cutlass_buffers_initialized(self, layer: Module) -> None:
1135
+ if getattr(self, "_cutlass_buffers_ready", False):
1136
+ return
1137
+
1138
+ device = layer.w13_weight.device
1139
+ num_experts = layer.w13_weight.shape[0]
1140
+ hidden_size = layer.w2_weight.shape[1]
1141
+ intermediate_size_per_partition = layer.intermediate_size_per_partition
1142
+
1143
+ self.ab_strides1 = torch.full(
1144
+ (num_experts,), hidden_size, device=device, dtype=torch.int64
1145
+ )
1146
+ self.c_strides1 = torch.full(
1147
+ (num_experts,),
1148
+ 2 * intermediate_size_per_partition,
1149
+ device=device,
1150
+ dtype=torch.int64,
1151
+ )
1152
+ self.ab_strides2 = torch.full(
1153
+ (num_experts,),
1154
+ intermediate_size_per_partition,
1155
+ device=device,
1156
+ dtype=torch.int64,
1157
+ )
1158
+ self.c_strides2 = torch.full(
1159
+ (num_experts,), hidden_size, device=device, dtype=torch.int64
1160
+ )
1161
+ self.workspace = torch.empty(90000, device=device, dtype=torch.uint8)
1162
+ self.a_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1163
+ self.b_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1164
+ self.out_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1165
+ self.a_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1166
+ self.b_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1167
+ self.expert_offsets = torch.empty(
1168
+ num_experts + 1, device=device, dtype=torch.int32
1169
+ )
1170
+ self.problem_sizes1 = torch.empty(
1171
+ num_experts, 3, device=device, dtype=torch.int32
1172
+ )
1173
+ self.problem_sizes2 = torch.empty(
1174
+ num_experts, 3, device=device, dtype=torch.int32
1175
+ )
1176
+
1177
+ self._cutlass_buffers_ready = True
1178
+
1108
1179
  def apply_with_router_logits(
1109
1180
  self,
1110
1181
  layer: torch.nn.Module,