sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -14,9 +14,18 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ import logging
17
18
  import math
19
+ from dataclasses import dataclass
18
20
  from enum import Enum, auto
19
- from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
21
+ from typing import (
22
+ Callable,
23
+ NamedTuple,
24
+ Optional,
25
+ Protocol,
26
+ TypeGuard,
27
+ runtime_checkable,
28
+ )
20
29
 
21
30
  import torch
22
31
  import torch.nn.functional as F
@@ -28,7 +37,10 @@ from sglang.srt.eplb.expert_location_dispatch import (
28
37
  ExpertLocationDispatchInfo,
29
38
  topk_ids_logical_to_physical,
30
39
  )
31
- from sglang.srt.managers.schedule_batch import global_server_args_dict
40
+ from sglang.srt.layers.moe import (
41
+ get_moe_runner_backend,
42
+ should_use_flashinfer_trtllm_moe,
43
+ )
32
44
  from sglang.srt.utils import (
33
45
  cpu_has_amx_support,
34
46
  get_bool_env_var,
@@ -43,6 +55,7 @@ try:
43
55
  from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
44
56
  except ImportError:
45
57
  pass
58
+ logger = logging.getLogger(__name__)
46
59
 
47
60
 
48
61
  _is_cuda = is_cuda()
@@ -65,13 +78,48 @@ if _use_aiter:
65
78
  if _is_npu:
66
79
  import torch_npu
67
80
 
81
+ # -------------------------------- TopKConfig ---------------------------------------
82
+
83
+
84
+ @dataclass
85
+ class TopKConfig:
86
+ top_k: int
87
+ use_grouped_topk: bool = False
88
+ topk_group: Optional[int] = None
89
+ num_expert_group: Optional[int] = None
90
+ renormalize: bool = True
91
+ num_fused_shared_experts: int = 0
92
+ custom_routing_function: Optional[Callable] = None
93
+ correction_bias: Optional[torch.Tensor] = None
94
+ torch_native: bool = False
95
+ routed_scaling_factor: Optional[float] = None
96
+ apply_routed_scaling_factor_on_output: bool = False
97
+
68
98
 
69
99
  # -------------------------------- TopKOutput ---------------------------------------
70
100
 
71
101
 
102
+ class TopKOutputChecker:
103
+
104
+ @staticmethod
105
+ def format_is_standard(topk_output: TopKOutput) -> TypeGuard[StandardTopKOutput]:
106
+ return topk_output.format.is_standard()
107
+
108
+ @staticmethod
109
+ def format_is_triton_kernel(
110
+ topk_output: TopKOutput,
111
+ ) -> TypeGuard[TritonKernelTopKOutput]:
112
+ return topk_output.format.is_triton_kernel()
113
+
114
+ @staticmethod
115
+ def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
116
+ return topk_output.format.is_bypassed()
117
+
118
+
72
119
  class TopKOutputFormat(Enum):
73
120
  STANDARD = auto()
74
121
  TRITON_KERNEL = auto()
122
+ BYPASSED = auto()
75
123
 
76
124
  def is_standard(self) -> bool:
77
125
  return self == TopKOutputFormat.STANDARD
@@ -79,6 +127,9 @@ class TopKOutputFormat(Enum):
79
127
  def is_triton_kernel(self) -> bool:
80
128
  return self == TopKOutputFormat.TRITON_KERNEL
81
129
 
130
+ def is_bypassed(self) -> bool:
131
+ return self == TopKOutputFormat.BYPASSED
132
+
82
133
 
83
134
  @runtime_checkable
84
135
  class TopKOutput(Protocol):
@@ -114,6 +165,20 @@ class TritonKernelTopKOutput(NamedTuple):
114
165
  return TopKOutputFormat.TRITON_KERNEL
115
166
 
116
167
 
168
+ class BypassedTopKOutput(NamedTuple):
169
+ """Bypassed top-k output format."""
170
+
171
+ hidden_states: torch.Tensor
172
+ router_logits: torch.Tensor
173
+ topk_config: TopKConfig
174
+ num_token_non_padded: Optional[torch.Tensor] = None
175
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None
176
+
177
+ @property
178
+ def format(self) -> TopKOutputFormat:
179
+ return TopKOutputFormat.BYPASSED
180
+
181
+
117
182
  # -------------------------------- TopK ---------------------------------------
118
183
 
119
184
 
@@ -132,23 +197,31 @@ class TopK(CustomOp):
132
197
  scoring_func: str = "softmax",
133
198
  correction_bias: Optional[torch.Tensor] = None,
134
199
  routed_scaling_factor: Optional[float] = None,
200
+ apply_routed_scaling_factor_on_output: Optional[bool] = False,
201
+ force_topk: bool = False,
135
202
  ):
136
203
  # NOTE: scoring_func is not used for now, but we keep it for future use
137
204
  # see https://github.com/sgl-project/sglang/pull/4505 for more details
138
205
  super().__init__()
206
+
139
207
  if use_grouped_topk:
140
208
  assert num_expert_group is not None and topk_group is not None
141
- self.top_k = top_k
142
- self.use_grouped_topk = use_grouped_topk
143
- self.renormalize = renormalize
144
- self.topk_group = topk_group
145
- self.num_expert_group = num_expert_group
146
- self.num_fused_shared_experts = num_fused_shared_experts
147
- self.custom_routing_function = custom_routing_function
148
- self.correction_bias = correction_bias
149
- self.routed_scaling_factor = routed_scaling_factor
150
-
151
- self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
209
+
210
+ self.topk_config = TopKConfig(
211
+ top_k=top_k,
212
+ use_grouped_topk=use_grouped_topk,
213
+ renormalize=renormalize,
214
+ topk_group=topk_group,
215
+ num_expert_group=num_expert_group,
216
+ num_fused_shared_experts=num_fused_shared_experts,
217
+ custom_routing_function=custom_routing_function,
218
+ correction_bias=correction_bias,
219
+ routed_scaling_factor=routed_scaling_factor,
220
+ apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
221
+ )
222
+
223
+ self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
224
+ self.force_topk = force_topk
152
225
 
153
226
  def forward_native(
154
227
  self,
@@ -158,20 +231,11 @@ class TopK(CustomOp):
158
231
  num_token_non_padded: Optional[torch.Tensor] = None,
159
232
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
160
233
  ) -> TopKOutput:
161
- torch_native = True
234
+ self.topk_config.torch_native = True
162
235
  return select_experts(
163
236
  hidden_states=hidden_states,
164
237
  router_logits=router_logits,
165
- top_k=self.top_k,
166
- use_grouped_topk=self.use_grouped_topk,
167
- renormalize=self.renormalize,
168
- topk_group=self.topk_group,
169
- num_expert_group=self.num_expert_group,
170
- num_fused_shared_experts=self.num_fused_shared_experts,
171
- custom_routing_function=self.custom_routing_function,
172
- correction_bias=self.correction_bias,
173
- torch_native=torch_native,
174
- routed_scaling_factor=self.routed_scaling_factor,
238
+ topk_config=self.topk_config,
175
239
  num_token_non_padded=num_token_non_padded,
176
240
  expert_location_dispatch_info=expert_location_dispatch_info,
177
241
  )
@@ -187,24 +251,28 @@ class TopK(CustomOp):
187
251
  if self.use_triton_kernels:
188
252
  # renormalize=True is equivalent to sm_first=False
189
253
  routing_data, gather_idx, scatter_idx = routing(
190
- router_logits, self.top_k, sm_first=not self.renormalize
254
+ router_logits,
255
+ self.topk_config.top_k,
256
+ sm_first=not self.topk_config.renormalize,
191
257
  )
192
258
  return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
259
+ elif not self.force_topk and (
260
+ should_use_flashinfer_trtllm_moe()
261
+ or get_moe_runner_backend().is_flashinfer_mxfp4()
262
+ ):
263
+ return BypassedTopKOutput(
264
+ hidden_states=hidden_states,
265
+ router_logits=router_logits,
266
+ topk_config=self.topk_config,
267
+ num_token_non_padded=num_token_non_padded,
268
+ expert_location_dispatch_info=expert_location_dispatch_info,
269
+ )
193
270
  else:
194
- torch_native = False
271
+ self.topk_config.torch_native = False
195
272
  return select_experts(
196
273
  hidden_states=hidden_states,
197
274
  router_logits=router_logits,
198
- top_k=self.top_k,
199
- use_grouped_topk=self.use_grouped_topk,
200
- renormalize=self.renormalize,
201
- topk_group=self.topk_group,
202
- num_expert_group=self.num_expert_group,
203
- num_fused_shared_experts=self.num_fused_shared_experts,
204
- custom_routing_function=self.custom_routing_function,
205
- correction_bias=self.correction_bias,
206
- torch_native=torch_native,
207
- routed_scaling_factor=self.routed_scaling_factor,
275
+ topk_config=self.topk_config,
208
276
  num_token_non_padded=num_token_non_padded,
209
277
  expert_location_dispatch_info=expert_location_dispatch_info,
210
278
  )
@@ -220,15 +288,7 @@ class TopK(CustomOp):
220
288
  return select_experts(
221
289
  hidden_states=hidden_states,
222
290
  router_logits=router_logits,
223
- top_k=self.top_k,
224
- use_grouped_topk=self.use_grouped_topk,
225
- renormalize=self.renormalize,
226
- topk_group=self.topk_group,
227
- num_expert_group=self.num_expert_group,
228
- num_fused_shared_experts=self.num_fused_shared_experts,
229
- custom_routing_function=self.custom_routing_function,
230
- correction_bias=self.correction_bias,
231
- routed_scaling_factor=self.routed_scaling_factor,
291
+ topk_config=self.topk_config,
232
292
  num_token_non_padded=num_token_non_padded,
233
293
  expert_location_dispatch_info=expert_location_dispatch_info,
234
294
  )
@@ -244,39 +304,40 @@ class TopK(CustomOp):
244
304
  global_num_experts = router_logits.shape[-1]
245
305
 
246
306
  # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
247
- if global_num_experts == 256:
307
+ if global_num_experts == 256 and self.topk_config.renormalize is False:
308
+
309
+ routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
248
310
  router_logits = router_logits.to(torch.float32)
311
+
249
312
  return torch_npu.npu_moe_gating_top_k(
250
313
  router_logits,
251
- k=self.top_k,
252
- bias=self.correction_bias.to(torch.float32),
253
- k_group=self.topk_group,
254
- group_count=self.num_expert_group,
314
+ k=self.topk_config.top_k,
315
+ bias=self.topk_config.correction_bias.to(torch.float32),
316
+ k_group=self.topk_config.topk_group,
317
+ group_count=self.topk_config.num_expert_group,
255
318
  group_select_mode=1,
256
319
  renorm=0,
257
320
  norm_type=1,
258
- routed_scaling_factor=1,
321
+ routed_scaling_factor=routed_scaling_factor,
259
322
  eps=float(1e-20),
260
323
  )
261
324
  else:
262
- torch_native = True
325
+ self.topk_config.torch_native = True
263
326
  return select_experts(
264
327
  hidden_states=hidden_states,
265
328
  router_logits=router_logits,
266
- top_k=self.top_k,
267
- use_grouped_topk=self.use_grouped_topk,
268
- renormalize=self.renormalize,
269
- topk_group=self.topk_group,
270
- num_expert_group=self.num_expert_group,
271
- num_fused_shared_experts=self.num_fused_shared_experts,
272
- custom_routing_function=self.custom_routing_function,
273
- correction_bias=self.correction_bias,
274
- torch_native=torch_native,
275
- routed_scaling_factor=self.routed_scaling_factor,
329
+ topk_config=self.topk_config,
276
330
  num_token_non_padded=num_token_non_padded,
277
331
  expert_location_dispatch_info=expert_location_dispatch_info,
278
332
  )
279
333
 
334
+ def empty_topk_output(self, device: torch.device) -> TopKOutput:
335
+ topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
336
+ topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
337
+ topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
338
+ router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
339
+ return StandardTopKOutput(topk_weights, topk_idx, router_logits)
340
+
280
341
 
281
342
  # ------------------------------- TopK implementation -------------------------------------
282
343
 
@@ -370,12 +431,13 @@ def grouped_topk_gpu(
370
431
  gating_output: torch.Tensor,
371
432
  topk: int,
372
433
  renormalize: bool,
373
- num_expert_group: int = 0,
374
- topk_group: int = 0,
434
+ num_expert_group: Optional[int] = None,
435
+ topk_group: Optional[int] = None,
375
436
  num_fused_shared_experts: int = 0,
376
437
  routed_scaling_factor: Optional[float] = None,
377
438
  num_token_non_padded: Optional[torch.Tensor] = None,
378
439
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
440
+ apply_routed_scaling_factor_on_output: Optional[bool] = False,
379
441
  ):
380
442
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
381
443
 
@@ -423,6 +485,8 @@ def grouped_topk_gpu(
423
485
  else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
424
486
  )
425
487
  topk_weights = topk_weights / topk_weights_sum
488
+ if apply_routed_scaling_factor_on_output:
489
+ topk_weights *= routed_scaling_factor
426
490
 
427
491
  topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
428
492
  topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
@@ -435,8 +499,8 @@ def grouped_topk_cpu(
435
499
  gating_output: torch.Tensor,
436
500
  topk: int,
437
501
  renormalize: bool,
438
- num_expert_group: int = 0,
439
- topk_group: int = 0,
502
+ num_expert_group: Optional[int] = None,
503
+ topk_group: Optional[int] = None,
440
504
  num_fused_shared_experts: int = 0,
441
505
  routed_scaling_factor: Optional[float] = None,
442
506
  num_token_non_padded: Optional[torch.Tensor] = None,
@@ -465,12 +529,13 @@ def biased_grouped_topk_impl(
465
529
  correction_bias: torch.Tensor,
466
530
  topk: int,
467
531
  renormalize: bool,
468
- num_expert_group: int = 0,
469
- topk_group: int = 0,
532
+ num_expert_group: Optional[int] = None,
533
+ topk_group: Optional[int] = None,
470
534
  num_fused_shared_experts: int = 0,
471
535
  routed_scaling_factor: Optional[float] = None,
472
536
  num_token_non_padded: Optional[torch.Tensor] = None,
473
537
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
538
+ apply_routed_scaling_factor_on_output: Optional[bool] = False,
474
539
  ):
475
540
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
476
541
 
@@ -522,6 +587,8 @@ def biased_grouped_topk_impl(
522
587
  else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
523
588
  )
524
589
  topk_weights = topk_weights / topk_weights_sum
590
+ if apply_routed_scaling_factor_on_output:
591
+ topk_weights *= routed_scaling_factor
525
592
 
526
593
  topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
527
594
  topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
@@ -558,12 +625,13 @@ def biased_grouped_topk_gpu(
558
625
  correction_bias: torch.Tensor,
559
626
  topk: int,
560
627
  renormalize: bool,
561
- num_expert_group: int = 0,
562
- topk_group: int = 0,
628
+ num_expert_group: Optional[int] = None,
629
+ topk_group: Optional[int] = None,
563
630
  num_fused_shared_experts: int = 0,
564
631
  routed_scaling_factor: Optional[float] = None,
565
632
  num_token_non_padded: Optional[torch.Tensor] = None,
566
633
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
634
+ apply_routed_scaling_factor_on_output: Optional[bool] = False,
567
635
  ):
568
636
  assert (
569
637
  routed_scaling_factor is not None
@@ -583,6 +651,7 @@ def biased_grouped_topk_gpu(
583
651
  topk,
584
652
  num_fused_shared_experts,
585
653
  routed_scaling_factor,
654
+ apply_routed_scaling_factor_on_output,
586
655
  )
587
656
  # TODO merge into kernel
588
657
  if (expert_location_dispatch_info is not None) or (
@@ -593,6 +662,7 @@ def biased_grouped_topk_gpu(
593
662
  )
594
663
  return topk_weights, topk_ids
595
664
  elif _use_aiter:
665
+ assert not apply_routed_scaling_factor_on_output, "Not implemented"
596
666
  token = gating_output.shape[0]
597
667
  device = gating_output.device
598
668
  assert (
@@ -624,6 +694,7 @@ def biased_grouped_topk_gpu(
624
694
  routed_scaling_factor=routed_scaling_factor,
625
695
  num_token_non_padded=num_token_non_padded,
626
696
  expert_location_dispatch_info=expert_location_dispatch_info,
697
+ apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
627
698
  )
628
699
 
629
700
 
@@ -633,15 +704,17 @@ def biased_grouped_topk_cpu(
633
704
  correction_bias: torch.Tensor,
634
705
  topk: int,
635
706
  renormalize: bool,
636
- num_expert_group: int = 0,
637
- topk_group: int = 0,
707
+ num_expert_group: Optional[int] = None,
708
+ topk_group: Optional[int] = None,
638
709
  compiled: bool = True,
639
710
  num_fused_shared_experts: int = 0,
640
711
  routed_scaling_factor: Optional[float] = None,
641
712
  num_token_non_padded: Optional[torch.Tensor] = None,
642
713
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
714
+ apply_routed_scaling_factor_on_output: Optional[bool] = False,
643
715
  ):
644
716
  assert expert_location_dispatch_info is None
717
+ assert not apply_routed_scaling_factor_on_output, "Not implemented"
645
718
  return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
646
719
  hidden_states,
647
720
  gating_output,
@@ -670,20 +743,26 @@ else:
670
743
  def select_experts(
671
744
  hidden_states: torch.Tensor,
672
745
  router_logits: torch.Tensor,
673
- top_k: int,
746
+ topk_config: TopKConfig,
674
747
  *,
675
- use_grouped_topk: bool = False,
676
- renormalize: bool = False,
677
- topk_group: Optional[int] = None,
678
- num_expert_group: Optional[int] = None,
679
- num_fused_shared_experts: int = 0,
680
- custom_routing_function: Optional[Callable] = None,
681
- correction_bias: Optional[torch.Tensor] = None,
682
- torch_native: bool = False,
683
- routed_scaling_factor: Optional[float] = None,
684
748
  num_token_non_padded: Optional[torch.Tensor] = None,
685
749
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
686
- ) -> TopKOutput:
750
+ ) -> StandardTopKOutput:
751
+
752
+ top_k = topk_config.top_k
753
+ use_grouped_topk = topk_config.use_grouped_topk
754
+ topk_group = topk_config.topk_group
755
+ num_expert_group = topk_config.num_expert_group
756
+ renormalize = topk_config.renormalize
757
+ num_fused_shared_experts = topk_config.num_fused_shared_experts
758
+ custom_routing_function = topk_config.custom_routing_function
759
+ correction_bias = topk_config.correction_bias
760
+ torch_native = topk_config.torch_native
761
+ routed_scaling_factor = topk_config.routed_scaling_factor
762
+ apply_routed_scaling_factor_on_output = (
763
+ topk_config.apply_routed_scaling_factor_on_output
764
+ )
765
+
687
766
  router_logits, correction_bias = (
688
767
  expert_location_dispatch.transform_select_experts_inputs(
689
768
  router_logits=router_logits,
@@ -708,6 +787,7 @@ def select_experts(
708
787
  routed_scaling_factor=routed_scaling_factor,
709
788
  num_token_non_padded=num_token_non_padded,
710
789
  expert_location_dispatch_info=expert_location_dispatch_info,
790
+ apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
711
791
  )
712
792
  else:
713
793
  topk_weights, topk_ids = biased_grouped_topk(
@@ -722,12 +802,14 @@ def select_experts(
722
802
  routed_scaling_factor=routed_scaling_factor,
723
803
  num_token_non_padded=num_token_non_padded,
724
804
  expert_location_dispatch_info=expert_location_dispatch_info,
805
+ apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
725
806
  )
726
807
  elif torch_native and custom_routing_function is None:
727
808
  assert (
728
809
  num_token_non_padded is None
729
810
  ), "num_token_non_padded is not yet supported in fused_topk_native"
730
811
  assert expert_location_dispatch_info is None
812
+ assert not apply_routed_scaling_factor_on_output, "Not implemented"
731
813
  topk_weights, topk_ids = fused_topk_native(
732
814
  hidden_states=hidden_states,
733
815
  gating_output=router_logits,
@@ -735,6 +817,7 @@ def select_experts(
735
817
  renormalize=renormalize,
736
818
  )
737
819
  elif custom_routing_function is None:
820
+ assert not apply_routed_scaling_factor_on_output, "Not implemented"
738
821
  # Qwen3MOE uses fused_topk
739
822
  topk_weights, topk_ids = fused_topk(
740
823
  hidden_states=hidden_states,
@@ -749,6 +832,7 @@ def select_experts(
749
832
  num_token_non_padded is None
750
833
  ), "num_token_non_padded is not yet supported in custom_routing_function"
751
834
  assert expert_location_dispatch_info is None
835
+ assert not apply_routed_scaling_factor_on_output, "Not implemented"
752
836
  topk_weights, topk_ids = custom_routing_function(
753
837
  hidden_states=hidden_states,
754
838
  gating_output=router_logits,