sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
10
10
  from sglang.srt.distributed import get_tp_group
11
11
  from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
12
12
  from sglang.srt.layers.moe import (
13
+ MoeRunner,
14
+ MoeRunnerBackend,
15
+ MoeRunnerConfig,
13
16
  should_use_flashinfer_cutlass_moe_fp4_allgather,
14
17
  should_use_flashinfer_trtllm_moe,
15
18
  )
16
19
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
20
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
17
21
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
18
22
  from sglang.srt.layers.quantization.base_config import (
19
23
  FusedMoEMethodBase,
@@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2
39
43
 
40
44
  if TYPE_CHECKING:
41
45
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
42
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
43
- from sglang.srt.layers.moe.topk import TopKOutput
46
+ from sglang.srt.layers.moe.token_dispatcher import (
47
+ CombineInput,
48
+ StandardDispatchOutput,
49
+ )
44
50
 
45
51
  if is_cuda():
46
52
  from sgl_kernel import scaled_fp4_quant
@@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
322
328
  layer: torch.nn.Module,
323
329
  num_experts: int,
324
330
  hidden_size: int,
325
- intermediate_size: int,
331
+ intermediate_size_per_partition: int,
326
332
  params_dtype: torch.dtype,
327
333
  **extra_weight_attrs,
328
334
  ):
@@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
338
344
 
339
345
  w13_weight = ModelWeightParameter(
340
346
  data=torch.empty(
341
- num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
347
+ num_experts,
348
+ 2 * intermediate_size_per_partition,
349
+ hidden_size,
350
+ dtype=weight_dtype,
342
351
  ),
343
352
  input_dim=2,
344
353
  output_dim=1,
@@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
348
357
 
349
358
  w2_weight = ModelWeightParameter(
350
359
  data=torch.empty(
351
- num_experts, hidden_size, intermediate_size, dtype=weight_dtype
360
+ num_experts,
361
+ hidden_size,
362
+ intermediate_size_per_partition,
363
+ dtype=weight_dtype,
352
364
  ),
353
365
  input_dim=2,
354
366
  output_dim=1,
@@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
414
426
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
415
427
 
416
428
  # Requantize each expert's weights using the combined scale
417
- # w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
418
- # where the first intermediate_size rows are w1, the next are w3
419
- intermediate_size = layer.w13_weight.shape[1] // 2
429
+ # w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
430
+ # where the first intermediate_size_per_partition rows are w1, the next are w3
431
+ intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
420
432
  for expert_id in range(layer.w13_weight.shape[0]):
421
433
  start = 0
422
434
  for shard_id in range(2): # w1 and w3
423
435
  # Dequantize using the original scale for this shard
424
436
  dq_weight = per_tensor_dequantize(
425
437
  layer.w13_weight[expert_id][
426
- start : start + intermediate_size, :
438
+ start : start + intermediate_size_per_partition, :
427
439
  ],
428
440
  layer.w13_weight_scale[expert_id][shard_id],
429
441
  )
430
442
  # Requantize using the combined max scale
431
443
  (
432
444
  layer.w13_weight[expert_id][
433
- start : start + intermediate_size, :
445
+ start : start + intermediate_size_per_partition, :
434
446
  ],
435
447
  _,
436
448
  ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
437
449
 
438
- start += intermediate_size
450
+ start += intermediate_size_per_partition
439
451
 
440
452
  # Update the scale parameter to be per-expert instead of per-shard
441
453
  layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
@@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
457
469
  layer.w2_input_scale.max(), requires_grad=False
458
470
  )
459
471
 
472
+ def create_moe_runner(
473
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
474
+ ):
475
+ self.moe_runner_config = moe_runner_config
476
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
477
+
460
478
  def apply(
461
479
  self,
462
480
  layer: torch.nn.Module,
463
- x: torch.Tensor,
464
- topk_output: TopKOutput,
465
- moe_runner_config: MoeRunnerConfig,
466
- ) -> torch.Tensor:
467
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
468
-
469
- return fused_experts(
470
- x,
471
- layer.w13_weight,
472
- layer.w2_weight,
473
- topk_output=topk_output,
474
- moe_runner_config=moe_runner_config,
481
+ dispatch_output: StandardDispatchOutput,
482
+ ) -> CombineInput:
483
+
484
+ quant_info = TritonMoeQuantInfo(
485
+ w13_weight=layer.w13_weight,
486
+ w2_weight=layer.w2_weight,
475
487
  use_fp8_w8a8=True,
476
- per_channel_quant=False, # ModelOpt uses per-tensor quantization
477
- w1_scale=layer.w13_weight_scale,
488
+ per_channel_quant=False,
489
+ w13_scale=layer.w13_weight_scale,
478
490
  w2_scale=layer.w2_weight_scale,
479
- a1_scale=layer.w13_input_scale,
491
+ a13_scale=layer.w13_input_scale,
480
492
  a2_scale=layer.w2_input_scale,
481
493
  )
482
494
 
495
+ return self.runner.run(dispatch_output, quant_info)
496
+
483
497
 
484
498
  class ModelOptFp4Config(QuantizationConfig):
485
499
  """Config class for FP4."""
@@ -517,6 +531,39 @@ class ModelOptFp4Config(QuantizationConfig):
517
531
  def get_config_filenames(cls) -> List[str]:
518
532
  return ["hf_quant_config.json"]
519
533
 
534
+ @staticmethod
535
+ def common_group_size(cfg: dict) -> int:
536
+ """Return the unique group_size across the config; raise if missing/mismatched."""
537
+ sizes = set()
538
+
539
+ # Top-level and 'quantization' block
540
+ v = cfg.get("group_size")
541
+ if isinstance(v, int):
542
+ sizes.add(v)
543
+ q = cfg.get("quantization")
544
+ if isinstance(q, dict):
545
+ v = q.get("group_size")
546
+ if isinstance(v, int):
547
+ sizes.add(v)
548
+
549
+ # config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
550
+ for g in (cfg.get("config_groups") or {}).values():
551
+ if isinstance(g, dict):
552
+ v = g.get("group_size")
553
+ if isinstance(v, int):
554
+ sizes.add(v)
555
+ for sub in g.values():
556
+ if isinstance(sub, dict):
557
+ v = sub.get("group_size")
558
+ if isinstance(v, int):
559
+ sizes.add(v)
560
+
561
+ if not sizes:
562
+ raise ValueError("No group_size found in config.")
563
+ if len(sizes) > 1:
564
+ raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
565
+ return next(iter(sizes))
566
+
520
567
  @classmethod
521
568
  def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
522
569
  # Handle two different config formats:
@@ -549,7 +596,7 @@ class ModelOptFp4Config(QuantizationConfig):
549
596
  else:
550
597
  kv_cache_quant_algo = "auto"
551
598
 
552
- group_size = config.get("group_size")
599
+ group_size = ModelOptFp4Config.common_group_size(config)
553
600
  exclude_modules = config.get("ignore", [])
554
601
  else:
555
602
  # Fall back to nested format (hf_quant_config.json - legacy format)
@@ -559,7 +606,7 @@ class ModelOptFp4Config(QuantizationConfig):
559
606
  kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
560
607
  if not kv_cache_quant_algo:
561
608
  kv_cache_quant_algo = "auto"
562
- group_size = quant_config.get("group_size")
609
+ group_size = ModelOptFp4Config.common_group_size(config)
563
610
  exclude_modules = quant_config.get("exclude_modules", [])
564
611
  except (ValueError, KeyError):
565
612
  raise ValueError(
@@ -595,10 +642,22 @@ class ModelOptFp4Config(QuantizationConfig):
595
642
  def is_layer_excluded(self, prefix: str, exclude_modules: list):
596
643
  import regex as re
597
644
 
645
+ fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
646
+ prefix_split = prefix.split(".")
598
647
  for pattern in exclude_modules:
599
648
  regex_str = pattern.replace(".", r"\.").replace("*", r".*")
649
+ pattern_split = pattern.split(".")
600
650
  if re.fullmatch(regex_str, prefix):
601
651
  return True
652
+ elif (
653
+ pattern_split[-1] in fused_patterns
654
+ and pattern_split[-1] in prefix_split[-1]
655
+ ):
656
+ # Check if the last part of the excluded pattern is contained in the last part of the prefix
657
+ # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
658
+ # e.g., model.layers.{i}.self_attn.{fused_weight_name}
659
+ assert len(prefix_split) == 5 and len(pattern_split) == 5
660
+ return True
602
661
  return False
603
662
 
604
663
  def get_quant_method(
@@ -1203,8 +1262,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1203
1262
  layer.w13_weight_scale,
1204
1263
  )
1205
1264
 
1206
- logger.info_once("Applied flashinfer weight processing for both w13 and w2")
1207
-
1208
1265
  else:
1209
1266
  # CUTLASS processing - handle w13 and w2 separately
1210
1267
 
@@ -1221,7 +1278,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1221
1278
  layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1222
1279
 
1223
1280
  # Both flashinfer cutlass and regular cutlass use same processing for w2
1224
- logger.info_once("Applied weight processing for both w13 and w2")
1225
1281
 
1226
1282
  # Set up CUTLASS MoE parameters
1227
1283
  device = layer.w13_weight.device
@@ -1238,21 +1294,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1238
1294
  # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
1239
1295
  return self.enable_flashinfer_cutlass_moe
1240
1296
 
1297
+ def create_moe_runner(
1298
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1299
+ ):
1300
+ self.moe_runner_config = moe_runner_config
1301
+
1241
1302
  def apply(
1242
1303
  self,
1243
1304
  layer: FusedMoE,
1244
- x: torch.Tensor,
1245
- topk_output: TopKOutput,
1246
- moe_runner_config: MoeRunnerConfig,
1247
- ) -> torch.Tensor:
1305
+ dispatch_output: StandardDispatchOutput,
1306
+ ) -> CombineInput:
1307
+
1308
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1309
+
1310
+ x = dispatch_output.hidden_states
1311
+ topk_output = dispatch_output.topk_output
1312
+
1248
1313
  assert (
1249
- moe_runner_config.activation == "silu"
1314
+ self.moe_runner_config.activation == "silu"
1250
1315
  ), "Only SiLU activation is supported."
1251
1316
 
1317
+ moe_runner_config = self.moe_runner_config
1318
+
1252
1319
  # Check if this is a FlashInferFP4MoE layer that should handle its own forward
1253
1320
  if hasattr(layer, "gemm1_weights_fp4_shuffled"):
1254
1321
  # This layer was processed with flashinfer TRTLLM - delegate to its own forward
1255
- return layer.forward(x, topk_output)
1322
+ return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
1256
1323
 
1257
1324
  if self.enable_flashinfer_cutlass_moe:
1258
1325
  assert (
@@ -1305,13 +1372,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1305
1372
  tp_rank=layer.moe_tp_rank,
1306
1373
  tune_max_num_tokens=next_power_of_2(x.shape[0]),
1307
1374
  )[0]
1308
- # Scale by routed_scaling_factor is fused into select_experts.
1309
1375
  if should_use_flashinfer_cutlass_moe_fp4_allgather():
1310
1376
  output, global_output = get_local_dp_buffer(), output
1311
1377
  get_tp_group().reduce_scatterv(
1312
1378
  global_output, output=output, sizes=get_dp_global_num_tokens()
1313
1379
  )
1314
- return output
1380
+ return StandardCombineInput(hidden_states=output)
1315
1381
 
1316
1382
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
1317
1383
 
@@ -1332,4 +1398,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1332
1398
  apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
1333
1399
  ).to(x.dtype)
1334
1400
  # Scale by routed_scaling_factor is fused into select_experts.
1335
- return output
1401
+
1402
+ return StandardCombineInput(hidden_states=output)
@@ -9,6 +9,8 @@ import torch
9
9
 
10
10
  from sglang.srt.distributed import get_tensor_model_parallel_rank
11
11
  from sglang.srt.distributed.parallel_state import get_tp_group
12
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
13
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
12
14
  from sglang.srt.layers.quantization.awq import AWQConfig
13
15
  from sglang.srt.layers.quantization.base_config import (
14
16
  FusedMoEMethodBase,
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
22
24
  logger = logging.getLogger(__name__)
23
25
 
24
26
  if TYPE_CHECKING:
25
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
26
- from sglang.srt.layers.moe.topk import TopKOutput
27
+ from sglang.srt.layers.moe.token_dispatcher import (
28
+ CombineInput,
29
+ StandardDispatchOutput,
30
+ )
27
31
 
28
32
 
29
33
  def get_weight_perm(num_bits: int):
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
349
353
  layer.register_parameter(key, param)
350
354
  set_weight_attrs(param, extra_weight_attrs)
351
355
 
356
+ def create_moe_runner(
357
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
358
+ ):
359
+ self.moe_runner_config = moe_runner_config
360
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
361
+
352
362
  def apply(
353
363
  self,
354
364
  layer: torch.nn.Module,
355
- x: torch.Tensor,
356
- topk_output: TopKOutput,
357
- moe_runner_config: MoeRunnerConfig,
358
- ) -> torch.Tensor:
359
- # avoid circular import
360
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
361
-
365
+ dispatch_output: StandardDispatchOutput,
366
+ ) -> CombineInput:
362
367
  assert (
363
- moe_runner_config.activation == "silu"
368
+ self.moe_runner_config.activation == "silu"
364
369
  ), "Only SiLU activation is supported."
365
370
 
366
371
  weight_bits = self.quant_config.weight_bits
367
372
  has_zp = self.quant_config.has_zp
368
373
 
369
- return fused_experts(
370
- x,
371
- layer.w13_qweight,
372
- layer.w2_qweight,
373
- topk_output=topk_output,
374
- moe_runner_config=moe_runner_config,
374
+ quant_info = TritonMoeQuantInfo(
375
+ w13_weight=layer.w13_qweight,
376
+ w2_weight=layer.w2_qweight,
375
377
  use_int4_w4a16=weight_bits == 4,
376
378
  use_int8_w8a16=weight_bits == 8,
377
- w1_scale=layer.w13_scales,
379
+ w13_scale=layer.w13_scales,
378
380
  w2_scale=layer.w2_scales,
379
- w1_zp=layer.w13_qzeros if has_zp else None,
381
+ w13_zp=layer.w13_qzeros if has_zp else None,
380
382
  w2_zp=layer.w2_qzeros if has_zp else None,
381
383
  block_shape=[0, layer.group_size],
382
384
  )
385
+ return self.runner.run(dispatch_output, quant_info)
383
386
 
384
387
  @staticmethod
385
388
  def get_weight_loader(layer, weight_loader):