sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -17,12 +17,19 @@ from sglang.srt.layers.quantization.base_config import (
17
17
  from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
18
18
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
19
19
  from sglang.srt.layers.quantization.utils import is_layer_skipped
20
- from sglang.srt.utils import set_weight_attrs
20
+ from sglang.srt.utils import is_npu, set_weight_attrs
21
+
22
+ _is_npu = is_npu()
23
+ if not _is_npu:
24
+ from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
21
25
 
22
26
  if TYPE_CHECKING:
23
27
  from sglang.srt.layers.moe import MoeRunnerConfig
24
28
  from sglang.srt.layers.moe.ep_moe.layer import EPMoE
25
- from sglang.srt.layers.moe.topk import StandardTopKOutput
29
+ from sglang.srt.layers.moe.token_dispatcher import (
30
+ CombineInput,
31
+ StandardDispatchOutput,
32
+ )
26
33
 
27
34
  ACTIVATION_SCHEMES = ["static", "dynamic"]
28
35
 
@@ -133,7 +140,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
133
140
  layer: EPMoE,
134
141
  num_experts: int,
135
142
  hidden_size: int,
136
- intermediate_size: int,
143
+ intermediate_size_per_partition: int,
137
144
  params_dtype: torch.dtype,
138
145
  **extra_weight_attrs,
139
146
  ):
@@ -145,7 +152,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
145
152
  w13_weight = torch.nn.Parameter(
146
153
  torch.empty(
147
154
  num_experts,
148
- intermediate_size * 2,
155
+ intermediate_size_per_partition * 2,
149
156
  hidden_size // 2,
150
157
  dtype=torch.int8,
151
158
  ),
@@ -159,7 +166,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
159
166
  torch.empty(
160
167
  num_experts,
161
168
  hidden_size,
162
- intermediate_size // 2,
169
+ intermediate_size_per_partition // 2,
163
170
  dtype=torch.int8,
164
171
  ),
165
172
  requires_grad=False,
@@ -173,7 +180,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
173
180
  w13_weight_scale = torch.nn.Parameter(
174
181
  torch.zeros(
175
182
  num_experts,
176
- 2 * intermediate_size,
183
+ 2 * intermediate_size_per_partition,
177
184
  hidden_size // self.quant_config.group_size,
178
185
  dtype=torch.float32,
179
186
  ),
@@ -186,7 +193,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
186
193
  torch.zeros(
187
194
  num_experts,
188
195
  hidden_size,
189
- intermediate_size // self.quant_config.group_size,
196
+ intermediate_size_per_partition // self.quant_config.group_size,
190
197
  dtype=torch.float32,
191
198
  ),
192
199
  requires_grad=False,
@@ -220,13 +227,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
220
227
  )
221
228
  self.c_strides1 = torch.full(
222
229
  (num_experts, 3),
223
- 2 * intermediate_size,
230
+ 2 * intermediate_size_per_partition,
224
231
  device=device,
225
232
  dtype=torch.int64,
226
233
  )
227
234
  self.a_strides2 = torch.full(
228
235
  (num_experts, 3),
229
- intermediate_size,
236
+ intermediate_size_per_partition,
230
237
  device=device,
231
238
  dtype=torch.int64,
232
239
  )
@@ -282,16 +289,22 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
282
289
  )
283
290
  layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
284
291
 
292
+ def create_moe_runner(
293
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
294
+ ):
295
+ self.moe_runner_config = moe_runner_config
296
+
285
297
  def apply(
286
298
  self,
287
299
  layer: EPMoE,
288
- x: torch.Tensor,
289
- topk_output: StandardTopKOutput,
290
- moe_runner_config: MoeRunnerConfig,
291
- ) -> torch.Tensor:
300
+ dispatch_output: StandardDispatchOutput,
301
+ ) -> CombineInput:
292
302
 
293
- # TODO(ch-wan): move it out of this class
294
303
  from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
304
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
305
+
306
+ x = dispatch_output.hidden_states
307
+ topk_output = dispatch_output.topk_output
295
308
 
296
309
  topk_weights, topk_ids, _ = topk_output
297
310
  local_topk_ids = topk_ids
@@ -328,6 +341,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
328
341
  layer.w13_input_scale,
329
342
  layer.w2_input_scale,
330
343
  )
331
- if moe_runner_config.routed_scaling_factor is not None:
332
- output *= moe_runner_config.routed_scaling_factor
333
- return output
344
+ if self.moe_runner_config.routed_scaling_factor is not None:
345
+ output *= self.moe_runner_config.routed_scaling_factor
346
+ return StandardCombineInput(hidden_states=output)
@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
5
5
  import torch
6
6
  from torch.nn.parameter import Parameter
7
7
 
8
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
9
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
8
10
  from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
9
11
  from sglang.srt.layers.quantization.base_config import (
10
12
  FusedMoEMethodBase,
@@ -26,8 +28,10 @@ from sglang.srt.layers.quantization.fp8_utils import (
26
28
  from sglang.srt.utils import set_weight_attrs
27
29
 
28
30
  if TYPE_CHECKING:
29
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
30
- from sglang.srt.layers.moe.topk import StandardTopKOutput
31
+ from sglang.srt.layers.moe.token_dispatcher import (
32
+ CombineInput,
33
+ StandardDispatchOutput,
34
+ )
31
35
 
32
36
  _is_fp8_fnuz = is_fp8_fnuz()
33
37
 
@@ -209,7 +213,7 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
209
213
  layer: torch.nn.Module,
210
214
  num_experts: int,
211
215
  hidden_size: int,
212
- intermediate_size: int,
216
+ intermediate_size_per_partition: int,
213
217
  params_dtype: torch.dtype,
214
218
  **extra_weight_attrs,
215
219
  ):
@@ -218,7 +222,10 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
218
222
  # WEIGHTS
219
223
  w13_weight = torch.nn.Parameter(
220
224
  torch.empty(
221
- num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype
225
+ num_experts,
226
+ 2 * intermediate_size_per_partition,
227
+ hidden_size,
228
+ dtype=fp8_dtype,
222
229
  ),
223
230
  requires_grad=False,
224
231
  )
@@ -226,14 +233,21 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
226
233
  set_weight_attrs(w13_weight, extra_weight_attrs)
227
234
 
228
235
  w2_weight = torch.nn.Parameter(
229
- torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype),
236
+ torch.empty(
237
+ num_experts,
238
+ hidden_size,
239
+ intermediate_size_per_partition,
240
+ dtype=fp8_dtype,
241
+ ),
230
242
  requires_grad=False,
231
243
  )
232
244
  layer.register_parameter("w2_weight", w2_weight)
233
245
  set_weight_attrs(w2_weight, extra_weight_attrs)
234
246
 
235
247
  w13_weight_scale = torch.nn.Parameter(
236
- torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
248
+ torch.ones(
249
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
250
+ ),
237
251
  requires_grad=False,
238
252
  )
239
253
  w2_weight_scale = torch.nn.Parameter(
@@ -266,25 +280,26 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
266
280
  layer.w2_weight_scale.data, requires_grad=False
267
281
  )
268
282
 
283
+ def create_moe_runner(
284
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
285
+ ):
286
+ self.moe_runner_config = moe_runner_config
287
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
288
+
269
289
  def apply(
270
290
  self,
271
291
  layer: torch.nn.Module,
272
- x: torch.Tensor,
273
- topk_output: StandardTopKOutput,
274
- moe_runner_config: MoeRunnerConfig,
275
- ) -> torch.Tensor:
276
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
292
+ dispatch_output: StandardDispatchOutput,
293
+ ) -> CombineInput:
277
294
 
278
- return fused_experts(
279
- x,
280
- layer.w13_weight,
281
- layer.w2_weight,
282
- topk_output=topk_output,
283
- moe_runner_config=moe_runner_config,
295
+ quant_info = TritonMoeQuantInfo(
296
+ w13_weight=layer.w13_weight,
297
+ w2_weight=layer.w2_weight,
284
298
  use_fp8_w8a8=True,
285
299
  per_channel_quant=True,
286
- w1_scale=(layer.w13_weight_scale),
287
- w2_scale=(layer.w2_weight_scale),
288
- a1_scale=layer.w13_input_scale,
300
+ w13_scale=layer.w13_weight_scale,
301
+ w2_scale=layer.w2_weight_scale,
302
+ a13_scale=layer.w13_input_scale,
289
303
  a2_scale=layer.w2_input_scale,
290
304
  )
305
+ return self.runner.run(dispatch_output, quant_info)
@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
24
24
  get_tensor_model_parallel_world_size,
25
25
  )
26
26
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
27
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
28
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
27
29
  from sglang.srt.layers.parameter import (
28
30
  ChannelQuantScaleParameter,
29
31
  ModelWeightParameter,
@@ -49,8 +51,10 @@ from sglang.srt.utils import (
49
51
  )
50
52
 
51
53
  if TYPE_CHECKING:
52
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
53
- from sglang.srt.layers.moe.topk import TopKOutput
54
+ from sglang.srt.layers.moe.token_dispatcher import (
55
+ CombineInput,
56
+ StandardDispatchOutput,
57
+ )
54
58
 
55
59
  _is_cuda = is_cuda()
56
60
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -339,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
339
343
  _is_cpu_amx_available
340
344
  ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
341
345
  _amx_process_weight_after_loading(layer, ["weight"])
342
- return
343
-
344
- layer.weight = Parameter(layer.weight.t(), requires_grad=False)
346
+ else:
347
+ layer.weight = Parameter(layer.weight.t(), requires_grad=False)
345
348
  layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
346
349
 
347
350
  def create_weights(
@@ -417,7 +420,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
417
420
  layer: torch.nn.Module,
418
421
  num_experts: int,
419
422
  hidden_size: int,
420
- intermediate_size: int,
423
+ intermediate_size_per_partition: int,
421
424
  params_dtype: torch.dtype,
422
425
  **extra_weight_attrs,
423
426
  ):
@@ -428,7 +431,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
428
431
  # WEIGHTS
429
432
  w13_weight = torch.nn.Parameter(
430
433
  torch.empty(
431
- num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
434
+ num_experts,
435
+ 2 * intermediate_size_per_partition,
436
+ hidden_size,
437
+ dtype=torch.int8,
432
438
  ),
433
439
  requires_grad=False,
434
440
  )
@@ -436,14 +442,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
436
442
  set_weight_attrs(w13_weight, extra_weight_attrs)
437
443
 
438
444
  w2_weight = torch.nn.Parameter(
439
- torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
445
+ torch.empty(
446
+ num_experts,
447
+ hidden_size,
448
+ intermediate_size_per_partition,
449
+ dtype=torch.int8,
450
+ ),
440
451
  requires_grad=False,
441
452
  )
442
453
  layer.register_parameter("w2_weight", w2_weight)
443
454
  set_weight_attrs(w2_weight, extra_weight_attrs)
444
455
 
445
456
  w13_weight_scale = torch.nn.Parameter(
446
- torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
457
+ torch.ones(
458
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
459
+ ),
447
460
  requires_grad=False,
448
461
  )
449
462
  w2_weight_scale = torch.nn.Parameter(
@@ -472,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
472
485
  _is_cpu_amx_available
473
486
  ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
474
487
  _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
475
- return
476
-
477
- layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
478
- layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
488
+ else:
489
+ layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
490
+ layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
479
491
  layer.w13_weight_scale = Parameter(
480
492
  layer.w13_weight_scale.data, requires_grad=False
481
493
  )
@@ -483,23 +495,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
483
495
  layer.w2_weight_scale.data, requires_grad=False
484
496
  )
485
497
 
498
+ def create_moe_runner(
499
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
500
+ ):
501
+ self.moe_runner_config = moe_runner_config
502
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
503
+
486
504
  def apply(
487
505
  self,
488
506
  layer: torch.nn.Module,
489
- x: torch.Tensor,
490
- topk_output: TopKOutput,
491
- moe_runner_config: MoeRunnerConfig,
507
+ dispatch_output: StandardDispatchOutput,
492
508
  ) -> torch.Tensor:
493
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
509
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
510
+
511
+ x = dispatch_output.hidden_states
512
+ topk_output = dispatch_output.topk_output
494
513
 
495
514
  if use_intel_amx_backend(layer):
496
515
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
497
516
 
498
517
  topk_weights, topk_ids, _ = topk_output
499
518
  x, topk_weights = apply_topk_weights_cpu(
500
- moe_runner_config.apply_router_weight_on_input, topk_weights, x
519
+ self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
501
520
  )
502
- return torch.ops.sgl_kernel.fused_experts_cpu(
521
+ output = torch.ops.sgl_kernel.fused_experts_cpu(
503
522
  x,
504
523
  layer.w13_weight,
505
524
  layer.w2_weight,
@@ -515,20 +534,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
515
534
  layer.w2_input_scale, # a2_scale
516
535
  True, # is_vnni
517
536
  )
537
+ return StandardCombineInput(hidden_states=output)
518
538
 
519
- return fused_experts(
520
- x,
521
- layer.w13_weight,
522
- layer.w2_weight,
523
- topk_output=topk_output,
524
- moe_runner_config=moe_runner_config,
539
+ quant_info = TritonMoeQuantInfo(
540
+ w13_weight=layer.w13_weight,
541
+ w2_weight=layer.w2_weight,
525
542
  use_int8_w8a8=True,
526
543
  per_channel_quant=True,
527
- w1_scale=(layer.w13_weight_scale),
528
- w2_scale=(layer.w2_weight_scale),
529
- a1_scale=layer.w13_input_scale,
544
+ w13_scale=layer.w13_weight_scale,
545
+ w2_scale=layer.w2_weight_scale,
546
+ a13_scale=layer.w13_input_scale,
530
547
  a2_scale=layer.w2_input_scale,
531
548
  )
549
+ return self.runner.run(dispatch_output, quant_info)
532
550
 
533
551
 
534
552
  class NPU_W8A8LinearMethodImpl:
@@ -900,7 +918,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
900
918
  layer: torch.nn.Module,
901
919
  num_experts: int,
902
920
  hidden_size: int,
903
- intermediate_size: int,
921
+ intermediate_size_per_partition: int,
904
922
  params_dtype: torch.dtype,
905
923
  **extra_weight_attrs,
906
924
  ) -> None:
@@ -914,21 +932,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
914
932
  # weight
915
933
  w13_weight = torch.nn.Parameter(
916
934
  torch.empty(
917
- num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
935
+ num_experts,
936
+ 2 * intermediate_size_per_partition,
937
+ hidden_size,
938
+ dtype=torch.int8,
918
939
  ),
919
940
  requires_grad=False,
920
941
  )
921
942
  layer.register_parameter("w13_weight", w13_weight)
922
943
  set_weight_attrs(w13_weight, extra_weight_attrs)
923
944
  w2_weight = torch.nn.Parameter(
924
- torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
945
+ torch.empty(
946
+ num_experts,
947
+ hidden_size,
948
+ intermediate_size_per_partition,
949
+ dtype=torch.int8,
950
+ ),
925
951
  requires_grad=False,
926
952
  )
927
953
  layer.register_parameter("w2_weight", w2_weight)
928
954
  set_weight_attrs(w2_weight, extra_weight_attrs)
929
955
  # scale
930
956
  w13_weight_scale = torch.nn.Parameter(
931
- torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
957
+ torch.empty(
958
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
959
+ ),
932
960
  requires_grad=False,
933
961
  )
934
962
  layer.register_parameter("w13_weight_scale", w13_weight_scale)
@@ -941,7 +969,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
941
969
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
942
970
  # offset
943
971
  w13_weight_offset = torch.nn.Parameter(
944
- torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
972
+ torch.empty(
973
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
974
+ ),
945
975
  requires_grad=False,
946
976
  )
947
977
  layer.register_parameter("w13_weight_offset", w13_weight_offset)
@@ -973,18 +1003,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
973
1003
  layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
974
1004
  )
975
1005
 
1006
+ def create_moe_runner(
1007
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1008
+ ):
1009
+ self.moe_runner_config = moe_runner_config
1010
+
976
1011
  def apply(
977
1012
  self,
978
1013
  layer,
979
- x,
980
- topk_output: TopKOutput,
981
- moe_runner_config: MoeRunnerConfig,
982
- ) -> torch.Tensor:
1014
+ dispatch_output: StandardDispatchOutput,
1015
+ ) -> CombineInput:
1016
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1017
+
1018
+ x = dispatch_output.hidden_states
1019
+ topk_output = dispatch_output.topk_output
983
1020
 
984
1021
  topk_weights, topk_ids, _ = topk_output
985
1022
  topk_ids = topk_ids.to(torch.int32)
986
1023
  topk_weights = topk_weights.to(x.dtype)
987
- return npu_fused_experts(
1024
+ output = npu_fused_experts(
988
1025
  hidden_states=x,
989
1026
  w13=layer.w13_weight,
990
1027
  w13_scale=layer.w13_weight_scale,
@@ -994,3 +1031,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
994
1031
  topk_ids=topk_ids,
995
1032
  top_k=topk_ids.shape[1],
996
1033
  )
1034
+ return StandardCombineInput(hidden_states=output)