sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +375 -51
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
2
2
 
3
- import datetime
4
- import glob
5
3
  import logging
6
- import os
7
- import sys
8
4
  from enum import Enum
9
5
  from typing import List, Optional, Tuple
10
6
 
@@ -22,12 +18,18 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
22
18
  use_symmetric_memory,
23
19
  )
24
20
  from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
25
- from sglang.srt.layers.moe.topk import StandardTopKOutput
26
- from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
21
+ from sglang.srt.layers.moe import (
22
+ MoeRunnerConfig,
23
+ get_moe_runner_backend,
24
+ should_use_flashinfer_trtllm_moe,
25
+ )
26
+ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
27
27
  from sglang.srt.layers.quantization.base_config import (
28
28
  QuantizationConfig,
29
29
  QuantizeMethodBase,
30
30
  )
31
+ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
32
+ from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
31
33
  from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
32
34
  from sglang.srt.managers.schedule_batch import global_server_args_dict
33
35
  from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
@@ -109,9 +111,8 @@ class FusedMoE(torch.nn.Module):
109
111
  hidden_size: Input hidden state size of the transformer
110
112
  intermediate_size: Intermediate size of the experts
111
113
  params_dtype: Data type for the parameters.
112
- reduce_results: Whether to all all_reduce on the output of the layer
113
- renomalize: Whether to renormalize the logits in the fused_moe kernel
114
- quant_config: Quantization configure.
114
+ reduce_results: Whether to apply all_reduce on the output of the layer
115
+ quant_config: Quantization configuration.
115
116
  inplace: suggestion to compute inplace (modify input activation).
116
117
  """
117
118
 
@@ -126,7 +127,6 @@ class FusedMoE(torch.nn.Module):
126
127
  params_dtype: Optional[torch.dtype] = None,
127
128
  reduce_results: bool = False,
128
129
  quant_config: Optional[QuantizationConfig] = None,
129
- tp_size: Optional[int] = None,
130
130
  prefix: str = "",
131
131
  activation: str = "silu",
132
132
  apply_router_weight_on_input: bool = False,
@@ -134,9 +134,8 @@ class FusedMoE(torch.nn.Module):
134
134
  inplace: bool = True,
135
135
  no_combine: bool = False,
136
136
  routed_scaling_factor: Optional[float] = None,
137
- enable_flashinfer_cutlass_moe: Optional[bool] = False,
138
- activation_alpha: Optional[float] = None,
139
- swiglu_limit: Optional[float] = None,
137
+ gemm1_alpha: Optional[float] = None,
138
+ gemm1_clamp_limit: Optional[float] = None,
140
139
  use_weight_loader_fused: bool = False,
141
140
  with_bias=False,
142
141
  ):
@@ -153,9 +152,17 @@ class FusedMoE(torch.nn.Module):
153
152
  self.expert_map_cpu = None
154
153
  self.expert_map_gpu = None
155
154
 
156
- # For activation
157
- self.activation_alpha = activation_alpha
158
- self.swiglu_limit = swiglu_limit
155
+ self.moe_runner_config = MoeRunnerConfig(
156
+ activation=activation,
157
+ apply_router_weight_on_input=apply_router_weight_on_input,
158
+ inplace=inplace,
159
+ no_combine=no_combine,
160
+ routed_scaling_factor=routed_scaling_factor,
161
+ gemm1_alpha=gemm1_alpha,
162
+ gemm1_clamp_limit=gemm1_clamp_limit,
163
+ )
164
+
165
+ enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
159
166
 
160
167
  if enable_flashinfer_cutlass_moe and quant_config is None:
161
168
  logger.warning("Disable flashinfer MoE when quantization config is None.")
@@ -174,9 +181,6 @@ class FusedMoE(torch.nn.Module):
174
181
  self.expert_map_cpu = torch.full(
175
182
  (self.num_experts,), -1, dtype=torch.int32, device="cpu"
176
183
  )
177
- self.expert_map_cpu = torch.full(
178
- (self.num_experts,), -1, dtype=torch.int32, device="cpu"
179
- )
180
184
  # Create a expert map for the local experts
181
185
  self.expert_map_cpu[
182
186
  self.moe_ep_rank
@@ -184,20 +188,12 @@ class FusedMoE(torch.nn.Module):
184
188
  * self.num_local_experts
185
189
  ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
186
190
 
187
- self.routed_scaling_factor = routed_scaling_factor
188
191
  assert intermediate_size % self.moe_tp_size == 0
189
192
  self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
190
193
  self.reduce_results = reduce_results
191
- self.activation = activation
192
- self.apply_router_weight_on_input = apply_router_weight_on_input
193
194
  self.use_presharded_weights = use_presharded_weights
194
- self.inplace = inplace
195
- self.no_combine = no_combine
196
-
197
- self.use_triton_kernels = (
198
- not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
199
- )
200
195
 
196
+ self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
201
197
  if quant_config is None:
202
198
  self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
203
199
  self.use_triton_kernels
@@ -207,14 +203,12 @@ class FusedMoE(torch.nn.Module):
207
203
  assert self.quant_method is not None
208
204
 
209
205
  self.quant_config = quant_config
210
- self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
211
- "enable_flashinfer_mxfp4_moe", False
212
- )
206
+ self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
213
207
  # TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
214
208
  if (
215
209
  self.quant_config is not None
216
210
  and self.quant_config.get_name() == "mxfp4"
217
- and self.use_enable_flashinfer_mxfp4_moe
211
+ and self.use_flashinfer_mxfp4_moe
218
212
  ):
219
213
  hidden_size = round_up(hidden_size, 256)
220
214
  self.quant_method.create_weights(
@@ -477,6 +471,7 @@ class FusedMoE(torch.nn.Module):
477
471
  not expert_id
478
472
  and self.quant_config is not None
479
473
  and self.quant_config.get_name() == "mxfp4"
474
+ and self.quant_config.is_static_cfg()
480
475
  ):
481
476
  if "bias" in weight_name:
482
477
  dim1 = loaded_weight.shape[1]
@@ -625,9 +620,7 @@ class FusedMoE(torch.nn.Module):
625
620
 
626
621
  if "ModelOpt" in self.quant_method.__class__.__name__:
627
622
  # Determine per-tensor weight scale patterns based on variant
628
- is_fp4_variant = (
629
- "ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
630
- )
623
+ is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
631
624
 
632
625
  # FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
633
626
  per_tensor_conditions = (
@@ -729,7 +722,11 @@ class FusedMoE(torch.nn.Module):
729
722
  ) -> None:
730
723
  tp_rank = self.moe_tp_rank
731
724
 
732
- if self.quant_config is not None and self.quant_config.get_name() == "mxfp4":
725
+ if (
726
+ self.quant_config is not None
727
+ and self.quant_config.get_name() == "mxfp4"
728
+ and self.quant_config.is_static_cfg()
729
+ ):
733
730
  if "bias" in weight_name:
734
731
  dim1 = loaded_weight.shape[1]
735
732
  param.data[:, :dim1].copy_(loaded_weight)
@@ -794,7 +791,7 @@ class FusedMoE(torch.nn.Module):
794
791
  f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
795
792
  )
796
793
 
797
- def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
794
+ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
798
795
  origin_hidden_states_dim = hidden_states.shape[-1]
799
796
  assert self.quant_method is not None
800
797
 
@@ -803,40 +800,22 @@ class FusedMoE(torch.nn.Module):
803
800
  # If we are in EP mode, we need to move the expert map to GPU.
804
801
  self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
805
802
 
806
- if self.expert_map_gpu is not None and isinstance(
807
- topk_output, StandardTopKOutput
808
- ):
809
- topk_output = topk_output._replace(
810
- topk_ids=self.expert_map_gpu[topk_output.topk_ids]
811
- )
803
+ if self.expert_map_gpu is not None:
804
+ if TopKOutputChecker.format_is_standard(topk_output):
805
+ topk_output = topk_output._replace(
806
+ topk_ids=self.expert_map_gpu[topk_output.topk_ids]
807
+ )
808
+ elif TopKOutputChecker.format_is_triton_kernel(topk_output):
809
+ raise NotImplementedError()
812
810
 
813
811
  # Matrix multiply.
814
812
  with use_symmetric_memory(get_tp_group()) as sm:
815
- kwargs = {}
816
- if self.activation_alpha is not None:
817
- kwargs["activation_alpha"] = self.activation_alpha
818
- if self.swiglu_limit is not None:
819
- kwargs["swiglu_limit"] = self.swiglu_limit
820
813
 
821
814
  final_hidden_states = self.quant_method.apply(
822
815
  layer=self,
823
816
  x=hidden_states,
824
817
  topk_output=topk_output,
825
- activation=self.activation,
826
- apply_router_weight_on_input=self.apply_router_weight_on_input,
827
- routed_scaling_factor=self.routed_scaling_factor,
828
- **(
829
- dict(
830
- tp_rank=self.moe_tp_rank,
831
- tp_size=self.moe_tp_size,
832
- ep_rank=self.moe_ep_rank,
833
- ep_size=self.moe_ep_size,
834
- )
835
- if self.quant_method.__class__.__name__
836
- == "ModelOptNvFp4FusedMoEMethod"
837
- else {}
838
- ),
839
- **kwargs,
818
+ moe_runner_config=self.moe_runner_config,
840
819
  )
841
820
  sm.tag(final_hidden_states)
842
821
 
@@ -941,53 +920,39 @@ class FusedMoE(torch.nn.Module):
941
920
  for shard_id in ["w1", "w2", "w3"]
942
921
  ]
943
922
 
923
+ def should_fuse_routed_scaling_factor_in_topk(self):
924
+ return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
925
+ isinstance(self.quant_method, Fp8MoEMethod)
926
+ and self.quant_method.use_cutlass_fused_experts_fp8
927
+ )
928
+
944
929
 
945
930
  class FlashInferFusedMoE(FusedMoE):
946
931
  def __init__(self, *args, **kwargs):
947
- renormalize = kwargs.pop("renormalize", True)
948
- num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
949
- use_grouped_topk = kwargs.pop("use_grouped_topk", False)
950
- num_expert_group = kwargs.pop("num_expert_group", None)
951
- topk_group = kwargs.pop("topk_group", None)
952
- correction_bias = kwargs.pop("correction_bias", None)
953
932
  super().__init__(*args, **kwargs)
954
- self.renormalize = renormalize
955
- self.num_fused_shared_experts = num_fused_shared_experts
956
- self.use_grouped_topk = use_grouped_topk
957
- if self.use_grouped_topk:
958
- assert num_expert_group is not None and topk_group is not None
959
- self.num_expert_group = num_expert_group
960
- self.topk_group = topk_group
961
- self.correction_bias = correction_bias
962
933
  self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
963
934
 
964
- def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
935
+ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
965
936
  assert self.use_flashinfer_trtllm_moe
966
937
  assert (
967
- self.activation == "silu"
938
+ self.moe_runner_config.activation == "silu"
968
939
  ), "Only silu is supported for flashinfer blockscale fp8 moe"
969
940
  assert self.quant_method is not None
970
941
  assert (
971
- self.renormalize
942
+ topk_output.topk_config.renormalize
972
943
  ), "Renormalize is required for flashinfer blockscale fp8 moe"
973
944
  assert (
974
945
  self.num_fused_shared_experts == 0
975
946
  ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
976
947
 
977
- # TRTLLM mode expects (TopK_config, router_logits) tuple
978
- if not isinstance(topk_output, tuple) or len(topk_output) != 2:
979
- raise ValueError(
980
- f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
981
- )
982
- _, router_logits = topk_output
948
+ assert TopKOutputChecker.format_is_bypassed(topk_output)
983
949
 
984
950
  # Matrix multiply.
985
951
  final_hidden_states = self.quant_method.apply_with_router_logits(
986
952
  layer=self,
987
953
  x=hidden_states,
988
- router_logits=router_logits,
989
- activation=self.activation,
990
- routed_scaling_factor=self.routed_scaling_factor,
954
+ topk_output=topk_output,
955
+ moe_runner_config=self.moe_runner_config,
991
956
  )
992
957
 
993
958
  if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
@@ -1000,28 +965,8 @@ class FlashInferFP4MoE(FusedMoE):
1000
965
  """FP4 TRTLLM MoE implementation using FlashInfer."""
1001
966
 
1002
967
  def __init__(self, *args, **kwargs):
1003
- # Extract DeepSeek-specific parameters
1004
- renormalize = kwargs.pop("renormalize", True)
1005
- num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
1006
- use_grouped_topk = kwargs.pop("use_grouped_topk", False)
1007
- num_expert_group = kwargs.pop("num_expert_group", None)
1008
- topk_group = kwargs.pop("topk_group", None)
1009
- correction_bias = kwargs.pop("correction_bias", None)
1010
-
1011
- # Extract additional TopK parameters that were previously extracted in forward
1012
- routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
1013
-
1014
968
  super().__init__(*args, **kwargs)
1015
969
 
1016
- # Store DeepSeek parameters
1017
- self.renormalize = renormalize
1018
- self.num_fused_shared_experts = num_fused_shared_experts
1019
- self.use_grouped_topk = use_grouped_topk
1020
- self.num_expert_group = num_expert_group
1021
- self.topk_group = topk_group
1022
- self.correction_bias = correction_bias
1023
- self.routed_scaling_factor = routed_scaling_factor
1024
-
1025
970
  # ---------------------------------------------------------------------
1026
971
  # Helper: quantize hidden states to FP4 each forward pass
1027
972
  # ---------------------------------------------------------------------
@@ -1052,21 +997,19 @@ class FlashInferFP4MoE(FusedMoE):
1052
997
 
1053
998
  return hs_fp4, hs_sf
1054
999
 
1055
- def forward(self, hidden_states: torch.Tensor, topk_output):
1000
+ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
1056
1001
  """Forward pass using FP4 TRTLLM kernel.
1057
1002
 
1058
1003
  Args:
1059
1004
  hidden_states: Input tensor
1060
- topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
1005
+ topk_output: TopKOutput object with Bypassed format
1061
1006
  """
1007
+ assert isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
1062
1008
 
1063
- # TRTLLM mode expects (TopK_config, router_logits) tuple
1064
- if not isinstance(topk_output, tuple) or len(topk_output) != 2:
1065
- raise ValueError(
1066
- f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
1067
- )
1009
+ assert TopKOutputChecker.format_is_bypassed(topk_output)
1068
1010
 
1069
- _, router_logits = topk_output
1011
+ router_logits = topk_output.router_logits
1012
+ topk_config = topk_output.topk_config
1070
1013
 
1071
1014
  hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
1072
1015
 
@@ -1074,7 +1017,7 @@ class FlashInferFP4MoE(FusedMoE):
1074
1017
 
1075
1018
  result = trtllm_fp4_block_scale_moe(
1076
1019
  routing_logits=router_logits,
1077
- routing_bias=self.correction_bias.to(hidden_states.dtype),
1020
+ routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
1078
1021
  hidden_states=hs_fp4,
1079
1022
  hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
1080
1023
  gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
@@ -1094,15 +1037,15 @@ class FlashInferFP4MoE(FusedMoE):
1094
1037
  output1_scale_gate_scalar=self.g1_alphas.data,
1095
1038
  output2_scale_scalar=self.g2_alphas.data,
1096
1039
  num_experts=self.num_experts,
1097
- top_k=self.top_k,
1098
- n_group=self.num_expert_group,
1099
- topk_group=self.topk_group,
1040
+ top_k=topk_config.top_k,
1041
+ n_group=topk_config.num_expert_group,
1042
+ topk_group=topk_config.topk_group,
1100
1043
  intermediate_size=self.intermediate_size_per_partition,
1101
1044
  local_expert_offset=self.moe_ep_rank * self.num_local_experts,
1102
1045
  local_num_experts=self.num_local_experts,
1103
- routed_scaling_factor=self.routed_scaling_factor,
1046
+ routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
1104
1047
  tile_tokens_dim=_get_tile_tokens_dim(
1105
- hidden_states.shape[0], self.top_k, self.num_local_experts
1048
+ hidden_states.shape[0], topk_config.top_k, self.num_local_experts
1106
1049
  ),
1107
1050
  routing_method_type=RoutingMethodType.DeepSeekV3,
1108
1051
  do_finalize=True,
@@ -18,6 +18,7 @@ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
18
18
  from triton_kernels.swiglu import swiglu_fn
19
19
 
20
20
  if TYPE_CHECKING:
21
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
21
22
  from sglang.srt.layers.moe.topk import TopKOutput
22
23
 
23
24
 
@@ -55,8 +56,7 @@ def triton_kernel_moe_forward(
55
56
  w1: torch.Tensor,
56
57
  w2: torch.Tensor,
57
58
  topk_output: TopKOutput,
58
- inplace: bool = False,
59
- activation: str = "silu",
59
+ moe_runner_config: MoeRunnerConfig,
60
60
  apply_router_weight_on_input: bool = False,
61
61
  use_fp8_w8a8: bool = False,
62
62
  per_channel_quant: bool = False,
@@ -69,7 +69,10 @@ def triton_kernel_moe_forward(
69
69
  block_shape: Optional[list[int]] = None,
70
70
  ) -> torch.Tensor:
71
71
 
72
- assert topk_output.format.is_triton_kernel()
72
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
73
+
74
+ assert TopKOutputChecker.format_is_triton_kernel(topk_output)
75
+
73
76
  routing_data, gather_idx, scatter_idx = topk_output
74
77
 
75
78
  return triton_kernel_fused_experts(
@@ -79,8 +82,8 @@ def triton_kernel_moe_forward(
79
82
  routing_data,
80
83
  gather_idx,
81
84
  scatter_idx,
82
- inplace=inplace,
83
- activation=activation,
85
+ inplace=False, # triton kernel doesn't support inplace
86
+ activation=moe_runner_config.activation,
84
87
  apply_router_weight_on_input=apply_router_weight_on_input,
85
88
  use_fp8_w8a8=use_fp8_w8a8,
86
89
  per_channel_quant=per_channel_quant,
@@ -192,8 +195,7 @@ def triton_kernel_moe_with_bias_forward(
192
195
  w2_pcg,
193
196
  b2: torch.Tensor,
194
197
  topk_output: TopKOutput,
195
- inplace: bool = False,
196
- activation: str = "silu",
198
+ moe_runner_config: MoeRunnerConfig,
197
199
  use_fp8_w8a8: bool = False,
198
200
  per_channel_quant: bool = False,
199
201
  global_num_experts: int = -1,
@@ -203,10 +205,11 @@ def triton_kernel_moe_with_bias_forward(
203
205
  a1_scale: Optional[torch.Tensor] = None,
204
206
  a2_scale: Optional[torch.Tensor] = None,
205
207
  block_shape: Optional[list[int]] = None,
206
- activation_alpha: Optional[float] = None,
207
- swiglu_limit: Optional[int] = None,
208
208
  ) -> torch.Tensor:
209
- assert topk_output.format.is_triton_kernel()
209
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
210
+
211
+ assert TopKOutputChecker.format_is_triton_kernel(topk_output)
212
+
210
213
  routing_data, gather_idx, scatter_idx = topk_output
211
214
 
212
215
  return triton_kernel_fused_experts_with_bias(
@@ -220,8 +223,8 @@ def triton_kernel_moe_with_bias_forward(
220
223
  routing_data=routing_data,
221
224
  gather_indx=gather_idx,
222
225
  scatter_indx=scatter_idx,
223
- inplace=inplace,
224
- activation=activation,
226
+ inplace=False, # triton kernel doesn't support inplace
227
+ activation=moe_runner_config.activation,
225
228
  use_fp8_w8a8=use_fp8_w8a8,
226
229
  per_channel_quant=per_channel_quant,
227
230
  global_num_experts=global_num_experts,
@@ -231,8 +234,8 @@ def triton_kernel_moe_with_bias_forward(
231
234
  a1_scale=a1_scale,
232
235
  a2_scale=a2_scale,
233
236
  block_shape=block_shape,
234
- activation_alpha=activation_alpha,
235
- swiglu_limit=swiglu_limit,
237
+ gemm1_alpha=moe_runner_config.gemm1_alpha,
238
+ gemm1_clamp_limit=moe_runner_config.gemm1_clamp_limit,
236
239
  )
237
240
 
238
241
 
@@ -258,10 +261,9 @@ def triton_kernel_fused_experts_with_bias(
258
261
  a1_scale: Optional[torch.Tensor] = None,
259
262
  a2_scale: Optional[torch.Tensor] = None,
260
263
  block_shape: Optional[list[int]] = None,
261
- activation_alpha: Optional[float] = None,
262
- swiglu_limit: Optional[int] = None,
264
+ gemm1_alpha: Optional[float] = None,
265
+ gemm1_clamp_limit: Optional[float] = None,
263
266
  ) -> torch.Tensor:
264
- # print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
265
267
  assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
266
268
  assert per_channel_quant == False, "per_channel_quant is not supported"
267
269
  assert expert_map == None, "expert_map is not supported"
@@ -307,7 +309,7 @@ def triton_kernel_fused_experts_with_bias(
307
309
 
308
310
  act = FusedActivation(
309
311
  FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
310
- (activation_alpha, swiglu_limit),
312
+ (gemm1_alpha, gemm1_clamp_limit),
311
313
  2,
312
314
  )
313
315
 
@@ -0,0 +1,3 @@
1
+ from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
2
+
3
+ __all__ = ["MoeRunnerConfig"]
@@ -0,0 +1,13 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+
5
+ @dataclass
6
+ class MoeRunnerConfig:
7
+ activation: str = "silu"
8
+ apply_router_weight_on_input: bool = False
9
+ inplace: bool = True
10
+ no_combine: bool = False
11
+ routed_scaling_factor: Optional[float] = None
12
+ gemm1_alpha: Optional[float] = None
13
+ gemm1_clamp_limit: Optional[float] = None
@@ -0,0 +1,141 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1rc2/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
4
+ from enum import IntEnum
5
+ from functools import cache
6
+ from typing import Optional
7
+
8
+ import torch
9
+
10
+ from sglang.srt.utils import direct_register_custom_op, get_bool_env_var, is_hip
11
+
12
+ _is_hip = is_hip()
13
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
14
+
15
+
16
+ class ActivationMethod(IntEnum):
17
+ # This allows interfacing with AITER ActivationType enum
18
+ # without importing the ActivationType enum from AITER globally.
19
+ SILU = 0
20
+ GELU = 1
21
+
22
+
23
+ def rocm_aiter_asm_moe_tkw1_impl(
24
+ hidden_states: torch.Tensor,
25
+ w1: torch.Tensor,
26
+ w2: torch.Tensor,
27
+ topk_weights: torch.Tensor,
28
+ topk_ids: torch.Tensor,
29
+ fc1_scale: Optional[torch.Tensor] = None,
30
+ fc2_scale: Optional[torch.Tensor] = None,
31
+ fc1_smooth_scale: Optional[torch.Tensor] = None,
32
+ fc2_smooth_scale: Optional[torch.Tensor] = None,
33
+ a16: bool = False,
34
+ per_tensor_quant_scale: Optional[torch.Tensor] = None,
35
+ expert_mask: Optional[torch.Tensor] = None,
36
+ activation_method: int = ActivationMethod.SILU.value,
37
+ ) -> torch.Tensor:
38
+
39
+ from aiter import ActivationType
40
+ from aiter.fused_moe_bf16_asm import asm_moe_tkw1
41
+
42
+ activation = ActivationType(activation_method)
43
+
44
+ return asm_moe_tkw1(
45
+ hidden_states,
46
+ w1,
47
+ w2,
48
+ topk_weights,
49
+ topk_ids,
50
+ fc1_scale=fc1_scale,
51
+ fc2_scale=fc2_scale,
52
+ fc1_smooth_scale=fc1_smooth_scale,
53
+ fc2_smooth_scale=fc2_smooth_scale,
54
+ a16=a16,
55
+ per_tensor_quant_scale=per_tensor_quant_scale,
56
+ expert_mask=expert_mask,
57
+ activation=activation,
58
+ )
59
+
60
+
61
+ def rocm_aiter_asm_moe_tkw1_fake(
62
+ hidden_states: torch.Tensor,
63
+ w1: torch.Tensor,
64
+ w2: torch.Tensor,
65
+ topk_weights: torch.Tensor,
66
+ topk_ids: torch.Tensor,
67
+ fc1_scale: Optional[torch.Tensor] = None,
68
+ fc2_scale: Optional[torch.Tensor] = None,
69
+ fc1_smooth_scale: Optional[torch.Tensor] = None,
70
+ fc2_smooth_scale: Optional[torch.Tensor] = None,
71
+ a16: bool = False,
72
+ per_tensor_quant_scale: Optional[torch.Tensor] = None,
73
+ expert_mask: Optional[torch.Tensor] = None,
74
+ activation_method: int = ActivationMethod.SILU.value,
75
+ ) -> torch.Tensor:
76
+ return torch.empty_like(hidden_states)
77
+
78
+
79
+ if _use_aiter:
80
+
81
+ direct_register_custom_op(
82
+ op_name="rocm_aiter_asm_moe_tkw1",
83
+ op_func=rocm_aiter_asm_moe_tkw1_impl,
84
+ mutates_args=[],
85
+ fake_impl=rocm_aiter_asm_moe_tkw1_fake,
86
+ )
87
+
88
+
89
+ def rocm_fused_experts_tkw1(
90
+ hidden_states: torch.Tensor,
91
+ w1: torch.Tensor,
92
+ w2: torch.Tensor,
93
+ topk_weights: torch.Tensor,
94
+ topk_ids: torch.Tensor,
95
+ activation: str = "silu",
96
+ apply_router_weight_on_input: bool = False,
97
+ use_fp8_w8a8: bool = False,
98
+ per_channel_quant: bool = False,
99
+ w1_scale: Optional[torch.Tensor] = None,
100
+ w2_scale: Optional[torch.Tensor] = None,
101
+ a1_scale: Optional[torch.Tensor] = None,
102
+ a2_scale: Optional[torch.Tensor] = None,
103
+ block_shape: Optional[list[int]] = None,
104
+ ) -> torch.Tensor:
105
+
106
+ activation_method = (
107
+ ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU
108
+ )
109
+ # All AITER Fused MoE kernels are expecting the following datatypes
110
+ topk_weights = topk_weights.to(torch.float32)
111
+ topk_ids = topk_ids.to(torch.int32)
112
+
113
+ # w8a8 per-channel quantization
114
+ if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
115
+ # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
116
+ # This applies topk_weights on the GEMM output of the first FC layer
117
+ # rather than the second FC.
118
+ assert (
119
+ topk_weights.dim() == 2
120
+ ), "`topk_weights` should be in shape (num_tokens, topk)"
121
+ assert topk_weights.shape[-1] == 1, (
122
+ "Only support topk=1 when" " `apply_router_weight_on_input` is True"
123
+ )
124
+
125
+ return torch.ops.sglang.rocm_aiter_asm_moe_tkw1(
126
+ hidden_states,
127
+ w1,
128
+ w2,
129
+ topk_weights,
130
+ topk_ids,
131
+ fc1_scale=w1_scale,
132
+ fc2_scale=w2_scale,
133
+ fc1_smooth_scale=None,
134
+ fc2_smooth_scale=None,
135
+ a16=False,
136
+ per_tensor_quant_scale=None,
137
+ expert_mask=None,
138
+ activation_method=activation_method,
139
+ )
140
+ else:
141
+ assert False, "This should not be called."
@@ -45,11 +45,14 @@ def fused_moe_router_kernel(
45
45
  logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
46
46
 
47
47
  # logit softcap
48
- logits_scaled = logits / moe_softcapping
49
- exped = tl.exp(2 * logits_scaled)
50
- top = exped - 1
51
- bottom = exped + 1
52
- logits_softcapped = top / bottom * moe_softcapping
48
+ if moe_softcapping == 0:
49
+ logits_softcapped = logits
50
+ else:
51
+ logits_scaled = logits / moe_softcapping
52
+ exped = tl.exp(2 * logits_scaled)
53
+ top = exped - 1
54
+ bottom = exped + 1
55
+ logits_softcapped = top / bottom * moe_softcapping
53
56
 
54
57
  # Add bias after softcapping
55
58
  if is_correction_bias:
@@ -207,9 +210,12 @@ def fused_moe_router_large_bs_kernel(
207
210
  b_ptrs += BLOCK_SIZE_K
208
211
 
209
212
  # 4. logit softcap
210
- logits_scaled = acc / moe_softcapping
211
- exped = tl.exp(2 * logits_scaled)
212
- logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
213
+ if moe_softcapping == 0:
214
+ logits_softcapped = acc
215
+ else:
216
+ logits_scaled = acc / moe_softcapping
217
+ exped = tl.exp(2 * logits_scaled)
218
+ logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
213
219
 
214
220
  # 5. top1
215
221
  arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
@@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel(
234
240
 
235
241
  # 7. handle topk == 2
236
242
  if topk == 2:
237
- cond_top2 = (arange_block_size_n < num_experts) and (
243
+ cond_top2 = (arange_block_size_n < num_experts) & (
238
244
  arange_block_size_n != top1[:, None]
239
245
  )
240
246
  top2 = tl.argmax(