sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -11,53 +11,41 @@ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
11
11
  ENABLE_JIT_DEEPGEMM,
12
12
  )
13
13
  from sglang.srt.server_args import ServerArgs
14
+ from sglang.srt.utils import get_bool_env_var
14
15
 
15
16
  logger = logging.getLogger(__name__)
16
17
 
17
18
  if ENABLE_JIT_DEEPGEMM:
18
19
  import deep_gemm
20
+ from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
19
21
 
20
- if DEEPGEMM_BLACKWELL:
21
- from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
22
- from deep_gemm import (
23
- fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
24
- )
25
- from deep_gemm import (
26
- m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
27
- )
28
- else:
29
- from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
30
- from deep_gemm import get_col_major_tma_aligned_tensor
31
- from deep_gemm import (
32
- m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
33
- )
34
- from deep_gemm import (
35
- m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
36
- )
22
+ _SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
37
23
 
38
24
 
25
+ # TODO maybe rename these functions
39
26
  def grouped_gemm_nt_f8f8bf16_masked(
40
27
  lhs: Tuple[torch.Tensor, torch.Tensor],
41
28
  rhs: Tuple[torch.Tensor, torch.Tensor],
42
29
  out: torch.Tensor,
43
30
  masked_m: torch.Tensor,
44
31
  expected_m: int,
45
- recipe=None,
46
32
  ):
47
33
  num_groups, _, k = lhs[0].shape
48
34
  _, n, _ = rhs[0].shape
49
35
  kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
50
36
 
37
+ _sanity_check_input(lhs)
38
+ _sanity_check_input(rhs)
39
+
51
40
  with compile_utils.deep_gemm_execution_hook(
52
41
  expected_m, n, k, num_groups, kernel_type
53
42
  ):
54
- _grouped_gemm_nt_f8f8bf16_masked_raw(
43
+ deep_gemm.fp8_m_grouped_gemm_nt_masked(
55
44
  lhs,
56
45
  rhs,
57
46
  out,
58
47
  masked_m,
59
48
  expected_m,
60
- **({"recipe": recipe} if DEEPGEMM_BLACKWELL else {})
61
49
  )
62
50
 
63
51
 
@@ -71,8 +59,11 @@ def grouped_gemm_nt_f8f8bf16_contig(
71
59
  num_groups, n, _ = rhs[0].shape
72
60
  kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
73
61
 
62
+ _sanity_check_input(lhs)
63
+ _sanity_check_input(rhs)
64
+
74
65
  with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
75
- _grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices)
66
+ deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
76
67
 
77
68
 
78
69
  def gemm_nt_f8f8bf16(
@@ -85,8 +76,11 @@ def gemm_nt_f8f8bf16(
85
76
  num_groups = 1
86
77
  kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
87
78
 
79
+ _sanity_check_input(lhs)
80
+ _sanity_check_input(rhs)
81
+
88
82
  with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
89
- _gemm_nt_f8f8bf16_raw(
83
+ deep_gemm.fp8_gemm_nt(
90
84
  lhs,
91
85
  rhs,
92
86
  out,
@@ -108,3 +102,18 @@ def configure_deep_gemm_num_sms(num_sms):
108
102
  yield
109
103
  finally:
110
104
  deep_gemm.set_num_sms(original_num_sms)
105
+
106
+
107
+ def _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]):
108
+ if not _SANITY_CHECK:
109
+ return
110
+
111
+ x, x_scale = x_fp8
112
+
113
+ if x_scale.dtype == torch.int:
114
+ return
115
+
116
+ from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0
117
+
118
+ x_scale_ceil = ceil_to_ue8m0(x_scale)
119
+ assert torch.all(x_scale == x_scale_ceil), f"{x_scale=} {x_scale_ceil=}"
@@ -30,6 +30,9 @@ except ImportError:
30
30
 
31
31
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
32
32
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
33
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
34
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
35
+ from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
33
36
  from sglang.srt.layers.parameter import (
34
37
  BlockQuantScaleParameter,
35
38
  ModelWeightParameter,
@@ -64,7 +67,6 @@ from sglang.srt.layers.quantization.utils import (
64
67
  per_tensor_dequantize,
65
68
  requantize_with_max_scale,
66
69
  )
67
- from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
68
70
  from sglang.srt.utils import (
69
71
  cpu_has_amx_support,
70
72
  get_bool_env_var,
@@ -72,6 +74,8 @@ from sglang.srt.utils import (
72
74
  is_cuda,
73
75
  is_hip,
74
76
  is_npu,
77
+ is_sm90_supported,
78
+ is_sm100_supported,
75
79
  log_info_on_rank0,
76
80
  next_power_of_2,
77
81
  print_warning_once,
@@ -80,7 +84,11 @@ from sglang.srt.utils import (
80
84
  )
81
85
 
82
86
  if TYPE_CHECKING:
83
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
87
+ from sglang.srt.layers.moe.token_dispatcher import (
88
+ CombineInput,
89
+ DispatchOutput,
90
+ StandardDispatchOutput,
91
+ )
84
92
  from sglang.srt.layers.moe.topk import TopKOutput
85
93
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
86
94
 
@@ -344,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
344
352
  _is_cpu_amx_available
345
353
  ), "Fp8LinearMethod on CPU requires that CPU has AMX support"
346
354
  _amx_process_weight_after_loading(layer, ["weight"])
355
+ layer.weight_scale_inv = torch.nn.Parameter(
356
+ layer.weight_scale_inv.data, requires_grad=False
357
+ )
347
358
  return
348
359
  else:
349
360
  weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
@@ -526,7 +537,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
526
537
  layer: Module,
527
538
  num_experts: int,
528
539
  hidden_size: int,
529
- intermediate_size: int,
540
+ intermediate_size_per_partition: int,
530
541
  params_dtype: torch.dtype,
531
542
  **extra_weight_attrs,
532
543
  ):
@@ -542,18 +553,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
542
553
  )
543
554
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
544
555
  # Required by column parallel or enabling merged weights
545
- if intermediate_size % block_n != 0:
556
+ if intermediate_size_per_partition % block_n != 0:
546
557
  raise ValueError(
547
558
  f"The output_size of gate's and up's weight = "
548
- f"{intermediate_size} is not divisible by "
559
+ f"{intermediate_size_per_partition} is not divisible by "
549
560
  f"weight quantization block_n = {block_n}."
550
561
  )
551
562
  if tp_size > 1:
552
563
  # Required by row parallel
553
- if intermediate_size % block_k != 0:
564
+ if intermediate_size_per_partition % block_k != 0:
554
565
  raise ValueError(
555
566
  f"The input_size of down's weight = "
556
- f"{intermediate_size} is not divisible by "
567
+ f"{intermediate_size_per_partition} is not divisible by "
557
568
  f"weight quantization block_k = {block_k}."
558
569
  )
559
570
 
@@ -563,7 +574,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
563
574
  w13_weight = torch.nn.Parameter(
564
575
  torch.empty(
565
576
  num_experts,
566
- 2 * intermediate_size,
577
+ 2 * intermediate_size_per_partition,
567
578
  hidden_size // 8,
568
579
  dtype=params_dtype,
569
580
  ),
@@ -571,20 +582,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
571
582
  )
572
583
  w2_weight = torch.nn.Parameter(
573
584
  torch.empty(
574
- num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype
585
+ num_experts,
586
+ hidden_size,
587
+ intermediate_size_per_partition // 8,
588
+ dtype=params_dtype,
575
589
  ),
576
590
  requires_grad=False,
577
591
  )
578
592
  else:
579
593
  w13_weight = torch.nn.Parameter(
580
594
  torch.empty(
581
- num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
595
+ num_experts,
596
+ 2 * intermediate_size_per_partition,
597
+ hidden_size,
598
+ dtype=params_dtype,
582
599
  ),
583
600
  requires_grad=False,
584
601
  )
585
602
  w2_weight = torch.nn.Parameter(
586
603
  torch.empty(
587
- num_experts, hidden_size, intermediate_size, dtype=params_dtype
604
+ num_experts,
605
+ hidden_size,
606
+ intermediate_size_per_partition,
607
+ dtype=params_dtype,
588
608
  ),
589
609
  requires_grad=False,
590
610
  )
@@ -600,7 +620,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
600
620
  w13_weight_scale = torch.nn.Parameter(
601
621
  torch.ones(
602
622
  num_experts,
603
- 2 * ((intermediate_size + block_n - 1) // block_n),
623
+ 2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
604
624
  (hidden_size + block_k - 1) // block_k,
605
625
  dtype=torch.float32,
606
626
  ),
@@ -610,7 +630,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
610
630
  torch.ones(
611
631
  num_experts,
612
632
  (hidden_size + block_n - 1) // block_n,
613
- (intermediate_size + block_k - 1) // block_k,
633
+ (intermediate_size_per_partition + block_k - 1) // block_k,
614
634
  dtype=torch.float32,
615
635
  ),
616
636
  requires_grad=False,
@@ -618,11 +638,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
618
638
  layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
619
639
  layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
620
640
  assert self.quant_config.activation_scheme == "dynamic"
621
- if (
622
- get_bool_env_var("SGLANG_CUTLASS_MOE")
623
- and self.cutlass_fp8_supported
624
- and (is_sm100_supported() or is_sm90_supported())
625
- ):
641
+ if self.use_cutlass_fused_experts_fp8:
626
642
  self.ab_strides1 = torch.full(
627
643
  (num_experts,),
628
644
  hidden_size,
@@ -631,13 +647,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
631
647
  )
632
648
  self.c_strides1 = torch.full(
633
649
  (num_experts,),
634
- 2 * intermediate_size,
650
+ 2 * intermediate_size_per_partition,
635
651
  device=w13_weight.device,
636
652
  dtype=torch.int64,
637
653
  )
638
654
  self.ab_strides2 = torch.full(
639
655
  (num_experts,),
640
- intermediate_size,
656
+ intermediate_size_per_partition,
641
657
  device=w2_weight.device,
642
658
  dtype=torch.int64,
643
659
  )
@@ -690,7 +706,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
690
706
  if _is_hip: # _use_aiter: TODO: add check back after triton kernel
691
707
  # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
692
708
  w13_weight_scale1 = torch.nn.Parameter(
693
- torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
709
+ torch.ones(
710
+ num_experts,
711
+ 2 * intermediate_size_per_partition,
712
+ dtype=torch.float32,
713
+ ),
694
714
  requires_grad=False,
695
715
  )
696
716
  w2_weight_scale1 = torch.nn.Parameter(
@@ -983,14 +1003,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
983
1003
  )
984
1004
  torch.cuda.empty_cache()
985
1005
 
1006
+ def create_moe_runner(
1007
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1008
+ ):
1009
+ self.moe_runner_config = moe_runner_config
1010
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
1011
+
986
1012
  def apply(
987
1013
  self,
988
1014
  layer: torch.nn.Module,
989
- x: torch.Tensor,
990
- topk_output: TopKOutput,
991
- moe_runner_config: MoeRunnerConfig,
992
- ) -> torch.Tensor:
993
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
1015
+ dispatch_output: DispatchOutput,
1016
+ ) -> CombineInput:
1017
+
1018
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1019
+
1020
+ x = dispatch_output.hidden_states
1021
+ topk_output = dispatch_output.topk_output
1022
+ moe_runner_config = self.moe_runner_config
994
1023
 
995
1024
  if use_intel_amx_backend(layer):
996
1025
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
@@ -1000,7 +1029,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1000
1029
  moe_runner_config.apply_router_weight_on_input, topk_weights, x
1001
1030
  )
1002
1031
 
1003
- return torch.ops.sgl_kernel.fused_experts_cpu(
1032
+ output = torch.ops.sgl_kernel.fused_experts_cpu(
1004
1033
  x,
1005
1034
  layer.w13_weight,
1006
1035
  layer.w2_weight,
@@ -1016,6 +1045,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1016
1045
  None, # a2_scale
1017
1046
  True, # is_vnni
1018
1047
  )
1048
+ return StandardCombineInput(hidden_states=output)
1019
1049
 
1020
1050
  if _is_hip:
1021
1051
  ret = self.maybe_apply_hip_fused_experts(
@@ -1026,7 +1056,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1026
1056
  moe_runner_config.no_combine,
1027
1057
  )
1028
1058
  if ret is not None:
1029
- return ret
1059
+ return StandardCombineInput(hidden_states=ret)
1030
1060
 
1031
1061
  if self.use_cutlass_fused_experts_fp8:
1032
1062
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
@@ -1055,17 +1085,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1055
1085
  self.problem_sizes2,
1056
1086
  use_fp8_blockscale=True,
1057
1087
  )
1058
- # Scale by routed_scaling_factor is fused into select_experts.
1059
- return output
1060
- # Expert fusion with FP8 quantization
1061
- return fused_experts(
1062
- x,
1063
- layer.w13_weight,
1064
- layer.w2_weight,
1065
- topk_output=topk_output,
1066
- moe_runner_config=moe_runner_config,
1088
+ return StandardCombineInput(hidden_states=output)
1089
+
1090
+ quant_info = TritonMoeQuantInfo(
1091
+ w13_weight=layer.w13_weight,
1092
+ w2_weight=layer.w2_weight,
1067
1093
  use_fp8_w8a8=True,
1068
- w1_scale=(
1094
+ w13_scale=(
1069
1095
  layer.w13_weight_scale_inv
1070
1096
  if self.block_quant
1071
1097
  else layer.w13_weight_scale
@@ -1073,20 +1099,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1073
1099
  w2_scale=(
1074
1100
  layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
1075
1101
  ),
1076
- a1_scale=layer.w13_input_scale,
1102
+ a13_scale=layer.w13_input_scale,
1077
1103
  a2_scale=layer.w2_input_scale,
1078
1104
  block_shape=self.quant_config.weight_block_size,
1079
1105
  )
1106
+ return self.runner.run(dispatch_output, quant_info)
1080
1107
 
1081
1108
  def apply_with_router_logits(
1082
1109
  self,
1083
1110
  layer: torch.nn.Module,
1084
- x: torch.Tensor,
1085
- topk_output: TopKOutput,
1086
- moe_runner_config: MoeRunnerConfig,
1111
+ dispatch_output: StandardDispatchOutput,
1087
1112
  ) -> torch.Tensor:
1088
- activation = moe_runner_config.activation
1089
- routed_scaling_factor = moe_runner_config.routed_scaling_factor
1113
+ x = dispatch_output.hidden_states
1114
+ topk_output = dispatch_output.topk_output
1115
+
1116
+ activation = self.moe_runner_config.activation
1117
+ routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
1090
1118
 
1091
1119
  from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
1092
1120
 
@@ -1107,10 +1135,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1107
1135
  and topk_config.topk_group is not None
1108
1136
  ), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
1109
1137
 
1110
- if topk_config.correction_bias is None:
1111
- correction_bias = topk_config.correction_bias.to(x.dtype)
1112
- else:
1113
- correction_bias = None
1138
+ correction_bias = (
1139
+ None
1140
+ if topk_config.correction_bias is None
1141
+ else topk_config.correction_bias.to(x.dtype)
1142
+ )
1143
+
1114
1144
  return trtllm_fp8_block_scale_moe(
1115
1145
  routing_logits=router_logits.to(torch.float32),
1116
1146
  routing_bias=correction_bias,
@@ -298,7 +298,7 @@ def _per_token_group_quant_8bit_raw(
298
298
  )
299
299
 
300
300
  if scale_ue8m0:
301
- from deep_gemm.utils.layout import transform_sf_into_required_layout
301
+ from deep_gemm import transform_sf_into_required_layout
302
302
 
303
303
  assert group_size == 128
304
304
  x_s = transform_sf_into_required_layout(
@@ -338,7 +338,7 @@ def _per_token_group_quant_8bit_fuse_silu_and_mul(
338
338
  # scale_ue8m0=scale_ue8m0,
339
339
  # )
340
340
 
341
- from deep_gemm.utils.layout import transform_sf_into_required_layout
341
+ from deep_gemm import transform_sf_into_required_layout
342
342
 
343
343
  from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
344
344
 
@@ -5,7 +5,7 @@ import torch
5
5
  from sglang.srt.layers.quantization import deep_gemm_wrapper
6
6
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
7
7
  from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
8
- from sglang.srt.layers.utils import is_sm100_supported
8
+ from sglang.srt.utils import is_sm100_supported
9
9
 
10
10
  try:
11
11
  from vllm import _custom_ops as ops
@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
45
45
 
46
46
  if _use_aiter:
47
47
  import aiter
48
- from aiter import gemm_a8w8_blockscale, get_hip_quant
48
+ from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant
49
49
 
50
50
  aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
51
51
 
@@ -248,11 +248,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
248
248
  scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
249
249
  )
250
250
 
251
- # NOTE(alcanderian): Useless when scale is packed to int32
252
- # if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
253
- # _check_ue8m0("x_scale", x_scale)
254
- # _check_ue8m0("weight_scale", ws)
255
-
256
251
  output = w8a8_block_fp8_matmul_deepgemm(
257
252
  q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
258
253
  )
@@ -261,11 +256,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
261
256
  return output.to(dtype=output_dtype).view(*output_shape)
262
257
 
263
258
 
264
- def _check_ue8m0(name, x):
265
- x_ceil = ceil_to_ue8m0(x)
266
- assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
267
-
268
-
269
259
  def aiter_w8a8_block_fp8_linear(
270
260
  input: torch.Tensor,
271
261
  weight: torch.Tensor,
@@ -459,7 +449,7 @@ def _requant_weight_ue8m0(
459
449
  import deep_gemm.utils.layout
460
450
 
461
451
  sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
462
- sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf)
452
+ sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
463
453
  return sf
464
454
 
465
455
  out_s = _transform_scale(out_s, mn=out_w.shape[-2])
@@ -652,25 +642,49 @@ def apply_fp8_linear(
652
642
  use_per_token_if_dynamic
653
643
  and not per_tensor_weights
654
644
  and not per_tensor_activations
655
- and USE_ROWWISE_TORCH_SCALED_MM
645
+ and (USE_ROWWISE_TORCH_SCALED_MM or _use_aiter)
656
646
  ):
657
- # For now validated on ROCm platform
658
- # fp8 rowwise scaling in torch._scaled_mm is introduced in
659
- # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
660
- # and ROCm 6.3, which only exists in torch 2.7 and above.
661
- # For CUDA platform please validate if the
662
- # torch._scaled_mm support rowwise scaled GEMM
663
- # Fused GEMM_DQ Rowwise GEMM
664
- output = torch._scaled_mm(
665
- qinput,
666
- weight,
667
- out_dtype=input.dtype,
668
- scale_a=x_scale,
669
- scale_b=weight_scale.t(),
670
- bias=bias,
671
- )
672
- return _process_scaled_mm_output(output, input_2d.shape, output_shape)
673
-
647
+ # into this sector means use dynamic per-token-per-channel quant
648
+ # per-token scale quant for input matrix, every row(one token) have one scale factor
649
+ # per-channel scale quant for weight matrix, every col(one channel) have one scale factor
650
+ if _use_aiter:
651
+ # gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype)
652
+ # XQ -> input tensor, shape = (m, k)
653
+ # WQ -> weight tensor, shape = (n, k), with preshuffe get better perf
654
+ # x_scale -> input scale tensor, shape = (m, 1)
655
+ # w_scale -> weight scale tensor, shape = (n ,1)
656
+ # dtype -> output dtype
657
+ output = gemm_a8w8_bpreshuffle(
658
+ XQ=qinput,
659
+ WQ=weight,
660
+ x_scale=x_scale,
661
+ w_scale=weight_scale,
662
+ dtype=input.dtype,
663
+ )
664
+ if bias is not None:
665
+ output += bias
666
+ return _process_scaled_mm_output(
667
+ output, input_2d.shape, [*input.shape[:-1], weight.shape[0]]
668
+ )
669
+ else:
670
+ # For now validated on ROCm platform
671
+ # fp8 rowwise scaling in torch._scaled_mm is introduced in
672
+ # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
673
+ # and ROCm 6.3, which only exists in torch 2.7 and above.
674
+ # For CUDA platform please validate if the
675
+ # torch._scaled_mm support rowwise scaled GEMM
676
+ # Fused GEMM_DQ Rowwise GEMM
677
+ output = torch._scaled_mm(
678
+ qinput,
679
+ weight,
680
+ out_dtype=input.dtype,
681
+ scale_a=x_scale,
682
+ scale_b=weight_scale.t(),
683
+ bias=bias,
684
+ )
685
+ return _process_scaled_mm_output(
686
+ output, input_2d.shape, output_shape
687
+ )
674
688
  else:
675
689
  # Fallback for channelwise case, where we use unfused DQ
676
690
  # due to limitations with scaled_mm
@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
45
45
 
46
46
  if TYPE_CHECKING:
47
47
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
48
- from sglang.srt.layers.moe.topk import TopKOutput
48
+ from sglang.srt.layers.moe.token_dispatcher import (
49
+ StandardDispatchOutput,
50
+ CombineInput,
51
+ )
49
52
 
50
53
  from sglang.srt.utils import is_cuda
51
54
 
@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
838
841
  from sglang.srt.layers.linear import set_weight_attrs
839
842
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
840
843
 
841
- intermediate_size = extra_weight_attrs.pop("intermediate_size")
842
-
843
- self.is_k_full = (not self.quant_config.desc_act) or (
844
- intermediate_size_per_partition == intermediate_size
845
- )
844
+ self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1
846
845
 
847
846
  if self.quant_config.group_size != -1:
848
847
  scales_size13 = hidden_size // self.quant_config.group_size
849
- w2_scales_size = (
850
- intermediate_size
851
- if self.quant_config.desc_act
852
- else intermediate_size_per_partition
853
- )
848
+ if self.quant_config.desc_act:
849
+ w2_scales_size = intermediate_size_per_partition
850
+ else:
851
+ w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
854
852
  scales_size2 = w2_scales_size // self.quant_config.group_size
855
853
  strategy = FusedMoeWeightScaleSupported.GROUP.value
856
854
  else:
@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1052
1050
  )
1053
1051
  replace_parameter(layer, "w2_scales", marlin_w2_scales)
1054
1052
 
1053
+ def create_moe_runner(
1054
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1055
+ ):
1056
+ self.moe_runner_config = moe_runner_config
1057
+
1055
1058
  def apply(
1056
1059
  self,
1057
1060
  layer: torch.nn.Module,
1058
- x: torch.Tensor,
1059
- topk_output: TopKOutput,
1060
- moe_runner_config: MoeRunnerConfig,
1061
- ) -> torch.Tensor:
1061
+ dispatch_output: StandardDispatchOutput,
1062
+ ) -> CombineInput:
1063
+
1064
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1065
+
1066
+ x = dispatch_output.hidden_states
1067
+ topk_output = dispatch_output.topk_output
1068
+
1062
1069
  # Delay the import to avoid circular dependency
1063
1070
 
1064
1071
  assert (
1065
- moe_runner_config.activation == "silu"
1072
+ self.moe_runner_config.activation == "silu"
1066
1073
  ), "Only SiLU activation is supported."
1067
1074
 
1068
1075
  # The input must currently be float16
@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1071
1078
 
1072
1079
  topk_weights, topk_ids, router_logits = topk_output
1073
1080
 
1074
- return fused_marlin_moe(
1081
+ output = fused_marlin_moe(
1075
1082
  x,
1076
1083
  layer.w13_qweight,
1077
1084
  layer.w2_qweight,
@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1087
1094
  num_bits=self.quant_config.weight_bits,
1088
1095
  is_k_full=self.is_k_full,
1089
1096
  ).to(orig_dtype)
1097
+ return StandardCombineInput(hidden_states=output)