sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
45
45
 
46
46
  if _use_aiter:
47
47
  import aiter
48
- from aiter import gemm_a8w8_blockscale, get_hip_quant
48
+ from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant
49
49
 
50
50
  aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
51
51
 
@@ -248,11 +248,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
248
248
  scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
249
249
  )
250
250
 
251
- # NOTE(alcanderian): Useless when scale is packed to int32
252
- # if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
253
- # _check_ue8m0("x_scale", x_scale)
254
- # _check_ue8m0("weight_scale", ws)
255
-
256
251
  output = w8a8_block_fp8_matmul_deepgemm(
257
252
  q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
258
253
  )
@@ -261,11 +256,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
261
256
  return output.to(dtype=output_dtype).view(*output_shape)
262
257
 
263
258
 
264
- def _check_ue8m0(name, x):
265
- x_ceil = ceil_to_ue8m0(x)
266
- assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
267
-
268
-
269
259
  def aiter_w8a8_block_fp8_linear(
270
260
  input: torch.Tensor,
271
261
  weight: torch.Tensor,
@@ -652,25 +642,49 @@ def apply_fp8_linear(
652
642
  use_per_token_if_dynamic
653
643
  and not per_tensor_weights
654
644
  and not per_tensor_activations
655
- and USE_ROWWISE_TORCH_SCALED_MM
645
+ and (USE_ROWWISE_TORCH_SCALED_MM or _use_aiter)
656
646
  ):
657
- # For now validated on ROCm platform
658
- # fp8 rowwise scaling in torch._scaled_mm is introduced in
659
- # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
660
- # and ROCm 6.3, which only exists in torch 2.7 and above.
661
- # For CUDA platform please validate if the
662
- # torch._scaled_mm support rowwise scaled GEMM
663
- # Fused GEMM_DQ Rowwise GEMM
664
- output = torch._scaled_mm(
665
- qinput,
666
- weight,
667
- out_dtype=input.dtype,
668
- scale_a=x_scale,
669
- scale_b=weight_scale.t(),
670
- bias=bias,
671
- )
672
- return _process_scaled_mm_output(output, input_2d.shape, output_shape)
673
-
647
+ # into this sector means use dynamic per-token-per-channel quant
648
+ # per-token scale quant for input matrix, every row(one token) have one scale factor
649
+ # per-channel scale quant for weight matrix, every col(one channel) have one scale factor
650
+ if _use_aiter:
651
+ # gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype)
652
+ # XQ -> input tensor, shape = (m, k)
653
+ # WQ -> weight tensor, shape = (n, k), with preshuffe get better perf
654
+ # x_scale -> input scale tensor, shape = (m, 1)
655
+ # w_scale -> weight scale tensor, shape = (n ,1)
656
+ # dtype -> output dtype
657
+ output = gemm_a8w8_bpreshuffle(
658
+ XQ=qinput,
659
+ WQ=weight,
660
+ x_scale=x_scale,
661
+ w_scale=weight_scale,
662
+ dtype=input.dtype,
663
+ )
664
+ if bias is not None:
665
+ output += bias
666
+ return _process_scaled_mm_output(
667
+ output, input_2d.shape, [*input.shape[:-1], weight.shape[0]]
668
+ )
669
+ else:
670
+ # For now validated on ROCm platform
671
+ # fp8 rowwise scaling in torch._scaled_mm is introduced in
672
+ # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
673
+ # and ROCm 6.3, which only exists in torch 2.7 and above.
674
+ # For CUDA platform please validate if the
675
+ # torch._scaled_mm support rowwise scaled GEMM
676
+ # Fused GEMM_DQ Rowwise GEMM
677
+ output = torch._scaled_mm(
678
+ qinput,
679
+ weight,
680
+ out_dtype=input.dtype,
681
+ scale_a=x_scale,
682
+ scale_b=weight_scale.t(),
683
+ bias=bias,
684
+ )
685
+ return _process_scaled_mm_output(
686
+ output, input_2d.shape, output_shape
687
+ )
674
688
  else:
675
689
  # Fallback for channelwise case, where we use unfused DQ
676
690
  # due to limitations with scaled_mm
@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
45
45
 
46
46
  if TYPE_CHECKING:
47
47
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
48
- from sglang.srt.layers.moe.topk import TopKOutput
48
+ from sglang.srt.layers.moe.token_dispatcher import (
49
+ StandardDispatchOutput,
50
+ CombineInput,
51
+ )
49
52
 
50
53
  from sglang.srt.utils import is_cuda
51
54
 
@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
838
841
  from sglang.srt.layers.linear import set_weight_attrs
839
842
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
840
843
 
841
- intermediate_size = extra_weight_attrs.pop("intermediate_size")
842
-
843
- self.is_k_full = (not self.quant_config.desc_act) or (
844
- intermediate_size_per_partition == intermediate_size
845
- )
844
+ self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1
846
845
 
847
846
  if self.quant_config.group_size != -1:
848
847
  scales_size13 = hidden_size // self.quant_config.group_size
849
- w2_scales_size = (
850
- intermediate_size
851
- if self.quant_config.desc_act
852
- else intermediate_size_per_partition
853
- )
848
+ if self.quant_config.desc_act:
849
+ w2_scales_size = intermediate_size_per_partition
850
+ else:
851
+ w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
854
852
  scales_size2 = w2_scales_size // self.quant_config.group_size
855
853
  strategy = FusedMoeWeightScaleSupported.GROUP.value
856
854
  else:
@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1052
1050
  )
1053
1051
  replace_parameter(layer, "w2_scales", marlin_w2_scales)
1054
1052
 
1053
+ def create_moe_runner(
1054
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1055
+ ):
1056
+ self.moe_runner_config = moe_runner_config
1057
+
1055
1058
  def apply(
1056
1059
  self,
1057
1060
  layer: torch.nn.Module,
1058
- x: torch.Tensor,
1059
- topk_output: TopKOutput,
1060
- moe_runner_config: MoeRunnerConfig,
1061
- ) -> torch.Tensor:
1061
+ dispatch_output: StandardDispatchOutput,
1062
+ ) -> CombineInput:
1063
+
1064
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1065
+
1066
+ x = dispatch_output.hidden_states
1067
+ topk_output = dispatch_output.topk_output
1068
+
1062
1069
  # Delay the import to avoid circular dependency
1063
1070
 
1064
1071
  assert (
1065
- moe_runner_config.activation == "silu"
1072
+ self.moe_runner_config.activation == "silu"
1066
1073
  ), "Only SiLU activation is supported."
1067
1074
 
1068
1075
  # The input must currently be float16
@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1071
1078
 
1072
1079
  topk_weights, topk_ids, router_logits = topk_output
1073
1080
 
1074
- return fused_marlin_moe(
1081
+ output = fused_marlin_moe(
1075
1082
  x,
1076
1083
  layer.w13_qweight,
1077
1084
  layer.w2_qweight,
@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1087
1094
  num_bits=self.quant_config.weight_bits,
1088
1095
  is_k_full=self.is_k_full,
1089
1096
  ).to(orig_dtype)
1097
+ return StandardCombineInput(hidden_states=output)
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
10
10
  from sglang.srt.distributed import get_tp_group
11
11
  from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
12
12
  from sglang.srt.layers.moe import (
13
+ MoeRunner,
14
+ MoeRunnerBackend,
15
+ MoeRunnerConfig,
13
16
  should_use_flashinfer_cutlass_moe_fp4_allgather,
14
17
  should_use_flashinfer_trtllm_moe,
15
18
  )
16
19
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
20
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
17
21
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
18
22
  from sglang.srt.layers.quantization.base_config import (
19
23
  FusedMoEMethodBase,
@@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2
39
43
 
40
44
  if TYPE_CHECKING:
41
45
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
42
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
43
- from sglang.srt.layers.moe.topk import TopKOutput
46
+ from sglang.srt.layers.moe.token_dispatcher import (
47
+ CombineInput,
48
+ StandardDispatchOutput,
49
+ )
44
50
 
45
51
  if is_cuda():
46
52
  from sgl_kernel import scaled_fp4_quant
@@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
322
328
  layer: torch.nn.Module,
323
329
  num_experts: int,
324
330
  hidden_size: int,
325
- intermediate_size: int,
331
+ intermediate_size_per_partition: int,
326
332
  params_dtype: torch.dtype,
327
333
  **extra_weight_attrs,
328
334
  ):
@@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
338
344
 
339
345
  w13_weight = ModelWeightParameter(
340
346
  data=torch.empty(
341
- num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
347
+ num_experts,
348
+ 2 * intermediate_size_per_partition,
349
+ hidden_size,
350
+ dtype=weight_dtype,
342
351
  ),
343
352
  input_dim=2,
344
353
  output_dim=1,
@@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
348
357
 
349
358
  w2_weight = ModelWeightParameter(
350
359
  data=torch.empty(
351
- num_experts, hidden_size, intermediate_size, dtype=weight_dtype
360
+ num_experts,
361
+ hidden_size,
362
+ intermediate_size_per_partition,
363
+ dtype=weight_dtype,
352
364
  ),
353
365
  input_dim=2,
354
366
  output_dim=1,
@@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
414
426
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
415
427
 
416
428
  # Requantize each expert's weights using the combined scale
417
- # w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
418
- # where the first intermediate_size rows are w1, the next are w3
419
- intermediate_size = layer.w13_weight.shape[1] // 2
429
+ # w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
430
+ # where the first intermediate_size_per_partition rows are w1, the next are w3
431
+ intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
420
432
  for expert_id in range(layer.w13_weight.shape[0]):
421
433
  start = 0
422
434
  for shard_id in range(2): # w1 and w3
423
435
  # Dequantize using the original scale for this shard
424
436
  dq_weight = per_tensor_dequantize(
425
437
  layer.w13_weight[expert_id][
426
- start : start + intermediate_size, :
438
+ start : start + intermediate_size_per_partition, :
427
439
  ],
428
440
  layer.w13_weight_scale[expert_id][shard_id],
429
441
  )
430
442
  # Requantize using the combined max scale
431
443
  (
432
444
  layer.w13_weight[expert_id][
433
- start : start + intermediate_size, :
445
+ start : start + intermediate_size_per_partition, :
434
446
  ],
435
447
  _,
436
448
  ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
437
449
 
438
- start += intermediate_size
450
+ start += intermediate_size_per_partition
439
451
 
440
452
  # Update the scale parameter to be per-expert instead of per-shard
441
453
  layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
@@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
457
469
  layer.w2_input_scale.max(), requires_grad=False
458
470
  )
459
471
 
472
+ def create_moe_runner(
473
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
474
+ ):
475
+ self.moe_runner_config = moe_runner_config
476
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
477
+
460
478
  def apply(
461
479
  self,
462
480
  layer: torch.nn.Module,
463
- x: torch.Tensor,
464
- topk_output: TopKOutput,
465
- moe_runner_config: MoeRunnerConfig,
466
- ) -> torch.Tensor:
467
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
468
-
469
- return fused_experts(
470
- x,
471
- layer.w13_weight,
472
- layer.w2_weight,
473
- topk_output=topk_output,
474
- moe_runner_config=moe_runner_config,
481
+ dispatch_output: StandardDispatchOutput,
482
+ ) -> CombineInput:
483
+
484
+ quant_info = TritonMoeQuantInfo(
485
+ w13_weight=layer.w13_weight,
486
+ w2_weight=layer.w2_weight,
475
487
  use_fp8_w8a8=True,
476
- per_channel_quant=False, # ModelOpt uses per-tensor quantization
477
- w1_scale=layer.w13_weight_scale,
488
+ per_channel_quant=False,
489
+ w13_scale=layer.w13_weight_scale,
478
490
  w2_scale=layer.w2_weight_scale,
479
- a1_scale=layer.w13_input_scale,
491
+ a13_scale=layer.w13_input_scale,
480
492
  a2_scale=layer.w2_input_scale,
481
493
  )
482
494
 
495
+ return self.runner.run(dispatch_output, quant_info)
496
+
483
497
 
484
498
  class ModelOptFp4Config(QuantizationConfig):
485
499
  """Config class for FP4."""
@@ -517,6 +531,39 @@ class ModelOptFp4Config(QuantizationConfig):
517
531
  def get_config_filenames(cls) -> List[str]:
518
532
  return ["hf_quant_config.json"]
519
533
 
534
+ @staticmethod
535
+ def common_group_size(cfg: dict) -> int:
536
+ """Return the unique group_size across the config; raise if missing/mismatched."""
537
+ sizes = set()
538
+
539
+ # Top-level and 'quantization' block
540
+ v = cfg.get("group_size")
541
+ if isinstance(v, int):
542
+ sizes.add(v)
543
+ q = cfg.get("quantization")
544
+ if isinstance(q, dict):
545
+ v = q.get("group_size")
546
+ if isinstance(v, int):
547
+ sizes.add(v)
548
+
549
+ # config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
550
+ for g in (cfg.get("config_groups") or {}).values():
551
+ if isinstance(g, dict):
552
+ v = g.get("group_size")
553
+ if isinstance(v, int):
554
+ sizes.add(v)
555
+ for sub in g.values():
556
+ if isinstance(sub, dict):
557
+ v = sub.get("group_size")
558
+ if isinstance(v, int):
559
+ sizes.add(v)
560
+
561
+ if not sizes:
562
+ raise ValueError("No group_size found in config.")
563
+ if len(sizes) > 1:
564
+ raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
565
+ return next(iter(sizes))
566
+
520
567
  @classmethod
521
568
  def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
522
569
  # Handle two different config formats:
@@ -549,7 +596,7 @@ class ModelOptFp4Config(QuantizationConfig):
549
596
  else:
550
597
  kv_cache_quant_algo = "auto"
551
598
 
552
- group_size = config.get("group_size")
599
+ group_size = ModelOptFp4Config.common_group_size(config)
553
600
  exclude_modules = config.get("ignore", [])
554
601
  else:
555
602
  # Fall back to nested format (hf_quant_config.json - legacy format)
@@ -559,7 +606,7 @@ class ModelOptFp4Config(QuantizationConfig):
559
606
  kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
560
607
  if not kv_cache_quant_algo:
561
608
  kv_cache_quant_algo = "auto"
562
- group_size = quant_config.get("group_size")
609
+ group_size = ModelOptFp4Config.common_group_size(config)
563
610
  exclude_modules = quant_config.get("exclude_modules", [])
564
611
  except (ValueError, KeyError):
565
612
  raise ValueError(
@@ -595,10 +642,22 @@ class ModelOptFp4Config(QuantizationConfig):
595
642
  def is_layer_excluded(self, prefix: str, exclude_modules: list):
596
643
  import regex as re
597
644
 
645
+ fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
646
+ prefix_split = prefix.split(".")
598
647
  for pattern in exclude_modules:
599
648
  regex_str = pattern.replace(".", r"\.").replace("*", r".*")
649
+ pattern_split = pattern.split(".")
600
650
  if re.fullmatch(regex_str, prefix):
601
651
  return True
652
+ elif (
653
+ pattern_split[-1] in fused_patterns
654
+ and pattern_split[-1] in prefix_split[-1]
655
+ ):
656
+ # Check if the last part of the excluded pattern is contained in the last part of the prefix
657
+ # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
658
+ # e.g., model.layers.{i}.self_attn.{fused_weight_name}
659
+ assert len(prefix_split) == 5 and len(pattern_split) == 5
660
+ return True
602
661
  return False
603
662
 
604
663
  def get_quant_method(
@@ -1203,8 +1262,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1203
1262
  layer.w13_weight_scale,
1204
1263
  )
1205
1264
 
1206
- logger.info_once("Applied flashinfer weight processing for both w13 and w2")
1207
-
1208
1265
  else:
1209
1266
  # CUTLASS processing - handle w13 and w2 separately
1210
1267
 
@@ -1221,7 +1278,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1221
1278
  layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1222
1279
 
1223
1280
  # Both flashinfer cutlass and regular cutlass use same processing for w2
1224
- logger.info_once("Applied weight processing for both w13 and w2")
1225
1281
 
1226
1282
  # Set up CUTLASS MoE parameters
1227
1283
  device = layer.w13_weight.device
@@ -1238,21 +1294,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1238
1294
  # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
1239
1295
  return self.enable_flashinfer_cutlass_moe
1240
1296
 
1297
+ def create_moe_runner(
1298
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1299
+ ):
1300
+ self.moe_runner_config = moe_runner_config
1301
+
1241
1302
  def apply(
1242
1303
  self,
1243
1304
  layer: FusedMoE,
1244
- x: torch.Tensor,
1245
- topk_output: TopKOutput,
1246
- moe_runner_config: MoeRunnerConfig,
1247
- ) -> torch.Tensor:
1305
+ dispatch_output: StandardDispatchOutput,
1306
+ ) -> CombineInput:
1307
+
1308
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1309
+
1310
+ x = dispatch_output.hidden_states
1311
+ topk_output = dispatch_output.topk_output
1312
+
1248
1313
  assert (
1249
- moe_runner_config.activation == "silu"
1314
+ self.moe_runner_config.activation == "silu"
1250
1315
  ), "Only SiLU activation is supported."
1251
1316
 
1317
+ moe_runner_config = self.moe_runner_config
1318
+
1252
1319
  # Check if this is a FlashInferFP4MoE layer that should handle its own forward
1253
1320
  if hasattr(layer, "gemm1_weights_fp4_shuffled"):
1254
1321
  # This layer was processed with flashinfer TRTLLM - delegate to its own forward
1255
- return layer.forward(x, topk_output)
1322
+ return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
1256
1323
 
1257
1324
  if self.enable_flashinfer_cutlass_moe:
1258
1325
  assert (
@@ -1305,13 +1372,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1305
1372
  tp_rank=layer.moe_tp_rank,
1306
1373
  tune_max_num_tokens=next_power_of_2(x.shape[0]),
1307
1374
  )[0]
1308
- # Scale by routed_scaling_factor is fused into select_experts.
1309
1375
  if should_use_flashinfer_cutlass_moe_fp4_allgather():
1310
1376
  output, global_output = get_local_dp_buffer(), output
1311
1377
  get_tp_group().reduce_scatterv(
1312
1378
  global_output, output=output, sizes=get_dp_global_num_tokens()
1313
1379
  )
1314
- return output
1380
+ return StandardCombineInput(hidden_states=output)
1315
1381
 
1316
1382
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
1317
1383
 
@@ -1332,4 +1398,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1332
1398
  apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
1333
1399
  ).to(x.dtype)
1334
1400
  # Scale by routed_scaling_factor is fused into select_experts.
1335
- return output
1401
+
1402
+ return StandardCombineInput(hidden_states=output)
@@ -9,6 +9,8 @@ import torch
9
9
 
10
10
  from sglang.srt.distributed import get_tensor_model_parallel_rank
11
11
  from sglang.srt.distributed.parallel_state import get_tp_group
12
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
13
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
12
14
  from sglang.srt.layers.quantization.awq import AWQConfig
13
15
  from sglang.srt.layers.quantization.base_config import (
14
16
  FusedMoEMethodBase,
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
22
24
  logger = logging.getLogger(__name__)
23
25
 
24
26
  if TYPE_CHECKING:
25
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
26
- from sglang.srt.layers.moe.topk import TopKOutput
27
+ from sglang.srt.layers.moe.token_dispatcher import (
28
+ CombineInput,
29
+ StandardDispatchOutput,
30
+ )
27
31
 
28
32
 
29
33
  def get_weight_perm(num_bits: int):
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
349
353
  layer.register_parameter(key, param)
350
354
  set_weight_attrs(param, extra_weight_attrs)
351
355
 
356
+ def create_moe_runner(
357
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
358
+ ):
359
+ self.moe_runner_config = moe_runner_config
360
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
361
+
352
362
  def apply(
353
363
  self,
354
364
  layer: torch.nn.Module,
355
- x: torch.Tensor,
356
- topk_output: TopKOutput,
357
- moe_runner_config: MoeRunnerConfig,
358
- ) -> torch.Tensor:
359
- # avoid circular import
360
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
361
-
365
+ dispatch_output: StandardDispatchOutput,
366
+ ) -> CombineInput:
362
367
  assert (
363
- moe_runner_config.activation == "silu"
368
+ self.moe_runner_config.activation == "silu"
364
369
  ), "Only SiLU activation is supported."
365
370
 
366
371
  weight_bits = self.quant_config.weight_bits
367
372
  has_zp = self.quant_config.has_zp
368
373
 
369
- return fused_experts(
370
- x,
371
- layer.w13_qweight,
372
- layer.w2_qweight,
373
- topk_output=topk_output,
374
- moe_runner_config=moe_runner_config,
374
+ quant_info = TritonMoeQuantInfo(
375
+ w13_weight=layer.w13_qweight,
376
+ w2_weight=layer.w2_qweight,
375
377
  use_int4_w4a16=weight_bits == 4,
376
378
  use_int8_w8a16=weight_bits == 8,
377
- w1_scale=layer.w13_scales,
379
+ w13_scale=layer.w13_scales,
378
380
  w2_scale=layer.w2_scales,
379
- w1_zp=layer.w13_qzeros if has_zp else None,
381
+ w13_zp=layer.w13_qzeros if has_zp else None,
380
382
  w2_zp=layer.w2_qzeros if has_zp else None,
381
383
  block_shape=[0, layer.group_size],
382
384
  )
385
+ return self.runner.run(dispatch_output, quant_info)
383
386
 
384
387
  @staticmethod
385
388
  def get_weight_loader(layer, weight_loader):