sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -47,6 +47,7 @@ if TYPE_CHECKING:
47
47
  CombineInput,
48
48
  StandardDispatchOutput,
49
49
  )
50
+ from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
50
51
 
51
52
  if is_cuda():
52
53
  from sgl_kernel import scaled_fp4_quant
@@ -77,12 +78,62 @@ logger = logging.getLogger(__name__)
77
78
  CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
78
79
  "SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
79
80
  )
81
+ USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
82
+ "SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM", "true"
83
+ )
84
+ # TODO make it true by default when the DeepEP PR is merged
85
+ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
86
+ "SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH", "false"
87
+ )
80
88
 
81
89
  # Supported activation schemes for the current configuration
82
90
  ACTIVATION_SCHEMES = ["static"]
83
91
 
84
92
 
85
- class ModelOptFp8Config(QuantizationConfig):
93
+ class ModelOptQuantConfig(QuantizationConfig):
94
+ def __init__(
95
+ self,
96
+ kv_cache_quant_algo: Optional[str],
97
+ exclude_modules: Optional[List[str]],
98
+ packed_modules_mapping: Optional[Dict[str, List[str]]],
99
+ ):
100
+ super().__init__()
101
+ self.packed_modules_mapping = packed_modules_mapping
102
+ self.exclude_modules = exclude_modules or []
103
+ self.kv_cache_quant_algo = kv_cache_quant_algo
104
+
105
+ def _get_quant_method(
106
+ self,
107
+ layer: torch.nn.Module,
108
+ prefix: str,
109
+ *,
110
+ Linear: type[LinearMethodBase],
111
+ Moe: type[FusedMoEMethodBase],
112
+ ) -> Optional[QuantizeMethodBase]:
113
+ from sglang.srt.layers.linear import LinearBase
114
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
115
+
116
+ if isinstance(layer, LinearBase):
117
+ if is_layer_skipped(
118
+ prefix, self.exclude_modules, self.packed_modules_mapping
119
+ ) or self.is_layer_excluded(prefix):
120
+ return UnquantizedLinearMethod()
121
+ return Linear(self)
122
+ elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
123
+ return ModelOptFp8KVCacheMethod(self)
124
+ elif isinstance(layer, FusedMoE):
125
+ return Moe(self)
126
+ return None
127
+
128
+ @classmethod
129
+ def get_config_filenames(cls) -> List[str]:
130
+ return ["hf_quant_config.json"]
131
+
132
+ def get_scaled_act_names(self) -> List[str]:
133
+ return []
134
+
135
+
136
+ class ModelOptFp8Config(ModelOptQuantConfig):
86
137
  """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
87
138
 
88
139
  def __init__(
@@ -90,22 +141,27 @@ class ModelOptFp8Config(QuantizationConfig):
90
141
  is_checkpoint_fp8_serialized: bool = False,
91
142
  kv_cache_quant_method: Optional[str] = None,
92
143
  exclude_modules: Optional[List[str]] = None,
144
+ packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
93
145
  ) -> None:
94
146
  """
95
147
  Args:
96
148
  is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
97
149
  """
150
+ super().__init__(kv_cache_quant_method, exclude_modules, packed_modules_mapping)
98
151
  self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
99
- self.kv_cache_quant_method = kv_cache_quant_method
100
- self.exclude_modules = exclude_modules
101
152
  if is_checkpoint_fp8_serialized:
102
153
  logger.warning(
103
154
  "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
104
155
  )
105
156
 
157
+ @classmethod
158
+ def override_quantization_method(cls, hf_quant_config, user_quant):
159
+ """Override quantization method based on the model's config."""
160
+ return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
161
+
106
162
  @classmethod
107
163
  def get_name(cls) -> str:
108
- return "modelopt"
164
+ return "modelopt_fp8"
109
165
 
110
166
  @classmethod
111
167
  def get_supported_act_dtypes(cls) -> List[torch.dtype]:
@@ -115,10 +171,6 @@ class ModelOptFp8Config(QuantizationConfig):
115
171
  def get_min_capability(cls) -> int:
116
172
  return 89 # Minimum hardware capability (e.g., Hopper GPUs).
117
173
 
118
- @classmethod
119
- def get_config_filenames(cls) -> List[str]:
120
- return ["hf_quant_config.json"]
121
-
122
174
  @classmethod
123
175
  def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
124
176
  # Handle two different config formats:
@@ -173,37 +225,27 @@ class ModelOptFp8Config(QuantizationConfig):
173
225
  is_checkpoint_fp8_serialized=True,
174
226
  kv_cache_quant_method=kv_cache_quant_method,
175
227
  exclude_modules=exclude_modules,
228
+ packed_modules_mapping=config.get("packed_modules_mapping"),
176
229
  )
177
230
 
178
- def get_quant_method(
179
- self, layer: torch.nn.Module, prefix: str
180
- ) -> Optional[QuantizeMethodBase]:
181
-
182
- from sglang.srt.layers.linear import LinearBase
183
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
184
-
185
- if self.exclude_modules and any(
231
+ def is_layer_excluded(self, prefix: str) -> bool:
232
+ if len(self.exclude_modules) == 0:
233
+ return False
234
+ return any(
186
235
  module in prefix
187
236
  or (
188
237
  prefix.startswith("language_model.")
189
238
  and module in prefix.removeprefix("language_model.")
190
239
  )
191
240
  for module in self.exclude_modules
192
- ):
193
- return None
194
-
195
- if isinstance(layer, LinearBase):
196
- return ModelOptFp8LinearMethod(self)
197
- if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
198
- return ModelOptFp8KVCacheMethod(self)
199
-
200
- if isinstance(layer, FusedMoE):
201
- return ModelOptFp8MoEMethod(self)
202
-
203
- return None
241
+ )
204
242
 
205
- def get_scaled_act_names(self) -> List[str]:
206
- return []
243
+ def get_quant_method(
244
+ self, layer: torch.nn.Module, prefix: str
245
+ ) -> Optional[QuantizeMethodBase]:
246
+ return self._get_quant_method(
247
+ layer, prefix, Linear=ModelOptFp8LinearMethod, Moe=ModelOptFp8MoEMethod
248
+ )
207
249
 
208
250
 
209
251
  class ModelOptFp8LinearMethod(LinearMethodBase):
@@ -499,7 +541,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
499
541
  return self.runner.run(dispatch_output, quant_info)
500
542
 
501
543
 
502
- class ModelOptFp4Config(QuantizationConfig):
544
+ class ModelOptFp4Config(ModelOptQuantConfig):
503
545
  """Config class for FP4."""
504
546
 
505
547
  def __init__(
@@ -508,7 +550,9 @@ class ModelOptFp4Config(QuantizationConfig):
508
550
  kv_cache_quant_algo: str = None,
509
551
  group_size: int = None,
510
552
  exclude_modules: List[str] = None,
553
+ packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
511
554
  ) -> None:
555
+ super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping)
512
556
  self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
513
557
  if is_checkpoint_nvfp4_serialized:
514
558
  logger.warning(
@@ -516,8 +560,11 @@ class ModelOptFp4Config(QuantizationConfig):
516
560
  "format is experimental and subject to change."
517
561
  )
518
562
  self.group_size = group_size
519
- self.kv_cache_quant_algo = kv_cache_quant_algo
520
- self.exclude_modules = exclude_modules
563
+
564
+ @classmethod
565
+ def override_quantization_method(cls, hf_quant_config, user_quant):
566
+ """Override quantization method based on the model's config."""
567
+ return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
521
568
 
522
569
  @classmethod
523
570
  def get_name(cls) -> str:
@@ -531,10 +578,6 @@ class ModelOptFp4Config(QuantizationConfig):
531
578
  def get_min_capability(cls) -> int:
532
579
  return 100
533
580
 
534
- @classmethod
535
- def get_config_filenames(cls) -> List[str]:
536
- return ["hf_quant_config.json"]
537
-
538
581
  @staticmethod
539
582
  def common_group_size(cfg: dict) -> int:
540
583
  """Return the unique group_size across the config; raise if missing/mismatched."""
@@ -600,7 +643,16 @@ class ModelOptFp4Config(QuantizationConfig):
600
643
  else:
601
644
  kv_cache_quant_algo = "auto"
602
645
 
603
- group_size = ModelOptFp4Config.common_group_size(config)
646
+ group_size = config.get("group_size")
647
+ # If group_size is not at top level, try to extract from config_groups
648
+ if group_size is None:
649
+ config_groups = config.get("config_groups", {})
650
+ if config_groups:
651
+ # Get group_size from the first group's weights config
652
+ first_group = next(iter(config_groups.values()), {})
653
+ weights_config = first_group.get("weights", {})
654
+ group_size = weights_config.get("group_size")
655
+
604
656
  exclude_modules = config.get("ignore", [])
605
657
  else:
606
658
  # Fall back to nested format (hf_quant_config.json - legacy format)
@@ -626,29 +678,30 @@ class ModelOptFp4Config(QuantizationConfig):
626
678
  )
627
679
  is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
628
680
 
629
- if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
681
+ if group_size is None or exclude_modules is None:
630
682
  logger.warning(
631
683
  f"group_size: {group_size},"
632
684
  f"kv_cache_quant_algo: {kv_cache_quant_algo},"
633
685
  f"exclude_modules: {exclude_modules}"
634
686
  )
635
687
  raise ValueError(
636
- "NVFP4 quantization requires group size and "
637
- "kv_cache_quant_algo specified in the quantization config"
688
+ "NVFP4 quantization requires group_size and exclude_modules "
689
+ "specified in the quantization config"
638
690
  )
639
691
  return cls(
640
692
  is_checkpoint_nvfp4_serialized,
641
693
  kv_cache_quant_algo,
642
694
  group_size,
643
695
  exclude_modules,
696
+ config.get("packed_modules_mapping"),
644
697
  )
645
698
 
646
- def is_layer_excluded(self, prefix: str, exclude_modules: list):
699
+ def is_layer_excluded(self, prefix: str):
647
700
  import regex as re
648
701
 
649
702
  fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
650
703
  prefix_split = prefix.split(".")
651
- for pattern in exclude_modules:
704
+ for pattern in self.exclude_modules:
652
705
  regex_str = pattern.replace(".", r"\.").replace("*", r".*")
653
706
  pattern_split = pattern.split(".")
654
707
  if re.fullmatch(regex_str, prefix):
@@ -664,30 +717,13 @@ class ModelOptFp4Config(QuantizationConfig):
664
717
  return True
665
718
  return False
666
719
 
667
- def get_quant_method(
668
- self, layer: torch.nn.Module, prefix: str
669
- ) -> Optional[QuantizeMethodBase]:
670
- from sglang.srt.layers.linear import LinearBase
671
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
672
- from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
673
-
674
- if isinstance(layer, LinearBase):
675
- if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
676
- prefix, self.exclude_modules
677
- ):
678
- return UnquantizedLinearMethod()
679
- return ModelOptFp4LinearMethod(self)
680
- if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
681
- return ModelOptFp8KVCacheMethod(self)
682
- elif isinstance(layer, FlashInferFP4MoE):
683
- # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
684
- return ModelOptNvFp4FusedMoEMethod(self)
685
- elif isinstance(layer, FusedMoE):
686
- return ModelOptNvFp4FusedMoEMethod(self)
687
- return None
688
-
689
- def get_scaled_act_names(self) -> List[str]:
690
- return []
720
+ def get_quant_method(self, layer: torch.nn.Module, prefix: str):
721
+ return self._get_quant_method(
722
+ layer,
723
+ prefix,
724
+ Linear=ModelOptFp4LinearMethod,
725
+ Moe=ModelOptNvFp4FusedMoEMethod, # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
726
+ )
691
727
 
692
728
 
693
729
  class ModelOptFp4LinearMethod(LinearMethodBase):
@@ -851,6 +887,7 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
851
887
  w_scale_interleaved,
852
888
  layer.alpha,
853
889
  output_dtype,
890
+ **(dict(backend="cutlass") if USE_CUTLASS_BACKEND_FOR_FP4_GEMM else dict()),
854
891
  )
855
892
  if bias is not None:
856
893
  out = out + bias
@@ -1050,19 +1087,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1050
1087
  intermediate_size,
1051
1088
  num_experts,
1052
1089
  ):
1053
- from flashinfer import (
1054
- RoutingMethodType,
1055
- e2m1_and_ufp8sf_scale_to_float,
1056
- fp4_quantize,
1057
- next_positive_power_of_2,
1058
- nvfp4_block_scale_interleave,
1059
- reorder_rows_for_gated_act_gemm,
1060
- shuffle_matrix_a,
1061
- shuffle_matrix_sf_a,
1062
- )
1090
+ from flashinfer import nvfp4_block_scale_interleave
1063
1091
  from flashinfer.fused_moe.core import (
1064
- _maybe_get_cached_w2_permute_indices,
1065
1092
  _maybe_get_cached_w3_w1_permute_indices,
1093
+ get_w2_permute_indices_with_cache,
1066
1094
  )
1067
1095
 
1068
1096
  """Prepare quantized weights for kernel (done offline with weights)."""
@@ -1123,7 +1151,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1123
1151
  )
1124
1152
  )
1125
1153
 
1126
- permute_indices = _maybe_get_cached_w2_permute_indices(
1154
+ permute_indices = get_w2_permute_indices_with_cache(
1127
1155
  self._cache_permute_indices,
1128
1156
  gemm2_weights_fp4[i].view(torch.uint8),
1129
1157
  epilogue_tile_m,
@@ -1134,7 +1162,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1134
1162
  .contiguous()
1135
1163
  )
1136
1164
 
1137
- permute_sf_indices = _maybe_get_cached_w2_permute_indices(
1165
+ permute_sf_indices = get_w2_permute_indices_with_cache(
1138
1166
  self._cache_permute_indices,
1139
1167
  gemm2_scales_linear_fp4[i].view(torch.uint8),
1140
1168
  epilogue_tile_m,
@@ -1220,6 +1248,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1220
1248
 
1221
1249
  w13_input_scale = _slice_scale(w13_input_scale)
1222
1250
  w2_input_scale = _slice_scale(w2_input_scale)
1251
+
1252
+ if CUTEDSL_MOE_NVFP4_DISPATCH:
1253
+ assert torch.all(w13_input_scale == w13_input_scale[0])
1254
+ w13_input_scale = w13_input_scale[0]
1223
1255
  else:
1224
1256
  w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1225
1257
  w2_input_scale = layer.w2_input_scale
@@ -1240,6 +1272,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1240
1272
  (1 / w2_input_scale).to(torch.float32), requires_grad=False
1241
1273
  )
1242
1274
 
1275
+ layer.dispatcher.set_quant_config(
1276
+ {"input_global_scale": layer.w13_input_scale_quant}
1277
+ )
1278
+
1243
1279
  # Validate weight scales
1244
1280
  for name, weight_scale in [
1245
1281
  ("w13", layer.w13_weight_scale),
@@ -1343,6 +1379,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1343
1379
  self,
1344
1380
  layer: FusedMoE,
1345
1381
  dispatch_output: StandardDispatchOutput,
1382
+ forward_shared_experts=None,
1383
+ alt_stream=None,
1346
1384
  ) -> CombineInput:
1347
1385
 
1348
1386
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
@@ -1414,9 +1452,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1414
1452
  )[0]
1415
1453
  if should_use_flashinfer_cutlass_moe_fp4_allgather():
1416
1454
  output, global_output = get_local_dp_buffer(), output
1455
+
1456
+ if forward_shared_experts is not None:
1457
+ alt_stream.wait_stream(torch.cuda.current_stream())
1458
+ with torch.cuda.stream(alt_stream):
1459
+ forward_shared_experts()
1460
+
1417
1461
  get_tp_group().reduce_scatterv(
1418
1462
  global_output, output=output, sizes=get_dp_global_num_tokens()
1419
1463
  )
1464
+
1465
+ if forward_shared_experts is not None:
1466
+ torch.cuda.current_stream().wait_stream(alt_stream)
1467
+
1420
1468
  return StandardCombineInput(hidden_states=output)
1421
1469
 
1422
1470
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
@@ -1446,6 +1494,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1446
1494
  x: torch.Tensor,
1447
1495
  masked_m: torch.Tensor,
1448
1496
  moe_runner_config: MoeRunnerConfig,
1497
+ down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
1449
1498
  ) -> torch.Tensor:
1450
1499
  assert (
1451
1500
  moe_runner_config.activation == "silu"
@@ -1462,7 +1511,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1462
1511
 
1463
1512
  out = flashinfer_cutedsl_moe_masked(
1464
1513
  hidden_states=x,
1465
- input_global_scale=layer.w13_input_scale_quant,
1514
+ input_global_scale=(
1515
+ None if CUTEDSL_MOE_NVFP4_DISPATCH else layer.w13_input_scale_quant
1516
+ ),
1466
1517
  w1=layer.w13_weight,
1467
1518
  w1_blockscale=layer.w13_blockscale_swizzled,
1468
1519
  w1_alpha=layer.g1_alphas,
@@ -1471,5 +1522,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1471
1522
  w2_blockscale=layer.w2_blockscale_swizzled,
1472
1523
  w2_alpha=layer.g2_alphas,
1473
1524
  masked_m=masked_m,
1525
+ **(
1526
+ dict(
1527
+ down_sm_count=down_gemm_overlap_args.num_sms,
1528
+ down_signals=down_gemm_overlap_args.signal,
1529
+ down_start_event=down_gemm_overlap_args.start_event,
1530
+ )
1531
+ if down_gemm_overlap_args is not None
1532
+ else {}
1533
+ ),
1474
1534
  )
1475
1535
  return out
@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import (
31
31
  QuantizeMethodBase,
32
32
  )
33
33
  from sglang.srt.layers.quantization.utils import is_layer_skipped
34
- from sglang.srt.managers.schedule_batch import global_server_args_dict
34
+ from sglang.srt.server_args import get_global_server_args
35
35
  from sglang.srt.utils import (
36
36
  direct_register_custom_op,
37
37
  is_cuda,
@@ -41,7 +41,6 @@ from sglang.srt.utils import (
41
41
  is_triton_kernels_available,
42
42
  log_info_on_rank0,
43
43
  mxfp_supported,
44
- next_power_of_2,
45
44
  round_up,
46
45
  set_weight_attrs,
47
46
  )
@@ -265,9 +264,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
265
264
  self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
266
265
  self.with_bias = False
267
266
  self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
268
- self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
269
- "flashinfer_mxfp4_moe_precision"
270
- ]
267
+ self.flashinfer_mxfp4_moe_precision = (
268
+ get_global_server_args().flashinfer_mxfp4_moe_precision
269
+ )
271
270
 
272
271
  self.triton_kernel_moe_forward = None
273
272
  self.triton_kernel_moe_with_bias_forward = None
@@ -597,30 +596,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
597
596
  layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
598
597
  torch.cuda.empty_cache()
599
598
 
600
- def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
601
- # Number of tokens in the input tensor.
602
- num_tokens = x.shape[0]
603
- # Factor to account for the imbalance of the experts.
604
- # factor equals to the
605
- # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
606
- # - 1.0 means perfect expert distribution.
607
- # - > 1.0 means some experts have more
608
- # tokens than the perfect distribution.
609
- # - < 1.0 does not make sense.
610
- imbalance_factor = 1.3
611
- # Calculate the number of tokens per expert
612
- # assuming perfect distribution.
613
- num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
614
- # Apply the imbalance factor.
615
- num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
616
- # And pad the number to the next power of 2.
617
- tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
618
- # Cap to 8-64 tokens per CTA tile
619
- # as it's the range supported by the kernel.
620
- tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
621
-
622
- return tile_tokens_dim
623
-
624
599
  def create_moe_runner(
625
600
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
626
601
  ):
@@ -696,7 +671,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
696
671
  layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
697
672
  layer.num_local_experts, # local num experts
698
673
  None,
699
- self._get_tile_tokens_dim(x, top_k),
674
+ None, # tile_tokens_dim
700
675
  1, # routing_method_type, renormalize
701
676
  True, # do finalize
702
677
  )[0]
@@ -731,8 +706,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
731
706
  quant_info = TritonMoeQuantInfo(
732
707
  w13_weight=layer.w13_weight,
733
708
  w2_weight=layer.w2_weight,
734
- w13_weight_bias=layer.w13_weight_bias,
735
- w2_weight_bias=layer.w2_weight_bias,
709
+ b13=getattr(layer, "w13_weight_bias", None),
710
+ b2=getattr(layer, "w2_weight_bias", None),
736
711
  )
737
712
  return self.runner.run(dispatch_output, quant_info)
738
713
 
@@ -843,10 +818,18 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
843
818
  topk_weights = topk_weights.to(
844
819
  torch.float32
845
820
  ) # aiter's moe_sorting requires topk_weights to be FP32
821
+
822
+ if hasattr(torch, "float4_e2m1fn_x2"):
823
+ w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
824
+ w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
825
+ else:
826
+ w13_weight = layer.w13_weight
827
+ w2_weight = layer.w2_weight
828
+
846
829
  output = fused_moe(
847
830
  x,
848
- layer.w13_weight,
849
- layer.w2_weight,
831
+ w13_weight,
832
+ w2_weight,
850
833
  topk_weights,
851
834
  topk_ids,
852
835
  quant_type=QuantType.per_1x32,
@@ -2,7 +2,7 @@
2
2
 
3
3
 
4
4
  import logging
5
- from typing import Any, Callable, Dict, List, Optional
5
+ from typing import Any, Dict, List, Optional
6
6
 
7
7
  import regex as re
8
8
  import torch
@@ -65,7 +65,9 @@ class QuarkConfig(QuantizationConfig):
65
65
  if should_ignore_layer(
66
66
  prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
67
67
  ):
68
- return UnquantizedLinearMethod()
68
+ if isinstance(layer, LinearBase):
69
+ return UnquantizedLinearMethod()
70
+ return None
69
71
 
70
72
  if isinstance(layer, LinearBase):
71
73
  scheme = self.get_scheme(layer=layer, layer_name=prefix)
@@ -3,16 +3,16 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import logging
6
- from typing import TYPE_CHECKING, Any, Callable, Optional
6
+ from typing import TYPE_CHECKING, Any
7
7
 
8
8
  import torch
9
- from aiter import ActivationType, QuantType, biased_grouped_topk
9
+ from aiter import ActivationType, QuantType
10
10
  from aiter.fused_moe import fused_moe
11
11
  from aiter.utility.fp4_utils import e8m0_shuffle
12
12
 
13
13
  from sglang.srt.layers.moe import MoeRunnerConfig
14
14
  from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
15
- from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
15
+ from sglang.srt.utils import is_hip, set_weight_attrs
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from sglang.srt.layers.moe.token_dispatcher import (
@@ -23,6 +23,8 @@ if TYPE_CHECKING:
23
23
 
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
+ _is_hip = is_hip()
27
+
26
28
  __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
27
29
 
28
30
  OCP_MX_BLOCK_SIZE = 32
@@ -182,11 +184,22 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
182
184
  topk_output = dispatch_output.topk_output
183
185
  moe_runner_config = self.moe_runner_config
184
186
  topk_weights, topk_ids, _ = topk_output
187
+ if _is_hip:
188
+ topk_weights = topk_weights.to(
189
+ torch.float32
190
+ ) # aiter's moe_sorting requires topk_weights to be FP32
191
+
192
+ if hasattr(torch, "float4_e2m1fn_x2"):
193
+ w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
194
+ w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
195
+ else:
196
+ w13_weight = layer.w13_weight
197
+ w2_weight = layer.w2_weight
185
198
 
186
199
  output = fused_moe(
187
200
  x,
188
- layer.w13_weight,
189
- layer.w2_weight,
201
+ w13_weight,
202
+ w2_weight,
190
203
  topk_weights,
191
204
  topk_ids,
192
205
  quant_type=QuantType.per_1x32,
@@ -2,20 +2,13 @@
2
2
 
3
3
  from typing import Any, Callable, Optional
4
4
 
5
- import aiter
6
5
  import torch
7
- import torch.nn.functional as F
8
- from aiter.ops.gemm_op_a4w4 import gemm_a4w4
9
- from aiter.ops.shuffle import shuffle_weight
10
6
  from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
11
7
  from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
12
8
  from aiter.ops.triton.quant import dynamic_mxfp4_quant
13
- from aiter.utility import dtypes
14
- from aiter.utility.fp4_utils import e8m0_shuffle
15
9
 
16
10
  from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
17
11
  from sglang.srt.layers.quantization.quark.schemes import QuarkScheme
18
- from sglang.srt.utils import get_bool_env_var
19
12
 
20
13
  __all__ = ["QuarkW4A4MXFP4"]
21
14
 
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import importlib.util
4
3
  from typing import TYPE_CHECKING, List, Optional
5
4
 
6
5
  import torch
@@ -31,8 +30,6 @@ if TYPE_CHECKING:
31
30
  StandardDispatchOutput,
32
31
  )
33
32
 
34
- has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
35
-
36
33
 
37
34
  _is_cpu_amx_available = cpu_has_amx_support()
38
35
  _is_hip = is_hip()
@@ -143,7 +140,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
143
140
 
144
141
  self.triton_kernel_moe_forward = None
145
142
  self.triton_kernel_moe_with_bias_forward = None
146
- if torch.cuda.is_available() and has_triton_kernels:
143
+ if torch.cuda.is_available() and use_triton_kernels:
147
144
  from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
148
145
  triton_kernel_moe_forward as _tk_forward,
149
146
  )
@@ -11,7 +11,6 @@ import numpy
11
11
  import torch
12
12
 
13
13
  from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
14
- from sglang.srt.utils import is_cuda
15
14
 
16
15
  if TYPE_CHECKING:
17
16
  from sglang.srt.layers.quantization.base_config import QuantizationConfig