sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
45
45
 
46
46
  if TYPE_CHECKING:
47
47
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
48
- from sglang.srt.layers.moe.topk import TopKOutput
48
+ from sglang.srt.layers.moe.token_dispatcher import (
49
+ StandardDispatchOutput,
50
+ CombineInput,
51
+ )
49
52
 
50
53
  from sglang.srt.utils import is_cuda
51
54
 
@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
838
841
  from sglang.srt.layers.linear import set_weight_attrs
839
842
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
840
843
 
841
- intermediate_size = extra_weight_attrs.pop("intermediate_size")
842
-
843
- self.is_k_full = (not self.quant_config.desc_act) or (
844
- intermediate_size_per_partition == intermediate_size
845
- )
844
+ self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1
846
845
 
847
846
  if self.quant_config.group_size != -1:
848
847
  scales_size13 = hidden_size // self.quant_config.group_size
849
- w2_scales_size = (
850
- intermediate_size
851
- if self.quant_config.desc_act
852
- else intermediate_size_per_partition
853
- )
848
+ if self.quant_config.desc_act:
849
+ w2_scales_size = intermediate_size_per_partition
850
+ else:
851
+ w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
854
852
  scales_size2 = w2_scales_size // self.quant_config.group_size
855
853
  strategy = FusedMoeWeightScaleSupported.GROUP.value
856
854
  else:
@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1052
1050
  )
1053
1051
  replace_parameter(layer, "w2_scales", marlin_w2_scales)
1054
1052
 
1053
+ def create_moe_runner(
1054
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1055
+ ):
1056
+ self.moe_runner_config = moe_runner_config
1057
+
1055
1058
  def apply(
1056
1059
  self,
1057
1060
  layer: torch.nn.Module,
1058
- x: torch.Tensor,
1059
- topk_output: TopKOutput,
1060
- moe_runner_config: MoeRunnerConfig,
1061
- ) -> torch.Tensor:
1061
+ dispatch_output: StandardDispatchOutput,
1062
+ ) -> CombineInput:
1063
+
1064
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1065
+
1066
+ x = dispatch_output.hidden_states
1067
+ topk_output = dispatch_output.topk_output
1068
+
1062
1069
  # Delay the import to avoid circular dependency
1063
1070
 
1064
1071
  assert (
1065
- moe_runner_config.activation == "silu"
1072
+ self.moe_runner_config.activation == "silu"
1066
1073
  ), "Only SiLU activation is supported."
1067
1074
 
1068
1075
  # The input must currently be float16
@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1071
1078
 
1072
1079
  topk_weights, topk_ids, router_logits = topk_output
1073
1080
 
1074
- return fused_marlin_moe(
1081
+ output = fused_marlin_moe(
1075
1082
  x,
1076
1083
  layer.w13_qweight,
1077
1084
  layer.w2_qweight,
@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1087
1094
  num_bits=self.quant_config.weight_bits,
1088
1095
  is_k_full=self.is_k_full,
1089
1096
  ).to(orig_dtype)
1097
+ return StandardCombineInput(hidden_states=output)
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
10
10
  from sglang.srt.distributed import get_tp_group
11
11
  from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
12
12
  from sglang.srt.layers.moe import (
13
+ MoeRunner,
14
+ MoeRunnerBackend,
15
+ MoeRunnerConfig,
13
16
  should_use_flashinfer_cutlass_moe_fp4_allgather,
14
17
  should_use_flashinfer_trtllm_moe,
15
18
  )
16
19
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
20
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
17
21
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
18
22
  from sglang.srt.layers.quantization.base_config import (
19
23
  FusedMoEMethodBase,
@@ -35,12 +39,14 @@ from sglang.srt.layers.quantization.utils import (
35
39
  requantize_with_max_scale,
36
40
  )
37
41
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.utils import is_cuda, next_power_of_2
42
+ from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
39
43
 
40
44
  if TYPE_CHECKING:
41
45
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
42
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
43
- from sglang.srt.layers.moe.topk import TopKOutput
46
+ from sglang.srt.layers.moe.token_dispatcher import (
47
+ CombineInput,
48
+ StandardDispatchOutput,
49
+ )
44
50
 
45
51
  if is_cuda():
46
52
  from sgl_kernel import scaled_fp4_quant
@@ -68,6 +74,10 @@ except ImportError:
68
74
  # Initialize logger for the module
69
75
  logger = logging.getLogger(__name__)
70
76
 
77
+ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
78
+ "SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
79
+ )
80
+
71
81
  # Supported activation schemes for the current configuration
72
82
  ACTIVATION_SCHEMES = ["static"]
73
83
 
@@ -322,7 +332,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
322
332
  layer: torch.nn.Module,
323
333
  num_experts: int,
324
334
  hidden_size: int,
325
- intermediate_size: int,
335
+ intermediate_size_per_partition: int,
326
336
  params_dtype: torch.dtype,
327
337
  **extra_weight_attrs,
328
338
  ):
@@ -338,7 +348,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
338
348
 
339
349
  w13_weight = ModelWeightParameter(
340
350
  data=torch.empty(
341
- num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
351
+ num_experts,
352
+ 2 * intermediate_size_per_partition,
353
+ hidden_size,
354
+ dtype=weight_dtype,
342
355
  ),
343
356
  input_dim=2,
344
357
  output_dim=1,
@@ -348,7 +361,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
348
361
 
349
362
  w2_weight = ModelWeightParameter(
350
363
  data=torch.empty(
351
- num_experts, hidden_size, intermediate_size, dtype=weight_dtype
364
+ num_experts,
365
+ hidden_size,
366
+ intermediate_size_per_partition,
367
+ dtype=weight_dtype,
352
368
  ),
353
369
  input_dim=2,
354
370
  output_dim=1,
@@ -414,28 +430,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
414
430
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
415
431
 
416
432
  # Requantize each expert's weights using the combined scale
417
- # w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
418
- # where the first intermediate_size rows are w1, the next are w3
419
- intermediate_size = layer.w13_weight.shape[1] // 2
433
+ # w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
434
+ # where the first intermediate_size_per_partition rows are w1, the next are w3
435
+ intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
420
436
  for expert_id in range(layer.w13_weight.shape[0]):
421
437
  start = 0
422
438
  for shard_id in range(2): # w1 and w3
423
439
  # Dequantize using the original scale for this shard
424
440
  dq_weight = per_tensor_dequantize(
425
441
  layer.w13_weight[expert_id][
426
- start : start + intermediate_size, :
442
+ start : start + intermediate_size_per_partition, :
427
443
  ],
428
444
  layer.w13_weight_scale[expert_id][shard_id],
429
445
  )
430
446
  # Requantize using the combined max scale
431
447
  (
432
448
  layer.w13_weight[expert_id][
433
- start : start + intermediate_size, :
449
+ start : start + intermediate_size_per_partition, :
434
450
  ],
435
451
  _,
436
452
  ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
437
453
 
438
- start += intermediate_size
454
+ start += intermediate_size_per_partition
439
455
 
440
456
  # Update the scale parameter to be per-expert instead of per-shard
441
457
  layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
@@ -457,29 +473,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
457
473
  layer.w2_input_scale.max(), requires_grad=False
458
474
  )
459
475
 
476
+ def create_moe_runner(
477
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
478
+ ):
479
+ self.moe_runner_config = moe_runner_config
480
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
481
+
460
482
  def apply(
461
483
  self,
462
484
  layer: torch.nn.Module,
463
- x: torch.Tensor,
464
- topk_output: TopKOutput,
465
- moe_runner_config: MoeRunnerConfig,
466
- ) -> torch.Tensor:
467
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
468
-
469
- return fused_experts(
470
- x,
471
- layer.w13_weight,
472
- layer.w2_weight,
473
- topk_output=topk_output,
474
- moe_runner_config=moe_runner_config,
485
+ dispatch_output: StandardDispatchOutput,
486
+ ) -> CombineInput:
487
+
488
+ quant_info = TritonMoeQuantInfo(
489
+ w13_weight=layer.w13_weight,
490
+ w2_weight=layer.w2_weight,
475
491
  use_fp8_w8a8=True,
476
- per_channel_quant=False, # ModelOpt uses per-tensor quantization
477
- w1_scale=layer.w13_weight_scale,
492
+ per_channel_quant=False,
493
+ w13_scale=layer.w13_weight_scale,
478
494
  w2_scale=layer.w2_weight_scale,
479
- a1_scale=layer.w13_input_scale,
495
+ a13_scale=layer.w13_input_scale,
480
496
  a2_scale=layer.w2_input_scale,
481
497
  )
482
498
 
499
+ return self.runner.run(dispatch_output, quant_info)
500
+
483
501
 
484
502
  class ModelOptFp4Config(QuantizationConfig):
485
503
  """Config class for FP4."""
@@ -517,6 +535,39 @@ class ModelOptFp4Config(QuantizationConfig):
517
535
  def get_config_filenames(cls) -> List[str]:
518
536
  return ["hf_quant_config.json"]
519
537
 
538
+ @staticmethod
539
+ def common_group_size(cfg: dict) -> int:
540
+ """Return the unique group_size across the config; raise if missing/mismatched."""
541
+ sizes = set()
542
+
543
+ # Top-level and 'quantization' block
544
+ v = cfg.get("group_size")
545
+ if isinstance(v, int):
546
+ sizes.add(v)
547
+ q = cfg.get("quantization")
548
+ if isinstance(q, dict):
549
+ v = q.get("group_size")
550
+ if isinstance(v, int):
551
+ sizes.add(v)
552
+
553
+ # config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
554
+ for g in (cfg.get("config_groups") or {}).values():
555
+ if isinstance(g, dict):
556
+ v = g.get("group_size")
557
+ if isinstance(v, int):
558
+ sizes.add(v)
559
+ for sub in g.values():
560
+ if isinstance(sub, dict):
561
+ v = sub.get("group_size")
562
+ if isinstance(v, int):
563
+ sizes.add(v)
564
+
565
+ if not sizes:
566
+ raise ValueError("No group_size found in config.")
567
+ if len(sizes) > 1:
568
+ raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
569
+ return next(iter(sizes))
570
+
520
571
  @classmethod
521
572
  def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
522
573
  # Handle two different config formats:
@@ -549,7 +600,7 @@ class ModelOptFp4Config(QuantizationConfig):
549
600
  else:
550
601
  kv_cache_quant_algo = "auto"
551
602
 
552
- group_size = config.get("group_size")
603
+ group_size = ModelOptFp4Config.common_group_size(config)
553
604
  exclude_modules = config.get("ignore", [])
554
605
  else:
555
606
  # Fall back to nested format (hf_quant_config.json - legacy format)
@@ -559,7 +610,7 @@ class ModelOptFp4Config(QuantizationConfig):
559
610
  kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
560
611
  if not kv_cache_quant_algo:
561
612
  kv_cache_quant_algo = "auto"
562
- group_size = quant_config.get("group_size")
613
+ group_size = ModelOptFp4Config.common_group_size(config)
563
614
  exclude_modules = quant_config.get("exclude_modules", [])
564
615
  except (ValueError, KeyError):
565
616
  raise ValueError(
@@ -595,16 +646,21 @@ class ModelOptFp4Config(QuantizationConfig):
595
646
  def is_layer_excluded(self, prefix: str, exclude_modules: list):
596
647
  import regex as re
597
648
 
649
+ fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
650
+ prefix_split = prefix.split(".")
598
651
  for pattern in exclude_modules:
599
652
  regex_str = pattern.replace(".", r"\.").replace("*", r".*")
653
+ pattern_split = pattern.split(".")
600
654
  if re.fullmatch(regex_str, prefix):
601
655
  return True
602
-
603
- # Check if the last part of the excluded pattern is contained in the last part of the prefix
604
- # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
605
- pattern_last_part = pattern.split(".")[-1]
606
- prefix_last_part = prefix.split(".")[-1]
607
- if pattern_last_part in prefix_last_part:
656
+ elif (
657
+ pattern_split[-1] in fused_patterns
658
+ and pattern_split[-1] in prefix_split[-1]
659
+ ):
660
+ # Check if the last part of the excluded pattern is contained in the last part of the prefix
661
+ # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
662
+ # e.g., model.layers.{i}.self_attn.{fused_weight_name}
663
+ assert len(prefix_split) == 5 and len(pattern_split) == 5
608
664
  return True
609
665
  return False
610
666
 
@@ -826,6 +882,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
826
882
  """Access the global enable_flashinfer_cutlass_moe setting."""
827
883
  return get_moe_runner_backend().is_flashinfer_cutlass()
828
884
 
885
+ @property
886
+ def enable_flashinfer_cutedsl_moe(self) -> bool:
887
+ from sglang.srt.layers.moe import get_moe_runner_backend
888
+
889
+ """Access the global enable_flashinfer_cutedsl_moe setting."""
890
+ return get_moe_runner_backend().is_flashinfer_cutedsl()
891
+
829
892
  def create_weights(
830
893
  self,
831
894
  layer: torch.nn.Module,
@@ -937,15 +1000,17 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
937
1000
  )
938
1001
 
939
1002
  w13_input_scale = PerTensorScaleParameter(
940
- data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
1003
+ data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
941
1004
  weight_loader=weight_loader,
942
1005
  )
1006
+ w13_input_scale._sglang_require_global_experts = True
943
1007
  layer.register_parameter("w13_input_scale", w13_input_scale)
944
1008
 
945
1009
  w2_input_scale = PerTensorScaleParameter(
946
- data=torch.empty(layer.num_local_experts, dtype=torch.float32),
1010
+ data=torch.empty(layer.num_experts, dtype=torch.float32),
947
1011
  weight_loader=weight_loader,
948
1012
  )
1013
+ w2_input_scale._sglang_require_global_experts = True
949
1014
  layer.register_parameter("w2_input_scale", w2_input_scale)
950
1015
 
951
1016
  def swizzle_blockscale(self, scale: torch.Tensor):
@@ -1128,6 +1193,33 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1128
1193
  if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
1129
1194
  w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
1130
1195
  w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
1196
+ elif self.enable_flashinfer_cutedsl_moe:
1197
+ # All-expert-one-input-scale is mathematically different from default per-expert-input-scale
1198
+ # Thus we allow users to switch the flag to do thorough testing
1199
+ if CUTEDSL_MOE_SCALAR_INPUT_SCALE:
1200
+ w13_input_scale = (
1201
+ layer.w13_input_scale.max()
1202
+ .to(torch.float32)
1203
+ .repeat(layer.w13_input_scale.shape[0])
1204
+ )
1205
+ else:
1206
+ w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
1207
+ torch.float32
1208
+ )
1209
+
1210
+ w2_input_scale = layer.w2_input_scale
1211
+
1212
+ def _slice_scale(w):
1213
+ assert w.shape == (layer.num_experts,)
1214
+ assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
1215
+ return w[
1216
+ layer.moe_ep_rank
1217
+ * layer.num_local_experts : (layer.moe_ep_rank + 1)
1218
+ * layer.num_local_experts
1219
+ ]
1220
+
1221
+ w13_input_scale = _slice_scale(w13_input_scale)
1222
+ w2_input_scale = _slice_scale(w2_input_scale)
1131
1223
  else:
1132
1224
  w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1133
1225
  w2_input_scale = layer.w2_input_scale
@@ -1210,8 +1302,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1210
1302
  layer.w13_weight_scale,
1211
1303
  )
1212
1304
 
1213
- logger.info_once("Applied flashinfer weight processing for both w13 and w2")
1214
-
1215
1305
  else:
1216
1306
  # CUTLASS processing - handle w13 and w2 separately
1217
1307
 
@@ -1228,7 +1318,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1228
1318
  layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1229
1319
 
1230
1320
  # Both flashinfer cutlass and regular cutlass use same processing for w2
1231
- logger.info_once("Applied weight processing for both w13 and w2")
1232
1321
 
1233
1322
  # Set up CUTLASS MoE parameters
1234
1323
  device = layer.w13_weight.device
@@ -1245,21 +1334,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1245
1334
  # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
1246
1335
  return self.enable_flashinfer_cutlass_moe
1247
1336
 
1337
+ def create_moe_runner(
1338
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1339
+ ):
1340
+ self.moe_runner_config = moe_runner_config
1341
+
1248
1342
  def apply(
1249
1343
  self,
1250
1344
  layer: FusedMoE,
1251
- x: torch.Tensor,
1252
- topk_output: TopKOutput,
1253
- moe_runner_config: MoeRunnerConfig,
1254
- ) -> torch.Tensor:
1345
+ dispatch_output: StandardDispatchOutput,
1346
+ ) -> CombineInput:
1347
+
1348
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1349
+
1350
+ x = dispatch_output.hidden_states
1351
+ topk_output = dispatch_output.topk_output
1352
+
1255
1353
  assert (
1256
- moe_runner_config.activation == "silu"
1354
+ self.moe_runner_config.activation == "silu"
1257
1355
  ), "Only SiLU activation is supported."
1258
1356
 
1357
+ moe_runner_config = self.moe_runner_config
1358
+
1259
1359
  # Check if this is a FlashInferFP4MoE layer that should handle its own forward
1260
1360
  if hasattr(layer, "gemm1_weights_fp4_shuffled"):
1261
1361
  # This layer was processed with flashinfer TRTLLM - delegate to its own forward
1262
- return layer.forward(x, topk_output)
1362
+ return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
1263
1363
 
1264
1364
  if self.enable_flashinfer_cutlass_moe:
1265
1365
  assert (
@@ -1312,13 +1412,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1312
1412
  tp_rank=layer.moe_tp_rank,
1313
1413
  tune_max_num_tokens=next_power_of_2(x.shape[0]),
1314
1414
  )[0]
1315
- # Scale by routed_scaling_factor is fused into select_experts.
1316
1415
  if should_use_flashinfer_cutlass_moe_fp4_allgather():
1317
1416
  output, global_output = get_local_dp_buffer(), output
1318
1417
  get_tp_group().reduce_scatterv(
1319
1418
  global_output, output=output, sizes=get_dp_global_num_tokens()
1320
1419
  )
1321
- return output
1420
+ return StandardCombineInput(hidden_states=output)
1322
1421
 
1323
1422
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
1324
1423
 
@@ -1339,4 +1438,38 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1339
1438
  apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
1340
1439
  ).to(x.dtype)
1341
1440
  # Scale by routed_scaling_factor is fused into select_experts.
1342
- return output
1441
+ return StandardCombineInput(hidden_states=output)
1442
+
1443
+ def apply_without_routing_weights(
1444
+ self,
1445
+ layer: FusedMoE,
1446
+ x: torch.Tensor,
1447
+ masked_m: torch.Tensor,
1448
+ moe_runner_config: MoeRunnerConfig,
1449
+ ) -> torch.Tensor:
1450
+ assert (
1451
+ moe_runner_config.activation == "silu"
1452
+ ), "Only SiLU activation is supported."
1453
+
1454
+ assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe"
1455
+ assert (
1456
+ not moe_runner_config.apply_router_weight_on_input
1457
+ ), "apply_router_weight_on_input is not supported for Flashinfer"
1458
+
1459
+ from sglang.srt.layers.moe.flashinfer_cutedsl_moe import (
1460
+ flashinfer_cutedsl_moe_masked,
1461
+ )
1462
+
1463
+ out = flashinfer_cutedsl_moe_masked(
1464
+ hidden_states=x,
1465
+ input_global_scale=layer.w13_input_scale_quant,
1466
+ w1=layer.w13_weight,
1467
+ w1_blockscale=layer.w13_blockscale_swizzled,
1468
+ w1_alpha=layer.g1_alphas,
1469
+ w2=layer.w2_weight,
1470
+ a2_global_scale=layer.w2_input_scale_quant,
1471
+ w2_blockscale=layer.w2_blockscale_swizzled,
1472
+ w2_alpha=layer.g2_alphas,
1473
+ masked_m=masked_m,
1474
+ )
1475
+ return out
@@ -9,6 +9,8 @@ import torch
9
9
 
10
10
  from sglang.srt.distributed import get_tensor_model_parallel_rank
11
11
  from sglang.srt.distributed.parallel_state import get_tp_group
12
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
13
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
12
14
  from sglang.srt.layers.quantization.awq import AWQConfig
13
15
  from sglang.srt.layers.quantization.base_config import (
14
16
  FusedMoEMethodBase,
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
22
24
  logger = logging.getLogger(__name__)
23
25
 
24
26
  if TYPE_CHECKING:
25
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
26
- from sglang.srt.layers.moe.topk import TopKOutput
27
+ from sglang.srt.layers.moe.token_dispatcher import (
28
+ CombineInput,
29
+ StandardDispatchOutput,
30
+ )
27
31
 
28
32
 
29
33
  def get_weight_perm(num_bits: int):
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
349
353
  layer.register_parameter(key, param)
350
354
  set_weight_attrs(param, extra_weight_attrs)
351
355
 
356
+ def create_moe_runner(
357
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
358
+ ):
359
+ self.moe_runner_config = moe_runner_config
360
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
361
+
352
362
  def apply(
353
363
  self,
354
364
  layer: torch.nn.Module,
355
- x: torch.Tensor,
356
- topk_output: TopKOutput,
357
- moe_runner_config: MoeRunnerConfig,
358
- ) -> torch.Tensor:
359
- # avoid circular import
360
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
361
-
365
+ dispatch_output: StandardDispatchOutput,
366
+ ) -> CombineInput:
362
367
  assert (
363
- moe_runner_config.activation == "silu"
368
+ self.moe_runner_config.activation == "silu"
364
369
  ), "Only SiLU activation is supported."
365
370
 
366
371
  weight_bits = self.quant_config.weight_bits
367
372
  has_zp = self.quant_config.has_zp
368
373
 
369
- return fused_experts(
370
- x,
371
- layer.w13_qweight,
372
- layer.w2_qweight,
373
- topk_output=topk_output,
374
- moe_runner_config=moe_runner_config,
374
+ quant_info = TritonMoeQuantInfo(
375
+ w13_weight=layer.w13_qweight,
376
+ w2_weight=layer.w2_qweight,
375
377
  use_int4_w4a16=weight_bits == 4,
376
378
  use_int8_w8a16=weight_bits == 8,
377
- w1_scale=layer.w13_scales,
379
+ w13_scale=layer.w13_scales,
378
380
  w2_scale=layer.w2_scales,
379
- w1_zp=layer.w13_qzeros if has_zp else None,
381
+ w13_zp=layer.w13_qzeros if has_zp else None,
380
382
  w2_zp=layer.w2_qzeros if has_zp else None,
381
383
  block_shape=[0, layer.group_size],
382
384
  )
385
+ return self.runner.run(dispatch_output, quant_info)
383
386
 
384
387
  @staticmethod
385
388
  def get_weight_loader(layer, weight_loader):