sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -199,7 +199,6 @@ class GPTQConfig(QuantizationConfig):
199
199
  self, layer: torch.nn.Module, prefix: str
200
200
  ) -> Optional[LinearMethodBase]:
201
201
  # Delay the import to avoid circular dependency
202
- from sglang.srt.layers.linear import LinearBase
203
202
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
204
203
 
205
204
  if isinstance(layer, FusedMoE):
@@ -12,7 +12,15 @@ from sglang.srt.utils import get_device_name, is_cuda
12
12
 
13
13
  _is_cuda = is_cuda()
14
14
  if _is_cuda:
15
- from sgl_kernel import sgl_per_token_group_quant_int8
15
+ # Temporary
16
+ try:
17
+ from sgl_kernel import sgl_per_token_group_quant_8bit
18
+
19
+ enable_sgl_per_token_group_quant_8bit = True
20
+ except ImportError:
21
+ from sgl_kernel import sgl_per_token_group_quant_int8
22
+
23
+ enable_sgl_per_token_group_quant_8bit = False
16
24
 
17
25
  logger = logging.getLogger(__name__)
18
26
 
@@ -187,6 +195,7 @@ def sglang_per_token_group_quant_int8(
187
195
  group_size: int,
188
196
  eps: float = 1e-10,
189
197
  dtype: torch.dtype = torch.int8,
198
+ enable_v2: Optional[bool] = None,
190
199
  ):
191
200
  assert (
192
201
  x.shape[-1] % group_size == 0
@@ -204,7 +213,14 @@ def sglang_per_token_group_quant_int8(
204
213
  dtype=torch.float32,
205
214
  )
206
215
 
207
- sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
216
+ # Temporary
217
+ if enable_sgl_per_token_group_quant_8bit:
218
+ sgl_per_token_group_quant_8bit(
219
+ x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
220
+ )
221
+ else:
222
+ assert not enable_v2
223
+ sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
208
224
 
209
225
  return x_q, x_s
210
226
 
@@ -4,6 +4,7 @@
4
4
  from __future__ import annotations
5
5
 
6
6
  import logging
7
+ from dataclasses import dataclass
7
8
  from typing import TYPE_CHECKING, Any, Optional
8
9
 
9
10
  import numpy
@@ -57,6 +58,17 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
57
58
  USE_FP32_REDUCE_DEFAULT = True
58
59
 
59
60
 
61
+ @dataclass
62
+ class MarlinLinearLayerConfig:
63
+ full_weight_shape: tuple[int, int] # [in, out]
64
+ partition_weight_shape: tuple[int, int]
65
+ weight_type: ScalarType
66
+ act_type: torch.dtype
67
+ group_size: int
68
+ zero_points: bool
69
+ has_g_idx: bool
70
+
71
+
60
72
  # For binary size and compile time, we don't support the same types for with and
61
73
  # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
62
74
  # TODO: we may want to move this into the C++ so its closer to the actual impl
@@ -79,7 +79,7 @@ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
79
79
  "SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
80
80
  )
81
81
  USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
82
- "SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
82
+ "SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM", "true"
83
83
  )
84
84
  # TODO make it true by default when the DeepEP PR is merged
85
85
  CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
@@ -90,7 +90,50 @@ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
90
90
  ACTIVATION_SCHEMES = ["static"]
91
91
 
92
92
 
93
- class ModelOptFp8Config(QuantizationConfig):
93
+ class ModelOptQuantConfig(QuantizationConfig):
94
+ def __init__(
95
+ self,
96
+ kv_cache_quant_algo: Optional[str],
97
+ exclude_modules: Optional[List[str]],
98
+ packed_modules_mapping: Optional[Dict[str, List[str]]],
99
+ ):
100
+ super().__init__()
101
+ self.packed_modules_mapping = packed_modules_mapping
102
+ self.exclude_modules = exclude_modules or []
103
+ self.kv_cache_quant_algo = kv_cache_quant_algo
104
+
105
+ def _get_quant_method(
106
+ self,
107
+ layer: torch.nn.Module,
108
+ prefix: str,
109
+ *,
110
+ Linear: type[LinearMethodBase],
111
+ Moe: type[FusedMoEMethodBase],
112
+ ) -> Optional[QuantizeMethodBase]:
113
+ from sglang.srt.layers.linear import LinearBase
114
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
115
+
116
+ if isinstance(layer, LinearBase):
117
+ if is_layer_skipped(
118
+ prefix, self.exclude_modules, self.packed_modules_mapping
119
+ ) or self.is_layer_excluded(prefix):
120
+ return UnquantizedLinearMethod()
121
+ return Linear(self)
122
+ elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
123
+ return ModelOptFp8KVCacheMethod(self)
124
+ elif isinstance(layer, FusedMoE):
125
+ return Moe(self)
126
+ return None
127
+
128
+ @classmethod
129
+ def get_config_filenames(cls) -> List[str]:
130
+ return ["hf_quant_config.json"]
131
+
132
+ def get_scaled_act_names(self) -> List[str]:
133
+ return []
134
+
135
+
136
+ class ModelOptFp8Config(ModelOptQuantConfig):
94
137
  """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
95
138
 
96
139
  def __init__(
@@ -98,22 +141,27 @@ class ModelOptFp8Config(QuantizationConfig):
98
141
  is_checkpoint_fp8_serialized: bool = False,
99
142
  kv_cache_quant_method: Optional[str] = None,
100
143
  exclude_modules: Optional[List[str]] = None,
144
+ packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
101
145
  ) -> None:
102
146
  """
103
147
  Args:
104
148
  is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
105
149
  """
150
+ super().__init__(kv_cache_quant_method, exclude_modules, packed_modules_mapping)
106
151
  self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
107
- self.kv_cache_quant_method = kv_cache_quant_method
108
- self.exclude_modules = exclude_modules
109
152
  if is_checkpoint_fp8_serialized:
110
153
  logger.warning(
111
154
  "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
112
155
  )
113
156
 
157
+ @classmethod
158
+ def override_quantization_method(cls, hf_quant_config, user_quant):
159
+ """Override quantization method based on the model's config."""
160
+ return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
161
+
114
162
  @classmethod
115
163
  def get_name(cls) -> str:
116
- return "modelopt"
164
+ return "modelopt_fp8"
117
165
 
118
166
  @classmethod
119
167
  def get_supported_act_dtypes(cls) -> List[torch.dtype]:
@@ -123,10 +171,6 @@ class ModelOptFp8Config(QuantizationConfig):
123
171
  def get_min_capability(cls) -> int:
124
172
  return 89 # Minimum hardware capability (e.g., Hopper GPUs).
125
173
 
126
- @classmethod
127
- def get_config_filenames(cls) -> List[str]:
128
- return ["hf_quant_config.json"]
129
-
130
174
  @classmethod
131
175
  def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
132
176
  # Handle two different config formats:
@@ -181,37 +225,27 @@ class ModelOptFp8Config(QuantizationConfig):
181
225
  is_checkpoint_fp8_serialized=True,
182
226
  kv_cache_quant_method=kv_cache_quant_method,
183
227
  exclude_modules=exclude_modules,
228
+ packed_modules_mapping=config.get("packed_modules_mapping"),
184
229
  )
185
230
 
186
- def get_quant_method(
187
- self, layer: torch.nn.Module, prefix: str
188
- ) -> Optional[QuantizeMethodBase]:
189
-
190
- from sglang.srt.layers.linear import LinearBase
191
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
192
-
193
- if self.exclude_modules and any(
231
+ def is_layer_excluded(self, prefix: str) -> bool:
232
+ if len(self.exclude_modules) == 0:
233
+ return False
234
+ return any(
194
235
  module in prefix
195
236
  or (
196
237
  prefix.startswith("language_model.")
197
238
  and module in prefix.removeprefix("language_model.")
198
239
  )
199
240
  for module in self.exclude_modules
200
- ):
201
- return None
202
-
203
- if isinstance(layer, LinearBase):
204
- return ModelOptFp8LinearMethod(self)
205
- if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
206
- return ModelOptFp8KVCacheMethod(self)
207
-
208
- if isinstance(layer, FusedMoE):
209
- return ModelOptFp8MoEMethod(self)
210
-
211
- return None
241
+ )
212
242
 
213
- def get_scaled_act_names(self) -> List[str]:
214
- return []
243
+ def get_quant_method(
244
+ self, layer: torch.nn.Module, prefix: str
245
+ ) -> Optional[QuantizeMethodBase]:
246
+ return self._get_quant_method(
247
+ layer, prefix, Linear=ModelOptFp8LinearMethod, Moe=ModelOptFp8MoEMethod
248
+ )
215
249
 
216
250
 
217
251
  class ModelOptFp8LinearMethod(LinearMethodBase):
@@ -507,7 +541,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
507
541
  return self.runner.run(dispatch_output, quant_info)
508
542
 
509
543
 
510
- class ModelOptFp4Config(QuantizationConfig):
544
+ class ModelOptFp4Config(ModelOptQuantConfig):
511
545
  """Config class for FP4."""
512
546
 
513
547
  def __init__(
@@ -516,7 +550,9 @@ class ModelOptFp4Config(QuantizationConfig):
516
550
  kv_cache_quant_algo: str = None,
517
551
  group_size: int = None,
518
552
  exclude_modules: List[str] = None,
553
+ packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
519
554
  ) -> None:
555
+ super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping)
520
556
  self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
521
557
  if is_checkpoint_nvfp4_serialized:
522
558
  logger.warning(
@@ -524,8 +560,11 @@ class ModelOptFp4Config(QuantizationConfig):
524
560
  "format is experimental and subject to change."
525
561
  )
526
562
  self.group_size = group_size
527
- self.kv_cache_quant_algo = kv_cache_quant_algo
528
- self.exclude_modules = exclude_modules
563
+
564
+ @classmethod
565
+ def override_quantization_method(cls, hf_quant_config, user_quant):
566
+ """Override quantization method based on the model's config."""
567
+ return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
529
568
 
530
569
  @classmethod
531
570
  def get_name(cls) -> str:
@@ -539,10 +578,6 @@ class ModelOptFp4Config(QuantizationConfig):
539
578
  def get_min_capability(cls) -> int:
540
579
  return 100
541
580
 
542
- @classmethod
543
- def get_config_filenames(cls) -> List[str]:
544
- return ["hf_quant_config.json"]
545
-
546
581
  @staticmethod
547
582
  def common_group_size(cfg: dict) -> int:
548
583
  """Return the unique group_size across the config; raise if missing/mismatched."""
@@ -608,7 +643,16 @@ class ModelOptFp4Config(QuantizationConfig):
608
643
  else:
609
644
  kv_cache_quant_algo = "auto"
610
645
 
611
- group_size = ModelOptFp4Config.common_group_size(config)
646
+ group_size = config.get("group_size")
647
+ # If group_size is not at top level, try to extract from config_groups
648
+ if group_size is None:
649
+ config_groups = config.get("config_groups", {})
650
+ if config_groups:
651
+ # Get group_size from the first group's weights config
652
+ first_group = next(iter(config_groups.values()), {})
653
+ weights_config = first_group.get("weights", {})
654
+ group_size = weights_config.get("group_size")
655
+
612
656
  exclude_modules = config.get("ignore", [])
613
657
  else:
614
658
  # Fall back to nested format (hf_quant_config.json - legacy format)
@@ -634,29 +678,30 @@ class ModelOptFp4Config(QuantizationConfig):
634
678
  )
635
679
  is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
636
680
 
637
- if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
681
+ if group_size is None or exclude_modules is None:
638
682
  logger.warning(
639
683
  f"group_size: {group_size},"
640
684
  f"kv_cache_quant_algo: {kv_cache_quant_algo},"
641
685
  f"exclude_modules: {exclude_modules}"
642
686
  )
643
687
  raise ValueError(
644
- "NVFP4 quantization requires group size and "
645
- "kv_cache_quant_algo specified in the quantization config"
688
+ "NVFP4 quantization requires group_size and exclude_modules "
689
+ "specified in the quantization config"
646
690
  )
647
691
  return cls(
648
692
  is_checkpoint_nvfp4_serialized,
649
693
  kv_cache_quant_algo,
650
694
  group_size,
651
695
  exclude_modules,
696
+ config.get("packed_modules_mapping"),
652
697
  )
653
698
 
654
- def is_layer_excluded(self, prefix: str, exclude_modules: list):
699
+ def is_layer_excluded(self, prefix: str):
655
700
  import regex as re
656
701
 
657
702
  fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
658
703
  prefix_split = prefix.split(".")
659
- for pattern in exclude_modules:
704
+ for pattern in self.exclude_modules:
660
705
  regex_str = pattern.replace(".", r"\.").replace("*", r".*")
661
706
  pattern_split = pattern.split(".")
662
707
  if re.fullmatch(regex_str, prefix):
@@ -672,30 +717,13 @@ class ModelOptFp4Config(QuantizationConfig):
672
717
  return True
673
718
  return False
674
719
 
675
- def get_quant_method(
676
- self, layer: torch.nn.Module, prefix: str
677
- ) -> Optional[QuantizeMethodBase]:
678
- from sglang.srt.layers.linear import LinearBase
679
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
680
- from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
681
-
682
- if isinstance(layer, LinearBase):
683
- if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
684
- prefix, self.exclude_modules
685
- ):
686
- return UnquantizedLinearMethod()
687
- return ModelOptFp4LinearMethod(self)
688
- if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
689
- return ModelOptFp8KVCacheMethod(self)
690
- elif isinstance(layer, FlashInferFP4MoE):
691
- # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
692
- return ModelOptNvFp4FusedMoEMethod(self)
693
- elif isinstance(layer, FusedMoE):
694
- return ModelOptNvFp4FusedMoEMethod(self)
695
- return None
696
-
697
- def get_scaled_act_names(self) -> List[str]:
698
- return []
720
+ def get_quant_method(self, layer: torch.nn.Module, prefix: str):
721
+ return self._get_quant_method(
722
+ layer,
723
+ prefix,
724
+ Linear=ModelOptFp4LinearMethod,
725
+ Moe=ModelOptNvFp4FusedMoEMethod, # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
726
+ )
699
727
 
700
728
 
701
729
  class ModelOptFp4LinearMethod(LinearMethodBase):
@@ -852,25 +880,15 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
852
880
  if enable_flashinfer_fp4_gemm:
853
881
  w = layer.weight.T
854
882
  w_scale_interleaved = layer.weight_scale_interleaved.T
855
- if USE_CUTLASS_BACKEND_FOR_FP4_GEMM:
856
- out = fp4_gemm(
857
- x_fp4,
858
- w,
859
- x_scale_interleaved,
860
- w_scale_interleaved,
861
- layer.alpha,
862
- output_dtype,
863
- backend="cutlass",
864
- )
865
- else:
866
- out = fp4_gemm(
867
- x_fp4,
868
- w,
869
- x_scale_interleaved,
870
- w_scale_interleaved,
871
- layer.alpha,
872
- output_dtype,
873
- )
883
+ out = fp4_gemm(
884
+ x_fp4,
885
+ w,
886
+ x_scale_interleaved,
887
+ w_scale_interleaved,
888
+ layer.alpha,
889
+ output_dtype,
890
+ **(dict(backend="cutlass") if USE_CUTLASS_BACKEND_FOR_FP4_GEMM else dict()),
891
+ )
874
892
  if bias is not None:
875
893
  out = out + bias
876
894
  return out.view(*output_shape)
@@ -1069,19 +1087,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1069
1087
  intermediate_size,
1070
1088
  num_experts,
1071
1089
  ):
1072
- from flashinfer import (
1073
- RoutingMethodType,
1074
- e2m1_and_ufp8sf_scale_to_float,
1075
- fp4_quantize,
1076
- next_positive_power_of_2,
1077
- nvfp4_block_scale_interleave,
1078
- reorder_rows_for_gated_act_gemm,
1079
- shuffle_matrix_a,
1080
- shuffle_matrix_sf_a,
1081
- )
1090
+ from flashinfer import nvfp4_block_scale_interleave
1082
1091
  from flashinfer.fused_moe.core import (
1083
- _maybe_get_cached_w2_permute_indices,
1084
1092
  _maybe_get_cached_w3_w1_permute_indices,
1093
+ get_w2_permute_indices_with_cache,
1085
1094
  )
1086
1095
 
1087
1096
  """Prepare quantized weights for kernel (done offline with weights)."""
@@ -1142,7 +1151,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1142
1151
  )
1143
1152
  )
1144
1153
 
1145
- permute_indices = _maybe_get_cached_w2_permute_indices(
1154
+ permute_indices = get_w2_permute_indices_with_cache(
1146
1155
  self._cache_permute_indices,
1147
1156
  gemm2_weights_fp4[i].view(torch.uint8),
1148
1157
  epilogue_tile_m,
@@ -1153,7 +1162,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1153
1162
  .contiguous()
1154
1163
  )
1155
1164
 
1156
- permute_sf_indices = _maybe_get_cached_w2_permute_indices(
1165
+ permute_sf_indices = get_w2_permute_indices_with_cache(
1157
1166
  self._cache_permute_indices,
1158
1167
  gemm2_scales_linear_fp4[i].view(torch.uint8),
1159
1168
  epilogue_tile_m,
@@ -1263,6 +1272,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1263
1272
  (1 / w2_input_scale).to(torch.float32), requires_grad=False
1264
1273
  )
1265
1274
 
1275
+ layer.dispatcher.set_quant_config(
1276
+ {"input_global_scale": layer.w13_input_scale_quant}
1277
+ )
1278
+
1266
1279
  # Validate weight scales
1267
1280
  for name, weight_scale in [
1268
1281
  ("w13", layer.w13_weight_scale),
@@ -1366,6 +1379,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1366
1379
  self,
1367
1380
  layer: FusedMoE,
1368
1381
  dispatch_output: StandardDispatchOutput,
1382
+ forward_shared_experts=None,
1383
+ alt_stream=None,
1369
1384
  ) -> CombineInput:
1370
1385
 
1371
1386
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
@@ -1437,9 +1452,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1437
1452
  )[0]
1438
1453
  if should_use_flashinfer_cutlass_moe_fp4_allgather():
1439
1454
  output, global_output = get_local_dp_buffer(), output
1455
+
1456
+ if forward_shared_experts is not None:
1457
+ alt_stream.wait_stream(torch.cuda.current_stream())
1458
+ with torch.cuda.stream(alt_stream):
1459
+ forward_shared_experts()
1460
+
1440
1461
  get_tp_group().reduce_scatterv(
1441
1462
  global_output, output=output, sizes=get_dp_global_num_tokens()
1442
1463
  )
1464
+
1465
+ if forward_shared_experts is not None:
1466
+ torch.cuda.current_stream().wait_stream(alt_stream)
1467
+
1443
1468
  return StandardCombineInput(hidden_states=output)
1444
1469
 
1445
1470
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import (
31
31
  QuantizeMethodBase,
32
32
  )
33
33
  from sglang.srt.layers.quantization.utils import is_layer_skipped
34
- from sglang.srt.managers.schedule_batch import global_server_args_dict
34
+ from sglang.srt.server_args import get_global_server_args
35
35
  from sglang.srt.utils import (
36
36
  direct_register_custom_op,
37
37
  is_cuda,
@@ -41,7 +41,6 @@ from sglang.srt.utils import (
41
41
  is_triton_kernels_available,
42
42
  log_info_on_rank0,
43
43
  mxfp_supported,
44
- next_power_of_2,
45
44
  round_up,
46
45
  set_weight_attrs,
47
46
  )
@@ -265,9 +264,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
265
264
  self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
266
265
  self.with_bias = False
267
266
  self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
268
- self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
269
- "flashinfer_mxfp4_moe_precision"
270
- ]
267
+ self.flashinfer_mxfp4_moe_precision = (
268
+ get_global_server_args().flashinfer_mxfp4_moe_precision
269
+ )
271
270
 
272
271
  self.triton_kernel_moe_forward = None
273
272
  self.triton_kernel_moe_with_bias_forward = None
@@ -597,30 +596,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
597
596
  layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
598
597
  torch.cuda.empty_cache()
599
598
 
600
- def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
601
- # Number of tokens in the input tensor.
602
- num_tokens = x.shape[0]
603
- # Factor to account for the imbalance of the experts.
604
- # factor equals to the
605
- # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
606
- # - 1.0 means perfect expert distribution.
607
- # - > 1.0 means some experts have more
608
- # tokens than the perfect distribution.
609
- # - < 1.0 does not make sense.
610
- imbalance_factor = 1.3
611
- # Calculate the number of tokens per expert
612
- # assuming perfect distribution.
613
- num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
614
- # Apply the imbalance factor.
615
- num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
616
- # And pad the number to the next power of 2.
617
- tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
618
- # Cap to 8-64 tokens per CTA tile
619
- # as it's the range supported by the kernel.
620
- tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
621
-
622
- return tile_tokens_dim
623
-
624
599
  def create_moe_runner(
625
600
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
626
601
  ):
@@ -696,7 +671,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
696
671
  layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
697
672
  layer.num_local_experts, # local num experts
698
673
  None,
699
- self._get_tile_tokens_dim(x, top_k),
674
+ None, # tile_tokens_dim
700
675
  1, # routing_method_type, renormalize
701
676
  True, # do finalize
702
677
  )[0]
@@ -2,7 +2,7 @@
2
2
 
3
3
 
4
4
  import logging
5
- from typing import Any, Callable, Dict, List, Optional
5
+ from typing import Any, Dict, List, Optional
6
6
 
7
7
  import regex as re
8
8
  import torch
@@ -65,7 +65,9 @@ class QuarkConfig(QuantizationConfig):
65
65
  if should_ignore_layer(
66
66
  prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
67
67
  ):
68
- return UnquantizedLinearMethod()
68
+ if isinstance(layer, LinearBase):
69
+ return UnquantizedLinearMethod()
70
+ return None
69
71
 
70
72
  if isinstance(layer, LinearBase):
71
73
  scheme = self.get_scheme(layer=layer, layer_name=prefix)
@@ -3,16 +3,16 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import logging
6
- from typing import TYPE_CHECKING, Any, Callable, Optional
6
+ from typing import TYPE_CHECKING, Any
7
7
 
8
8
  import torch
9
- from aiter import ActivationType, QuantType, biased_grouped_topk
9
+ from aiter import ActivationType, QuantType
10
10
  from aiter.fused_moe import fused_moe
11
11
  from aiter.utility.fp4_utils import e8m0_shuffle
12
12
 
13
13
  from sglang.srt.layers.moe import MoeRunnerConfig
14
14
  from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
15
- from sglang.srt.utils import get_bool_env_var, is_hip, mxfp_supported, set_weight_attrs
15
+ from sglang.srt.utils import is_hip, set_weight_attrs
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from sglang.srt.layers.moe.token_dispatcher import (
@@ -2,20 +2,13 @@
2
2
 
3
3
  from typing import Any, Callable, Optional
4
4
 
5
- import aiter
6
5
  import torch
7
- import torch.nn.functional as F
8
- from aiter.ops.gemm_op_a4w4 import gemm_a4w4
9
- from aiter.ops.shuffle import shuffle_weight
10
6
  from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
11
7
  from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
12
8
  from aiter.ops.triton.quant import dynamic_mxfp4_quant
13
- from aiter.utility import dtypes
14
- from aiter.utility.fp4_utils import e8m0_shuffle
15
9
 
16
10
  from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
17
11
  from sglang.srt.layers.quantization.quark.schemes import QuarkScheme
18
- from sglang.srt.utils import get_bool_env_var
19
12
 
20
13
  __all__ = ["QuarkW4A4MXFP4"]
21
14
 
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import importlib.util
4
3
  from typing import TYPE_CHECKING, List, Optional
5
4
 
6
5
  import torch
@@ -31,8 +30,6 @@ if TYPE_CHECKING:
31
30
  StandardDispatchOutput,
32
31
  )
33
32
 
34
- has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
35
-
36
33
 
37
34
  _is_cpu_amx_available = cpu_has_amx_support()
38
35
  _is_hip = is_hip()
@@ -143,7 +140,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
143
140
 
144
141
  self.triton_kernel_moe_forward = None
145
142
  self.triton_kernel_moe_with_bias_forward = None
146
- if torch.cuda.is_available() and has_triton_kernels:
143
+ if torch.cuda.is_available() and use_triton_kernels:
147
144
  from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
148
145
  triton_kernel_moe_forward as _tk_forward,
149
146
  )
@@ -11,7 +11,6 @@ import numpy
11
11
  import torch
12
12
 
13
13
  from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
14
- from sglang.srt.utils import is_cuda
15
14
 
16
15
  if TYPE_CHECKING:
17
16
  from sglang.srt.layers.quantization.base_config import QuantizationConfig