sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional
22
22
  import torch
23
23
  from torch.nn.parameter import Parameter
24
24
 
25
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
26
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
25
27
  from sglang.srt.layers.moe.utils import get_moe_runner_backend
26
28
  from sglang.srt.layers.quantization.base_config import (
27
29
  FusedMoEMethodBase,
@@ -59,8 +61,10 @@ if is_flashinfer_available():
59
61
  logger = logging.getLogger(__name__)
60
62
 
61
63
  if TYPE_CHECKING:
62
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
63
- from sglang.srt.layers.moe.topk import TopKOutput
64
+ from sglang.srt.layers.moe.token_dispatcher import (
65
+ CombineInput,
66
+ StandardDispatchOutput,
67
+ )
64
68
 
65
69
  _is_hip = is_hip()
66
70
 
@@ -283,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
283
287
  layer: torch.nn.Module,
284
288
  num_experts: int,
285
289
  hidden_size: int,
286
- intermediate_size: int,
290
+ intermediate_size_per_partition: int,
287
291
  params_dtype: torch.dtype,
288
292
  with_bias: bool = False,
289
293
  **extra_weight_attrs,
@@ -296,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
296
300
 
297
301
  # pad the intermediate size to be a multiple of 2 * mxfp4_block
298
302
  # for to hold non-uniform sharded tensor as well as swizzling
299
- intermediate_size_per_partition_after_pad = intermediate_size
303
+ intermediate_size_per_partition_after_pad = intermediate_size_per_partition
300
304
  if _is_sm100_supported:
301
305
  if self.use_flashinfer:
302
306
  intermediate_size_per_partition_after_pad = round_up(
303
- intermediate_size, 256
307
+ intermediate_size_per_partition, 256
304
308
  )
305
309
  hidden_size = round_up(hidden_size, 256)
306
310
  else:
307
311
  intermediate_size_per_partition_after_pad = round_up(
308
- intermediate_size, 64
312
+ intermediate_size_per_partition, 64
309
313
  )
310
314
  elif has_triton_kernels:
311
315
  # TODO: this is a hack to make
312
316
  # intermediate_size_per_partition_after_pad the same as the
313
317
  # per_rank_intermediate_size during weight loading
314
318
  intermediate_size_per_partition_after_pad = round_up(
315
- intermediate_size, mxfp4_block
319
+ intermediate_size_per_partition, mxfp4_block
316
320
  )
317
321
 
318
- self.intermediate_size = intermediate_size_per_partition_after_pad
322
+ self.intermediate_size_per_partition = intermediate_size_per_partition_after_pad
319
323
 
320
324
  self.hidden_size = hidden_size
321
325
  # Fused gate_up_proj (column parallel)
@@ -410,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
410
414
  assert (
411
415
  layer.w13_weight.dim() == 3
412
416
  and layer.w13_weight.shape[0] == self.num_experts
413
- and layer.w13_weight.shape[1] == self.intermediate_size * 2
417
+ and layer.w13_weight.shape[1]
418
+ == self.intermediate_size_per_partition * 2
414
419
  and layer.w13_weight.shape[2] == self.hidden_size // 2
415
420
  )
416
421
  assert (
417
422
  layer.w13_weight_scale.dim() == 3
418
423
  and layer.w13_weight_scale.shape[0] == self.num_experts
419
- and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
424
+ and layer.w13_weight_scale.shape[1]
425
+ == self.intermediate_size_per_partition * 2
420
426
  and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
421
427
  )
422
428
  assert (
423
429
  layer.w2_weight.dim() == 3
424
430
  and layer.w2_weight.shape[0] == self.num_experts
425
431
  and layer.w2_weight.shape[1] == self.hidden_size
426
- and layer.w2_weight.shape[2] == self.intermediate_size // 2
432
+ and layer.w2_weight.shape[2]
433
+ == self.intermediate_size_per_partition // 2
427
434
  )
428
435
  assert (
429
436
  layer.w2_weight_scale.dim() == 3
430
437
  and layer.w2_weight_scale.shape[1] == self.hidden_size
431
438
  and layer.w2_weight_scale.shape[2]
432
- == self.intermediate_size // sf_block_size
439
+ == self.intermediate_size_per_partition // sf_block_size
433
440
  )
434
441
  assert (
435
442
  layer.w13_weight_bias.dim() == 2
436
443
  and layer.w13_weight_bias.shape[0] == self.num_experts
437
- and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2
444
+ and layer.w13_weight_bias.shape[1]
445
+ == self.intermediate_size_per_partition * 2
438
446
  )
439
447
  assert (
440
448
  layer.w2_weight_bias.dim() == 2
@@ -511,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
511
519
  torch.stack(gemm1_scales_mxfp4_shuffled)
512
520
  .reshape(
513
521
  self.num_experts,
514
- 2 * self.intermediate_size,
522
+ 2 * self.intermediate_size_per_partition,
515
523
  self.hidden_size // sf_block_size,
516
524
  )
517
525
  .view(torch.float8_e4m3fn)
@@ -523,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
523
531
  .reshape(
524
532
  self.num_experts,
525
533
  self.hidden_size,
526
- self.intermediate_size // sf_block_size,
534
+ self.intermediate_size_per_partition // sf_block_size,
527
535
  )
528
536
  .view(torch.float8_e4m3fn)
529
537
  )
@@ -613,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
613
621
 
614
622
  return tile_tokens_dim
615
623
 
624
+ def create_moe_runner(
625
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
626
+ ):
627
+ self.moe_runner_config = moe_runner_config
628
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
629
+
616
630
  def apply(
617
631
  self,
618
632
  layer: torch.nn.Module,
619
- x: torch.Tensor,
620
- topk_output: TopKOutput,
621
- moe_runner_config: MoeRunnerConfig,
622
- ) -> torch.Tensor:
633
+ dispatch_output: StandardDispatchOutput,
634
+ ) -> CombineInput:
623
635
 
636
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
624
637
  from sglang.srt.layers.moe.topk import TopKOutputChecker
625
638
 
639
+ x = dispatch_output.hidden_states
640
+ topk_output = dispatch_output.topk_output
641
+
642
+ moe_runner_config = self.moe_runner_config
643
+
626
644
  if self.use_flashinfer:
627
645
  # When bf16 mode is enabled, we don't need to quantize the input,
628
646
  # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
@@ -674,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
674
692
  top_k,
675
693
  None, # n_group # TODO: support n_group
676
694
  None, # topk_group # TODO: support topk_group
677
- self.intermediate_size, # padded to multiple of 256
695
+ self.intermediate_size_per_partition, # padded to multiple of 256
678
696
  layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
679
697
  layer.num_local_experts, # local num experts
680
698
  None,
@@ -682,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
682
700
  1, # routing_method_type, renormalize
683
701
  True, # do finalize
684
702
  )[0]
685
- return trtllm_gen_output
703
+ return StandardCombineInput(hidden_states=trtllm_gen_output)
686
704
 
687
705
  if self.use_triton_kernels:
688
706
  assert (
689
707
  layer.moe_ep_size == 1
690
708
  ), "Expert parallel is not supported when using triton kernels"
691
709
  if self.with_bias:
692
- return self.triton_kernel_moe_with_bias_forward(
710
+ output = self.triton_kernel_moe_with_bias_forward(
693
711
  hidden_states=x,
694
712
  w1=self.w13_weight_triton_tensor,
695
713
  w1_pcg=self.w13_precision_config,
@@ -701,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
701
719
  moe_runner_config=moe_runner_config,
702
720
  )
703
721
  else:
704
- return self.triton_kernel_moe_forward(
722
+ output = self.triton_kernel_moe_forward(
705
723
  hidden_states=x,
706
724
  w1=layer.w13_weight,
707
725
  w2=layer.w2_weight,
708
726
  topk_output=topk_output,
709
727
  moe_runner_config=moe_runner_config,
710
728
  )
729
+ return StandardCombineInput(hidden_states=output)
711
730
  else:
712
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
713
-
714
- return fused_experts(
715
- hidden_states=x,
716
- w1=layer.w13_weight,
717
- w2=layer.w2_weight,
718
- topk_output=topk_output,
719
- moe_runner_config=moe_runner_config,
720
- b1=layer.w13_weight_bias,
721
- b2=layer.w2_weight_bias,
731
+ quant_info = TritonMoeQuantInfo(
732
+ w13_weight=layer.w13_weight,
733
+ w2_weight=layer.w2_weight,
734
+ w13_weight_bias=layer.w13_weight_bias,
735
+ w2_weight_bias=layer.w2_weight_bias,
722
736
  )
737
+ return self.runner.run(dispatch_output, quant_info)
723
738
 
724
739
 
725
740
  class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
@@ -798,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
798
813
 
799
814
  return w, mx_scales
800
815
 
801
- def process_weights_after_loading(self, layer: Module) -> None:
816
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
802
817
  w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
803
818
  w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
804
819
 
@@ -808,19 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
808
823
  layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
809
824
  layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
810
825
 
826
+ def create_moe_runner(
827
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
828
+ ):
829
+ self.moe_runner_config = moe_runner_config
830
+
811
831
  def apply(
812
832
  self,
813
833
  layer: torch.nn.Module,
814
- x: torch.Tensor,
815
- topk_output: TopKOutput,
816
- moe_runner_config: MoeRunnerConfig,
817
- ) -> torch.Tensor:
834
+ dispatch_output: StandardDispatchOutput,
835
+ ) -> CombineInput:
836
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
837
+
838
+ x = dispatch_output.hidden_states
839
+ topk_output = dispatch_output.topk_output
840
+
818
841
  topk_weights, topk_ids, _ = topk_output
819
842
  if _is_hip:
820
843
  topk_weights = topk_weights.to(
821
844
  torch.float32
822
845
  ) # aiter's moe_sorting requires topk_weights to be FP32
823
- return fused_moe(
846
+ output = fused_moe(
824
847
  x,
825
848
  layer.w13_weight,
826
849
  layer.w2_weight,
@@ -831,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
831
854
  w2_scale=layer.w2_weight_scale,
832
855
  activation=(
833
856
  ActivationType.Silu
834
- if moe_runner_config.activation == "silu"
857
+ if self.moe_runner_config.activation == "silu"
835
858
  else ActivationType.Gelu
836
859
  ),
837
860
  doweight_stage1=False,
838
861
  )
862
+ return StandardCombineInput(hidden_states=output)
@@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk
10
10
  from aiter.fused_moe import fused_moe
11
11
  from aiter.utility.fp4_utils import e8m0_shuffle
12
12
 
13
+ from sglang.srt.layers.moe import MoeRunnerConfig
14
+ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
13
15
  from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
14
16
 
17
+ if TYPE_CHECKING:
18
+ from sglang.srt.layers.moe.token_dispatcher import (
19
+ CombineInput,
20
+ StandardDispatchOutput,
21
+ )
22
+ from sglang.srt.layers.quantization.quark.quark import QuarkConfig
23
+
15
24
  logger = logging.getLogger(__name__)
16
25
 
17
26
  __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
@@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
19
28
  OCP_MX_BLOCK_SIZE = 32
20
29
 
21
30
  if TYPE_CHECKING:
22
- from sglang.srt.layers.moe.topk import TopKOutput
23
-
24
-
25
- class QuarkMoEMethod:
26
- def __new__(cls, *args, **kwargs):
27
- from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
28
-
29
- if not hasattr(cls, "_initialized"):
30
- original_init = cls.__init__
31
- new_cls = type(
32
- cls.__name__,
33
- (FusedMoEMethodBase,),
34
- {
35
- "__init__": original_init,
36
- **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
37
- },
38
- )
39
- obj = super(new_cls, new_cls).__new__(new_cls)
40
- obj.__init__(*args, **kwargs)
41
- return obj
42
- return super().__new__(cls)
31
+ from sglang.srt.layers.quantization import QuarkConfig
32
+
33
+
34
+ class QuarkMoEMethod(FusedMoEMethodBase):
35
+
36
+ def __init__(self, quant_config: QuarkConfig):
37
+ self.quant_config = quant_config
43
38
 
44
39
  @staticmethod
45
40
  def get_moe_method(
46
- quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
41
+ quant_config: QuarkConfig, # type: ignore # noqa E501 # noqa F821
47
42
  module: torch.nn.Module,
48
43
  layer_name: str,
49
44
  ) -> "QuarkMoEMethod":
@@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
170
165
  # layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
171
166
  layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
172
167
 
168
+ def create_moe_runner(
169
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
170
+ ):
171
+ self.moe_runner_config = moe_runner_config
172
+
173
173
  def apply(
174
174
  self,
175
175
  layer: torch.nn.Module,
176
- x: torch.Tensor,
177
- topk_output: TopKOutput,
178
- moe_runner_config: MoeRunnerConfig,
179
- ) -> torch.Tensor:
176
+ dispatch_output: StandardDispatchOutput,
177
+ ) -> CombineInput:
178
+
179
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
180
+
181
+ x = dispatch_output.hidden_states
182
+ topk_output = dispatch_output.topk_output
183
+ moe_runner_config = self.moe_runner_config
180
184
  topk_weights, topk_ids, _ = topk_output
181
185
 
182
- return fused_moe(
186
+ output = fused_moe(
183
187
  x,
184
188
  layer.w13_weight,
185
189
  layer.w2_weight,
@@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
195
199
  ),
196
200
  doweight_stage1=False,
197
201
  )
202
+ return StandardCombineInput(hidden_states=output)
@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
9
9
 
10
10
  from sglang.srt.custom_op import CustomOp
11
11
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
12
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
13
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
12
14
  from sglang.srt.layers.quantization.base_config import (
13
15
  FusedMoEMethodBase,
14
16
  LinearMethodBase,
@@ -24,8 +26,10 @@ from sglang.srt.utils import (
24
26
  )
25
27
 
26
28
  if TYPE_CHECKING:
27
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
28
- from sglang.srt.layers.moe.topk import TopKOutput
29
+ from sglang.srt.layers.moe.token_dispatcher import (
30
+ CombineInput,
31
+ StandardDispatchOutput,
32
+ )
29
33
 
30
34
  has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
31
35
 
@@ -155,7 +159,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
155
159
  layer: torch.nn.Module,
156
160
  num_experts: int,
157
161
  hidden_size: int,
158
- intermediate_size: int,
162
+ intermediate_size_per_partition: int,
159
163
  params_dtype: torch.dtype,
160
164
  with_bias: bool = False,
161
165
  **extra_weight_attrs,
@@ -163,7 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
163
167
  self.with_bias = with_bias
164
168
 
165
169
  # Fused gate_up_proj (column parallel)
166
- w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
170
+ w13_weight_n, w13_weight_k = 2 * intermediate_size_per_partition, hidden_size
167
171
  if self.use_triton_kernels:
168
172
  w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
169
173
  w13_weight = torch.nn.Parameter(
@@ -175,7 +179,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
175
179
 
176
180
  if self.with_bias:
177
181
  w13_weight_bias = torch.nn.Parameter(
178
- torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32),
182
+ torch.empty(
183
+ num_experts,
184
+ 2 * intermediate_size_per_partition,
185
+ dtype=torch.float32,
186
+ ),
179
187
  requires_grad=False,
180
188
  )
181
189
  layer.register_parameter("w13_weight_bias", w13_weight_bias)
@@ -184,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
184
192
  # down_proj (row parallel)
185
193
  w2_weight_n, w2_weight_k = (
186
194
  hidden_size,
187
- intermediate_size,
195
+ intermediate_size_per_partition,
188
196
  )
189
197
  if self.use_triton_kernels:
190
198
  w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
@@ -222,33 +230,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
222
230
 
223
231
  return
224
232
 
233
+ def create_moe_runner(
234
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
235
+ ):
236
+ self.moe_runner_config = moe_runner_config
237
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
238
+
225
239
  def apply(
226
240
  self,
227
241
  layer: torch.nn.Module,
228
- x: torch.Tensor,
229
- topk_output: TopKOutput,
230
- moe_runner_config: MoeRunnerConfig,
231
- ) -> torch.Tensor:
242
+ dispatch_output: StandardDispatchOutput,
243
+ ) -> CombineInput:
232
244
 
233
245
  return self.forward(
234
- x=x,
235
246
  layer=layer,
236
- topk_output=topk_output,
237
- moe_runner_config=moe_runner_config,
247
+ dispatch_output=dispatch_output,
238
248
  )
239
249
 
240
250
  def forward_cuda(
241
251
  self,
242
252
  layer: torch.nn.Module,
243
- x: torch.Tensor,
244
- topk_output: TopKOutput,
245
- moe_runner_config: MoeRunnerConfig,
246
- ) -> torch.Tensor:
253
+ dispatch_output: StandardDispatchOutput,
254
+ ) -> CombineInput:
255
+
256
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
257
+
258
+ x = dispatch_output.hidden_states
259
+ topk_output = dispatch_output.topk_output
260
+
261
+ moe_runner_config = self.moe_runner_config
247
262
 
248
263
  if self.use_triton_kernels:
249
264
  if self.with_bias:
250
265
  assert self.triton_kernel_moe_with_bias_forward is not None
251
- return self.triton_kernel_moe_with_bias_forward(
266
+ output = self.triton_kernel_moe_with_bias_forward(
252
267
  hidden_states=x,
253
268
  w1=layer.w13_weight,
254
269
  w2=layer.w2_weight,
@@ -261,13 +276,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
261
276
  )
262
277
  else:
263
278
  assert self.triton_kernel_moe_forward is not None
264
- return self.triton_kernel_moe_forward(
279
+ output = self.triton_kernel_moe_forward(
265
280
  hidden_states=x,
266
281
  w1=layer.w13_weight,
267
282
  w2=layer.w2_weight,
268
283
  topk_output=topk_output,
269
284
  moe_runner_config=moe_runner_config,
270
285
  )
286
+ return StandardCombineInput(hidden_states=output)
271
287
  else:
272
288
  if _use_aiter:
273
289
  assert not moe_runner_config.no_combine, "unsupported"
@@ -284,7 +300,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
284
300
  topk_weights = torch.ones_like(
285
301
  topk_weights, dtype=torch.float32
286
302
  ) # topk_weights must be FP32 (float32)
287
- return fused_moe(
303
+ output = fused_moe(
288
304
  x,
289
305
  layer.w13_weight,
290
306
  layer.w2_weight,
@@ -296,28 +312,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
296
312
  else ActivationType.Gelu
297
313
  ),
298
314
  )
315
+ return StandardCombineInput(hidden_states=output)
299
316
  else:
300
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
301
- fused_experts,
302
- )
303
317
 
304
- return fused_experts(
305
- hidden_states=x,
306
- w1=layer.w13_weight,
307
- w2=layer.w2_weight,
308
- b1=getattr(layer, "w13_weight_bias", None),
318
+ quant_info = TritonMoeQuantInfo(
319
+ w13_weight=layer.w13_weight,
320
+ w2_weight=layer.w2_weight,
321
+ b13=getattr(layer, "w13_weight_bias", None),
309
322
  b2=getattr(layer, "w2_weight_bias", None),
310
- topk_output=topk_output,
311
- moe_runner_config=moe_runner_config,
312
323
  )
324
+ return self.runner.run(dispatch_output, quant_info)
313
325
 
314
326
  def forward_cpu(
315
327
  self,
316
328
  layer: torch.nn.Module,
317
- x: torch.Tensor,
318
- topk_output: TopKOutput,
319
- moe_runner_config: MoeRunnerConfig,
320
- ) -> torch.Tensor:
329
+ dispatch_output: StandardDispatchOutput,
330
+ ) -> CombineInput:
331
+
332
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
333
+
334
+ x = dispatch_output.hidden_states
335
+ topk_output = dispatch_output.topk_output
336
+
337
+ moe_runner_config = self.moe_runner_config
338
+
321
339
  assert (
322
340
  moe_runner_config.activation == "silu"
323
341
  ), f"activation = {moe_runner_config.activation} is not supported."
@@ -332,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
332
350
  x, topk_weights = apply_topk_weights_cpu(
333
351
  moe_runner_config.apply_router_weight_on_input, topk_weights, x
334
352
  )
335
- return torch.ops.sgl_kernel.fused_experts_cpu(
353
+ output = torch.ops.sgl_kernel.fused_experts_cpu(
336
354
  x,
337
355
  layer.w13_weight,
338
356
  layer.w2_weight,
@@ -348,33 +366,103 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
348
366
  None, # a2_scale
349
367
  True, # is_vnni
350
368
  )
369
+ return StandardCombineInput(hidden_states=output)
351
370
  else:
352
371
  from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
353
372
 
354
- return moe_forward_native(
373
+ output = moe_forward_native(
355
374
  layer,
356
375
  x,
357
376
  topk_output,
358
377
  moe_runner_config,
359
378
  )
379
+ return StandardCombineInput(hidden_states=output)
360
380
 
361
381
  def forward_npu(
362
382
  self,
363
383
  layer: torch.nn.Module,
364
- x: torch.Tensor,
365
- topk_output: TopKOutput,
366
- moe_runner_config: MoeRunnerConfig,
367
- ) -> torch.Tensor:
368
- from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
384
+ dispatch_output: StandardDispatchOutput,
385
+ ) -> CombineInput:
386
+
387
+ import torch_npu
388
+
389
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
390
+
391
+ x = dispatch_output.hidden_states
392
+ topk_weights, topk_ids, _ = dispatch_output.topk_output
393
+
394
+ original_dtype = x.dtype
395
+ num_tokens = x.shape[0]
396
+ topk_weights = topk_weights.to(x.dtype)
397
+ topk_ids = topk_ids.to(torch.int32)
398
+ num_experts = layer.num_experts
399
+ top_k = layer.top_k
400
+ row_idx_len = num_tokens * top_k
401
+ row_idx = (
402
+ torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)
403
+ .view(top_k, -1)
404
+ .permute(1, 0)
405
+ .contiguous()
406
+ )
369
407
 
370
- return moe_forward_native(
371
- layer,
372
- x,
373
- topk_output,
374
- moe_runner_config,
408
+ hidden_states, expanded_row_idx, expanded_expert_idx = (
409
+ torch_npu.npu_moe_init_routing(
410
+ x, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens
411
+ )
412
+ )
413
+
414
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
415
+ expanded_expert_idx, num_experts
375
416
  )
376
417
 
377
- def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
418
+ expert_tokens = expert_tokens.to(torch.int64)
419
+ if layer.w13_weight.shape[-1] == layer.hidden_size:
420
+ w13 = layer.w13_weight.transpose(1, 2)
421
+ w2 = layer.w2_weight.transpose(1, 2)
422
+
423
+ # gmm1: gate_up_proj
424
+ hidden_states = torch_npu.npu_grouped_matmul(
425
+ x=[hidden_states],
426
+ weight=[w13],
427
+ split_item=2,
428
+ group_list_type=0,
429
+ group_type=0,
430
+ group_list=expert_tokens,
431
+ output_dtype=original_dtype,
432
+ )[0]
433
+
434
+ # act_fn:
435
+ if self.moe_runner_config.activation == "silu":
436
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
437
+ else:
438
+ from sglang.srt.layers.activation import GeluAndMul
439
+
440
+ hidden_states = GeluAndMul()(hidden_states)
441
+
442
+ # gmm2: down_proj
443
+ hidden_states = torch_npu.npu_grouped_matmul(
444
+ x=[hidden_states],
445
+ weight=[w2],
446
+ split_item=2,
447
+ group_list_type=0,
448
+ group_type=0,
449
+ group_list=expert_tokens,
450
+ output_dtype=original_dtype,
451
+ )[0]
452
+
453
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
454
+ hidden_states,
455
+ skip1=None,
456
+ skip2=None,
457
+ bias=None,
458
+ scales=topk_weights,
459
+ expanded_src_to_dst_row=expanded_row_idx,
460
+ export_for_source_row=topk_ids,
461
+ )
462
+
463
+ return StandardCombineInput(hidden_states=final_hidden_states)
464
+
465
+ def forward_tpu(self, *args, **kwargs) -> CombineInput:
378
466
  raise NotImplementedError("The TPU backend currently does not support MoE.")
379
467
 
380
468
  forward_native = forward_cpu