sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -79,7 +79,7 @@ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
79
79
  "SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
80
80
  )
81
81
  USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
82
- "SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
82
+ "SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM", "true"
83
83
  )
84
84
  # TODO make it true by default when the DeepEP PR is merged
85
85
  CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
@@ -90,7 +90,50 @@ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
90
90
  ACTIVATION_SCHEMES = ["static"]
91
91
 
92
92
 
93
- class ModelOptFp8Config(QuantizationConfig):
93
+ class ModelOptQuantConfig(QuantizationConfig):
94
+ def __init__(
95
+ self,
96
+ kv_cache_quant_algo: Optional[str],
97
+ exclude_modules: Optional[List[str]],
98
+ packed_modules_mapping: Optional[Dict[str, List[str]]],
99
+ ):
100
+ super().__init__()
101
+ self.packed_modules_mapping = packed_modules_mapping
102
+ self.exclude_modules = exclude_modules or []
103
+ self.kv_cache_quant_algo = kv_cache_quant_algo
104
+
105
+ def _get_quant_method(
106
+ self,
107
+ layer: torch.nn.Module,
108
+ prefix: str,
109
+ *,
110
+ Linear: type[LinearMethodBase],
111
+ Moe: type[FusedMoEMethodBase],
112
+ ) -> Optional[QuantizeMethodBase]:
113
+ from sglang.srt.layers.linear import LinearBase
114
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
115
+
116
+ if isinstance(layer, LinearBase):
117
+ if is_layer_skipped(
118
+ prefix, self.exclude_modules, self.packed_modules_mapping
119
+ ) or self.is_layer_excluded(prefix):
120
+ return UnquantizedLinearMethod()
121
+ return Linear(self)
122
+ elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
123
+ return ModelOptFp8KVCacheMethod(self)
124
+ elif isinstance(layer, FusedMoE):
125
+ return Moe(self)
126
+ return None
127
+
128
+ @classmethod
129
+ def get_config_filenames(cls) -> List[str]:
130
+ return ["hf_quant_config.json"]
131
+
132
+ def get_scaled_act_names(self) -> List[str]:
133
+ return []
134
+
135
+
136
+ class ModelOptFp8Config(ModelOptQuantConfig):
94
137
  """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
95
138
 
96
139
  def __init__(
@@ -98,22 +141,27 @@ class ModelOptFp8Config(QuantizationConfig):
98
141
  is_checkpoint_fp8_serialized: bool = False,
99
142
  kv_cache_quant_method: Optional[str] = None,
100
143
  exclude_modules: Optional[List[str]] = None,
144
+ packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
101
145
  ) -> None:
102
146
  """
103
147
  Args:
104
148
  is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
105
149
  """
150
+ super().__init__(kv_cache_quant_method, exclude_modules, packed_modules_mapping)
106
151
  self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
107
- self.kv_cache_quant_method = kv_cache_quant_method
108
- self.exclude_modules = exclude_modules
109
152
  if is_checkpoint_fp8_serialized:
110
153
  logger.warning(
111
154
  "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
112
155
  )
113
156
 
157
+ @classmethod
158
+ def override_quantization_method(cls, hf_quant_config, user_quant):
159
+ """Override quantization method based on the model's config."""
160
+ return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
161
+
114
162
  @classmethod
115
163
  def get_name(cls) -> str:
116
- return "modelopt"
164
+ return "modelopt_fp8"
117
165
 
118
166
  @classmethod
119
167
  def get_supported_act_dtypes(cls) -> List[torch.dtype]:
@@ -123,10 +171,6 @@ class ModelOptFp8Config(QuantizationConfig):
123
171
  def get_min_capability(cls) -> int:
124
172
  return 89 # Minimum hardware capability (e.g., Hopper GPUs).
125
173
 
126
- @classmethod
127
- def get_config_filenames(cls) -> List[str]:
128
- return ["hf_quant_config.json"]
129
-
130
174
  @classmethod
131
175
  def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
132
176
  # Handle two different config formats:
@@ -181,37 +225,27 @@ class ModelOptFp8Config(QuantizationConfig):
181
225
  is_checkpoint_fp8_serialized=True,
182
226
  kv_cache_quant_method=kv_cache_quant_method,
183
227
  exclude_modules=exclude_modules,
228
+ packed_modules_mapping=config.get("packed_modules_mapping"),
184
229
  )
185
230
 
186
- def get_quant_method(
187
- self, layer: torch.nn.Module, prefix: str
188
- ) -> Optional[QuantizeMethodBase]:
189
-
190
- from sglang.srt.layers.linear import LinearBase
191
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
192
-
193
- if self.exclude_modules and any(
231
+ def is_layer_excluded(self, prefix: str) -> bool:
232
+ if len(self.exclude_modules) == 0:
233
+ return False
234
+ return any(
194
235
  module in prefix
195
236
  or (
196
237
  prefix.startswith("language_model.")
197
238
  and module in prefix.removeprefix("language_model.")
198
239
  )
199
240
  for module in self.exclude_modules
200
- ):
201
- return None
202
-
203
- if isinstance(layer, LinearBase):
204
- return ModelOptFp8LinearMethod(self)
205
- if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
206
- return ModelOptFp8KVCacheMethod(self)
207
-
208
- if isinstance(layer, FusedMoE):
209
- return ModelOptFp8MoEMethod(self)
210
-
211
- return None
241
+ )
212
242
 
213
- def get_scaled_act_names(self) -> List[str]:
214
- return []
243
+ def get_quant_method(
244
+ self, layer: torch.nn.Module, prefix: str
245
+ ) -> Optional[QuantizeMethodBase]:
246
+ return self._get_quant_method(
247
+ layer, prefix, Linear=ModelOptFp8LinearMethod, Moe=ModelOptFp8MoEMethod
248
+ )
215
249
 
216
250
 
217
251
  class ModelOptFp8LinearMethod(LinearMethodBase):
@@ -507,7 +541,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
507
541
  return self.runner.run(dispatch_output, quant_info)
508
542
 
509
543
 
510
- class ModelOptFp4Config(QuantizationConfig):
544
+ class ModelOptFp4Config(ModelOptQuantConfig):
511
545
  """Config class for FP4."""
512
546
 
513
547
  def __init__(
@@ -516,7 +550,9 @@ class ModelOptFp4Config(QuantizationConfig):
516
550
  kv_cache_quant_algo: str = None,
517
551
  group_size: int = None,
518
552
  exclude_modules: List[str] = None,
553
+ packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
519
554
  ) -> None:
555
+ super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping)
520
556
  self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
521
557
  if is_checkpoint_nvfp4_serialized:
522
558
  logger.warning(
@@ -524,8 +560,11 @@ class ModelOptFp4Config(QuantizationConfig):
524
560
  "format is experimental and subject to change."
525
561
  )
526
562
  self.group_size = group_size
527
- self.kv_cache_quant_algo = kv_cache_quant_algo
528
- self.exclude_modules = exclude_modules
563
+
564
+ @classmethod
565
+ def override_quantization_method(cls, hf_quant_config, user_quant):
566
+ """Override quantization method based on the model's config."""
567
+ return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
529
568
 
530
569
  @classmethod
531
570
  def get_name(cls) -> str:
@@ -539,10 +578,6 @@ class ModelOptFp4Config(QuantizationConfig):
539
578
  def get_min_capability(cls) -> int:
540
579
  return 100
541
580
 
542
- @classmethod
543
- def get_config_filenames(cls) -> List[str]:
544
- return ["hf_quant_config.json"]
545
-
546
581
  @staticmethod
547
582
  def common_group_size(cfg: dict) -> int:
548
583
  """Return the unique group_size across the config; raise if missing/mismatched."""
@@ -608,7 +643,16 @@ class ModelOptFp4Config(QuantizationConfig):
608
643
  else:
609
644
  kv_cache_quant_algo = "auto"
610
645
 
611
- group_size = ModelOptFp4Config.common_group_size(config)
646
+ group_size = config.get("group_size")
647
+ # If group_size is not at top level, try to extract from config_groups
648
+ if group_size is None:
649
+ config_groups = config.get("config_groups", {})
650
+ if config_groups:
651
+ # Get group_size from the first group's weights config
652
+ first_group = next(iter(config_groups.values()), {})
653
+ weights_config = first_group.get("weights", {})
654
+ group_size = weights_config.get("group_size")
655
+
612
656
  exclude_modules = config.get("ignore", [])
613
657
  else:
614
658
  # Fall back to nested format (hf_quant_config.json - legacy format)
@@ -634,29 +678,30 @@ class ModelOptFp4Config(QuantizationConfig):
634
678
  )
635
679
  is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
636
680
 
637
- if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
681
+ if group_size is None or exclude_modules is None:
638
682
  logger.warning(
639
683
  f"group_size: {group_size},"
640
684
  f"kv_cache_quant_algo: {kv_cache_quant_algo},"
641
685
  f"exclude_modules: {exclude_modules}"
642
686
  )
643
687
  raise ValueError(
644
- "NVFP4 quantization requires group size and "
645
- "kv_cache_quant_algo specified in the quantization config"
688
+ "NVFP4 quantization requires group_size and exclude_modules "
689
+ "specified in the quantization config"
646
690
  )
647
691
  return cls(
648
692
  is_checkpoint_nvfp4_serialized,
649
693
  kv_cache_quant_algo,
650
694
  group_size,
651
695
  exclude_modules,
696
+ config.get("packed_modules_mapping"),
652
697
  )
653
698
 
654
- def is_layer_excluded(self, prefix: str, exclude_modules: list):
699
+ def is_layer_excluded(self, prefix: str):
655
700
  import regex as re
656
701
 
657
702
  fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
658
703
  prefix_split = prefix.split(".")
659
- for pattern in exclude_modules:
704
+ for pattern in self.exclude_modules:
660
705
  regex_str = pattern.replace(".", r"\.").replace("*", r".*")
661
706
  pattern_split = pattern.split(".")
662
707
  if re.fullmatch(regex_str, prefix):
@@ -672,30 +717,13 @@ class ModelOptFp4Config(QuantizationConfig):
672
717
  return True
673
718
  return False
674
719
 
675
- def get_quant_method(
676
- self, layer: torch.nn.Module, prefix: str
677
- ) -> Optional[QuantizeMethodBase]:
678
- from sglang.srt.layers.linear import LinearBase
679
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
680
- from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
681
-
682
- if isinstance(layer, LinearBase):
683
- if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
684
- prefix, self.exclude_modules
685
- ):
686
- return UnquantizedLinearMethod()
687
- return ModelOptFp4LinearMethod(self)
688
- if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
689
- return ModelOptFp8KVCacheMethod(self)
690
- elif isinstance(layer, FlashInferFP4MoE):
691
- # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
692
- return ModelOptNvFp4FusedMoEMethod(self)
693
- elif isinstance(layer, FusedMoE):
694
- return ModelOptNvFp4FusedMoEMethod(self)
695
- return None
696
-
697
- def get_scaled_act_names(self) -> List[str]:
698
- return []
720
+ def get_quant_method(self, layer: torch.nn.Module, prefix: str):
721
+ return self._get_quant_method(
722
+ layer,
723
+ prefix,
724
+ Linear=ModelOptFp4LinearMethod,
725
+ Moe=ModelOptNvFp4FusedMoEMethod, # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
726
+ )
699
727
 
700
728
 
701
729
  class ModelOptFp4LinearMethod(LinearMethodBase):
@@ -852,25 +880,15 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
852
880
  if enable_flashinfer_fp4_gemm:
853
881
  w = layer.weight.T
854
882
  w_scale_interleaved = layer.weight_scale_interleaved.T
855
- if USE_CUTLASS_BACKEND_FOR_FP4_GEMM:
856
- out = fp4_gemm(
857
- x_fp4,
858
- w,
859
- x_scale_interleaved,
860
- w_scale_interleaved,
861
- layer.alpha,
862
- output_dtype,
863
- backend="cutlass",
864
- )
865
- else:
866
- out = fp4_gemm(
867
- x_fp4,
868
- w,
869
- x_scale_interleaved,
870
- w_scale_interleaved,
871
- layer.alpha,
872
- output_dtype,
873
- )
883
+ out = fp4_gemm(
884
+ x_fp4,
885
+ w,
886
+ x_scale_interleaved,
887
+ w_scale_interleaved,
888
+ layer.alpha,
889
+ output_dtype,
890
+ **(dict(backend="cutlass") if USE_CUTLASS_BACKEND_FOR_FP4_GEMM else dict()),
891
+ )
874
892
  if bias is not None:
875
893
  out = out + bias
876
894
  return out.view(*output_shape)
@@ -1069,19 +1087,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1069
1087
  intermediate_size,
1070
1088
  num_experts,
1071
1089
  ):
1072
- from flashinfer import (
1073
- RoutingMethodType,
1074
- e2m1_and_ufp8sf_scale_to_float,
1075
- fp4_quantize,
1076
- next_positive_power_of_2,
1077
- nvfp4_block_scale_interleave,
1078
- reorder_rows_for_gated_act_gemm,
1079
- shuffle_matrix_a,
1080
- shuffle_matrix_sf_a,
1081
- )
1090
+ from flashinfer import nvfp4_block_scale_interleave
1082
1091
  from flashinfer.fused_moe.core import (
1083
- _maybe_get_cached_w2_permute_indices,
1084
1092
  _maybe_get_cached_w3_w1_permute_indices,
1093
+ get_w2_permute_indices_with_cache,
1085
1094
  )
1086
1095
 
1087
1096
  """Prepare quantized weights for kernel (done offline with weights)."""
@@ -1142,7 +1151,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1142
1151
  )
1143
1152
  )
1144
1153
 
1145
- permute_indices = _maybe_get_cached_w2_permute_indices(
1154
+ permute_indices = get_w2_permute_indices_with_cache(
1146
1155
  self._cache_permute_indices,
1147
1156
  gemm2_weights_fp4[i].view(torch.uint8),
1148
1157
  epilogue_tile_m,
@@ -1153,7 +1162,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1153
1162
  .contiguous()
1154
1163
  )
1155
1164
 
1156
- permute_sf_indices = _maybe_get_cached_w2_permute_indices(
1165
+ permute_sf_indices = get_w2_permute_indices_with_cache(
1157
1166
  self._cache_permute_indices,
1158
1167
  gemm2_scales_linear_fp4[i].view(torch.uint8),
1159
1168
  epilogue_tile_m,
@@ -1263,6 +1272,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1263
1272
  (1 / w2_input_scale).to(torch.float32), requires_grad=False
1264
1273
  )
1265
1274
 
1275
+ layer.dispatcher.set_quant_config(
1276
+ {"input_global_scale": layer.w13_input_scale_quant}
1277
+ )
1278
+
1266
1279
  # Validate weight scales
1267
1280
  for name, weight_scale in [
1268
1281
  ("w13", layer.w13_weight_scale),
@@ -1366,6 +1379,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1366
1379
  self,
1367
1380
  layer: FusedMoE,
1368
1381
  dispatch_output: StandardDispatchOutput,
1382
+ forward_shared_experts=None,
1383
+ alt_stream=None,
1369
1384
  ) -> CombineInput:
1370
1385
 
1371
1386
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
@@ -1437,9 +1452,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1437
1452
  )[0]
1438
1453
  if should_use_flashinfer_cutlass_moe_fp4_allgather():
1439
1454
  output, global_output = get_local_dp_buffer(), output
1455
+
1456
+ if forward_shared_experts is not None:
1457
+ alt_stream.wait_stream(torch.cuda.current_stream())
1458
+ with torch.cuda.stream(alt_stream):
1459
+ forward_shared_experts()
1460
+
1440
1461
  get_tp_group().reduce_scatterv(
1441
1462
  global_output, output=output, sizes=get_dp_global_num_tokens()
1442
1463
  )
1464
+
1465
+ if forward_shared_experts is not None:
1466
+ torch.cuda.current_stream().wait_stream(alt_stream)
1467
+
1443
1468
  return StandardCombineInput(hidden_states=output)
1444
1469
 
1445
1470
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import (
31
31
  QuantizeMethodBase,
32
32
  )
33
33
  from sglang.srt.layers.quantization.utils import is_layer_skipped
34
- from sglang.srt.managers.schedule_batch import global_server_args_dict
34
+ from sglang.srt.server_args import get_global_server_args
35
35
  from sglang.srt.utils import (
36
36
  direct_register_custom_op,
37
37
  is_cuda,
@@ -41,7 +41,6 @@ from sglang.srt.utils import (
41
41
  is_triton_kernels_available,
42
42
  log_info_on_rank0,
43
43
  mxfp_supported,
44
- next_power_of_2,
45
44
  round_up,
46
45
  set_weight_attrs,
47
46
  )
@@ -262,25 +261,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
262
261
 
263
262
  self.prefix = prefix
264
263
  self.topk_indices_dtype = None
265
- self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
264
+ self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
266
265
  self.with_bias = False
267
266
  self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
268
- self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
269
- "flashinfer_mxfp4_moe_precision"
270
- ]
271
-
272
- self.triton_kernel_moe_forward = None
273
- self.triton_kernel_moe_with_bias_forward = None
274
- if torch.cuda.is_available() and has_triton_kernels:
275
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
276
- triton_kernel_moe_forward as _tk_forward,
277
- )
278
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
279
- triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
280
- )
281
-
282
- self.triton_kernel_moe_forward = _tk_forward
283
- self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
267
+ self.flashinfer_mxfp4_moe_precision = (
268
+ get_global_server_args().flashinfer_mxfp4_moe_precision
269
+ )
284
270
 
285
271
  def create_weights(
286
272
  self,
@@ -597,35 +583,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
597
583
  layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
598
584
  torch.cuda.empty_cache()
599
585
 
600
- def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
601
- # Number of tokens in the input tensor.
602
- num_tokens = x.shape[0]
603
- # Factor to account for the imbalance of the experts.
604
- # factor equals to the
605
- # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
606
- # - 1.0 means perfect expert distribution.
607
- # - > 1.0 means some experts have more
608
- # tokens than the perfect distribution.
609
- # - < 1.0 does not make sense.
610
- imbalance_factor = 1.3
611
- # Calculate the number of tokens per expert
612
- # assuming perfect distribution.
613
- num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
614
- # Apply the imbalance factor.
615
- num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
616
- # And pad the number to the next power of 2.
617
- tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
618
- # Cap to 8-64 tokens per CTA tile
619
- # as it's the range supported by the kernel.
620
- tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
621
-
622
- return tile_tokens_dim
623
-
624
586
  def create_moe_runner(
625
587
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
626
588
  ):
627
589
  self.moe_runner_config = moe_runner_config
628
- self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
590
+ backend = (
591
+ MoeRunnerBackend.TRITON_KERNELS
592
+ if self.use_triton_kernels
593
+ else MoeRunnerBackend.TRITON
594
+ )
595
+ self.runner = MoeRunner(backend, moe_runner_config)
629
596
 
630
597
  def apply(
631
598
  self,
@@ -696,37 +663,37 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
696
663
  layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
697
664
  layer.num_local_experts, # local num experts
698
665
  None,
699
- self._get_tile_tokens_dim(x, top_k),
666
+ None, # tile_tokens_dim
700
667
  1, # routing_method_type, renormalize
701
668
  True, # do finalize
702
669
  )[0]
703
670
  return StandardCombineInput(hidden_states=trtllm_gen_output)
704
671
 
705
- if self.use_triton_kernels:
672
+ backend = self.runner.runner_backend
673
+ if backend.is_triton_kernels():
674
+ from sglang.srt.layers.moe.moe_runner.triton_kernels import (
675
+ TritonKernelsQuantInfo,
676
+ )
677
+
706
678
  assert (
707
679
  layer.moe_ep_size == 1
708
680
  ), "Expert parallel is not supported when using triton kernels"
709
- if self.with_bias:
710
- output = self.triton_kernel_moe_with_bias_forward(
711
- hidden_states=x,
712
- w1=self.w13_weight_triton_tensor,
713
- w1_pcg=self.w13_precision_config,
714
- w2=self.w2_weight_triton_tensor,
715
- w2_pcg=self.w2_precision_config,
716
- b1=layer.w13_weight_bias,
717
- b2=layer.w2_weight_bias,
718
- topk_output=topk_output,
719
- moe_runner_config=moe_runner_config,
720
- )
721
- else:
722
- output = self.triton_kernel_moe_forward(
723
- hidden_states=x,
724
- w1=layer.w13_weight,
725
- w2=layer.w2_weight,
726
- topk_output=topk_output,
727
- moe_runner_config=moe_runner_config,
728
- )
729
- return StandardCombineInput(hidden_states=output)
681
+ quant_info = TritonKernelsQuantInfo(
682
+ w13_weight=(
683
+ self.w13_weight_triton_tensor
684
+ if self.w13_weight_triton_tensor is not None
685
+ else layer.w13_weight
686
+ ),
687
+ w2_weight=(
688
+ self.w2_weight_triton_tensor
689
+ if self.w2_weight_triton_tensor is not None
690
+ else layer.w2_weight
691
+ ),
692
+ w13_bias=getattr(layer, "w13_weight_bias", None),
693
+ w2_bias=getattr(layer, "w2_weight_bias", None),
694
+ w13_precision_config=getattr(self, "w13_precision_config", None),
695
+ w2_precision_config=getattr(self, "w2_precision_config", None),
696
+ )
730
697
  else:
731
698
  quant_info = TritonMoeQuantInfo(
732
699
  w13_weight=layer.w13_weight,
@@ -734,7 +701,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
734
701
  b13=getattr(layer, "w13_weight_bias", None),
735
702
  b2=getattr(layer, "w2_weight_bias", None),
736
703
  )
737
- return self.runner.run(dispatch_output, quant_info)
704
+ return self.runner.run(dispatch_output, quant_info)
738
705
 
739
706
 
740
707
  class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
@@ -2,7 +2,7 @@
2
2
 
3
3
 
4
4
  import logging
5
- from typing import Any, Callable, Dict, List, Optional
5
+ from typing import Any, Dict, List, Optional
6
6
 
7
7
  import regex as re
8
8
  import torch
@@ -65,7 +65,9 @@ class QuarkConfig(QuantizationConfig):
65
65
  if should_ignore_layer(
66
66
  prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
67
67
  ):
68
- return UnquantizedLinearMethod()
68
+ if isinstance(layer, LinearBase):
69
+ return UnquantizedLinearMethod()
70
+ return None
69
71
 
70
72
  if isinstance(layer, LinearBase):
71
73
  scheme = self.get_scheme(layer=layer, layer_name=prefix)
@@ -3,16 +3,16 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import logging
6
- from typing import TYPE_CHECKING, Any, Callable, Optional
6
+ from typing import TYPE_CHECKING, Any
7
7
 
8
8
  import torch
9
- from aiter import ActivationType, QuantType, biased_grouped_topk
9
+ from aiter import ActivationType, QuantType
10
10
  from aiter.fused_moe import fused_moe
11
11
  from aiter.utility.fp4_utils import e8m0_shuffle
12
12
 
13
13
  from sglang.srt.layers.moe import MoeRunnerConfig
14
14
  from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
15
- from sglang.srt.utils import get_bool_env_var, is_hip, mxfp_supported, set_weight_attrs
15
+ from sglang.srt.utils import is_hip, set_weight_attrs
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from sglang.srt.layers.moe.token_dispatcher import (
@@ -2,20 +2,13 @@
2
2
 
3
3
  from typing import Any, Callable, Optional
4
4
 
5
- import aiter
6
5
  import torch
7
- import torch.nn.functional as F
8
- from aiter.ops.gemm_op_a4w4 import gemm_a4w4
9
- from aiter.ops.shuffle import shuffle_weight
10
6
  from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
11
7
  from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
12
8
  from aiter.ops.triton.quant import dynamic_mxfp4_quant
13
- from aiter.utility import dtypes
14
- from aiter.utility.fp4_utils import e8m0_shuffle
15
9
 
16
10
  from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
17
11
  from sglang.srt.layers.quantization.quark.schemes import QuarkScheme
18
- from sglang.srt.utils import get_bool_env_var
19
12
 
20
13
  __all__ = ["QuarkW4A4MXFP4"]
21
14