sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/topk.py
CHANGED
@@ -14,9 +14,18 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
+
import logging
|
17
18
|
import math
|
19
|
+
from dataclasses import dataclass
|
18
20
|
from enum import Enum, auto
|
19
|
-
from typing import
|
21
|
+
from typing import (
|
22
|
+
Callable,
|
23
|
+
NamedTuple,
|
24
|
+
Optional,
|
25
|
+
Protocol,
|
26
|
+
TypeGuard,
|
27
|
+
runtime_checkable,
|
28
|
+
)
|
20
29
|
|
21
30
|
import torch
|
22
31
|
import torch.nn.functional as F
|
@@ -28,7 +37,10 @@ from sglang.srt.eplb.expert_location_dispatch import (
|
|
28
37
|
ExpertLocationDispatchInfo,
|
29
38
|
topk_ids_logical_to_physical,
|
30
39
|
)
|
31
|
-
from sglang.srt.
|
40
|
+
from sglang.srt.layers.moe import (
|
41
|
+
get_moe_runner_backend,
|
42
|
+
should_use_flashinfer_trtllm_moe,
|
43
|
+
)
|
32
44
|
from sglang.srt.utils import (
|
33
45
|
cpu_has_amx_support,
|
34
46
|
get_bool_env_var,
|
@@ -43,6 +55,7 @@ try:
|
|
43
55
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
44
56
|
except ImportError:
|
45
57
|
pass
|
58
|
+
logger = logging.getLogger(__name__)
|
46
59
|
|
47
60
|
|
48
61
|
_is_cuda = is_cuda()
|
@@ -65,13 +78,48 @@ if _use_aiter:
|
|
65
78
|
if _is_npu:
|
66
79
|
import torch_npu
|
67
80
|
|
81
|
+
# -------------------------------- TopKConfig ---------------------------------------
|
82
|
+
|
83
|
+
|
84
|
+
@dataclass
|
85
|
+
class TopKConfig:
|
86
|
+
top_k: int
|
87
|
+
use_grouped_topk: bool = False
|
88
|
+
topk_group: Optional[int] = None
|
89
|
+
num_expert_group: Optional[int] = None
|
90
|
+
renormalize: bool = True
|
91
|
+
num_fused_shared_experts: int = 0
|
92
|
+
custom_routing_function: Optional[Callable] = None
|
93
|
+
correction_bias: Optional[torch.Tensor] = None
|
94
|
+
torch_native: bool = False
|
95
|
+
routed_scaling_factor: Optional[float] = None
|
96
|
+
apply_routed_scaling_factor_on_output: bool = False
|
97
|
+
|
68
98
|
|
69
99
|
# -------------------------------- TopKOutput ---------------------------------------
|
70
100
|
|
71
101
|
|
102
|
+
class TopKOutputChecker:
|
103
|
+
|
104
|
+
@staticmethod
|
105
|
+
def format_is_standard(topk_output: TopKOutput) -> TypeGuard[StandardTopKOutput]:
|
106
|
+
return topk_output.format.is_standard()
|
107
|
+
|
108
|
+
@staticmethod
|
109
|
+
def format_is_triton_kernel(
|
110
|
+
topk_output: TopKOutput,
|
111
|
+
) -> TypeGuard[TritonKernelTopKOutput]:
|
112
|
+
return topk_output.format.is_triton_kernel()
|
113
|
+
|
114
|
+
@staticmethod
|
115
|
+
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
|
116
|
+
return topk_output.format.is_bypassed()
|
117
|
+
|
118
|
+
|
72
119
|
class TopKOutputFormat(Enum):
|
73
120
|
STANDARD = auto()
|
74
121
|
TRITON_KERNEL = auto()
|
122
|
+
BYPASSED = auto()
|
75
123
|
|
76
124
|
def is_standard(self) -> bool:
|
77
125
|
return self == TopKOutputFormat.STANDARD
|
@@ -79,6 +127,9 @@ class TopKOutputFormat(Enum):
|
|
79
127
|
def is_triton_kernel(self) -> bool:
|
80
128
|
return self == TopKOutputFormat.TRITON_KERNEL
|
81
129
|
|
130
|
+
def is_bypassed(self) -> bool:
|
131
|
+
return self == TopKOutputFormat.BYPASSED
|
132
|
+
|
82
133
|
|
83
134
|
@runtime_checkable
|
84
135
|
class TopKOutput(Protocol):
|
@@ -114,6 +165,20 @@ class TritonKernelTopKOutput(NamedTuple):
|
|
114
165
|
return TopKOutputFormat.TRITON_KERNEL
|
115
166
|
|
116
167
|
|
168
|
+
class BypassedTopKOutput(NamedTuple):
|
169
|
+
"""Bypassed top-k output format."""
|
170
|
+
|
171
|
+
hidden_states: torch.Tensor
|
172
|
+
router_logits: torch.Tensor
|
173
|
+
topk_config: TopKConfig
|
174
|
+
num_token_non_padded: Optional[torch.Tensor] = None
|
175
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None
|
176
|
+
|
177
|
+
@property
|
178
|
+
def format(self) -> TopKOutputFormat:
|
179
|
+
return TopKOutputFormat.BYPASSED
|
180
|
+
|
181
|
+
|
117
182
|
# -------------------------------- TopK ---------------------------------------
|
118
183
|
|
119
184
|
|
@@ -132,23 +197,31 @@ class TopK(CustomOp):
|
|
132
197
|
scoring_func: str = "softmax",
|
133
198
|
correction_bias: Optional[torch.Tensor] = None,
|
134
199
|
routed_scaling_factor: Optional[float] = None,
|
200
|
+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
201
|
+
force_topk: bool = False,
|
135
202
|
):
|
136
203
|
# NOTE: scoring_func is not used for now, but we keep it for future use
|
137
204
|
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
138
205
|
super().__init__()
|
206
|
+
|
139
207
|
if use_grouped_topk:
|
140
208
|
assert num_expert_group is not None and topk_group is not None
|
141
|
-
|
142
|
-
self.
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
209
|
+
|
210
|
+
self.topk_config = TopKConfig(
|
211
|
+
top_k=top_k,
|
212
|
+
use_grouped_topk=use_grouped_topk,
|
213
|
+
renormalize=renormalize,
|
214
|
+
topk_group=topk_group,
|
215
|
+
num_expert_group=num_expert_group,
|
216
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
217
|
+
custom_routing_function=custom_routing_function,
|
218
|
+
correction_bias=correction_bias,
|
219
|
+
routed_scaling_factor=routed_scaling_factor,
|
220
|
+
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
221
|
+
)
|
222
|
+
|
223
|
+
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
224
|
+
self.force_topk = force_topk
|
152
225
|
|
153
226
|
def forward_native(
|
154
227
|
self,
|
@@ -158,20 +231,11 @@ class TopK(CustomOp):
|
|
158
231
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
159
232
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
160
233
|
) -> TopKOutput:
|
161
|
-
torch_native = True
|
234
|
+
self.topk_config.torch_native = True
|
162
235
|
return select_experts(
|
163
236
|
hidden_states=hidden_states,
|
164
237
|
router_logits=router_logits,
|
165
|
-
|
166
|
-
use_grouped_topk=self.use_grouped_topk,
|
167
|
-
renormalize=self.renormalize,
|
168
|
-
topk_group=self.topk_group,
|
169
|
-
num_expert_group=self.num_expert_group,
|
170
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
171
|
-
custom_routing_function=self.custom_routing_function,
|
172
|
-
correction_bias=self.correction_bias,
|
173
|
-
torch_native=torch_native,
|
174
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
238
|
+
topk_config=self.topk_config,
|
175
239
|
num_token_non_padded=num_token_non_padded,
|
176
240
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
177
241
|
)
|
@@ -187,24 +251,28 @@ class TopK(CustomOp):
|
|
187
251
|
if self.use_triton_kernels:
|
188
252
|
# renormalize=True is equivalent to sm_first=False
|
189
253
|
routing_data, gather_idx, scatter_idx = routing(
|
190
|
-
router_logits,
|
254
|
+
router_logits,
|
255
|
+
self.topk_config.top_k,
|
256
|
+
sm_first=not self.topk_config.renormalize,
|
191
257
|
)
|
192
258
|
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
259
|
+
elif not self.force_topk and (
|
260
|
+
should_use_flashinfer_trtllm_moe()
|
261
|
+
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
262
|
+
):
|
263
|
+
return BypassedTopKOutput(
|
264
|
+
hidden_states=hidden_states,
|
265
|
+
router_logits=router_logits,
|
266
|
+
topk_config=self.topk_config,
|
267
|
+
num_token_non_padded=num_token_non_padded,
|
268
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
269
|
+
)
|
193
270
|
else:
|
194
|
-
torch_native = False
|
271
|
+
self.topk_config.torch_native = False
|
195
272
|
return select_experts(
|
196
273
|
hidden_states=hidden_states,
|
197
274
|
router_logits=router_logits,
|
198
|
-
|
199
|
-
use_grouped_topk=self.use_grouped_topk,
|
200
|
-
renormalize=self.renormalize,
|
201
|
-
topk_group=self.topk_group,
|
202
|
-
num_expert_group=self.num_expert_group,
|
203
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
204
|
-
custom_routing_function=self.custom_routing_function,
|
205
|
-
correction_bias=self.correction_bias,
|
206
|
-
torch_native=torch_native,
|
207
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
275
|
+
topk_config=self.topk_config,
|
208
276
|
num_token_non_padded=num_token_non_padded,
|
209
277
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
210
278
|
)
|
@@ -220,15 +288,7 @@ class TopK(CustomOp):
|
|
220
288
|
return select_experts(
|
221
289
|
hidden_states=hidden_states,
|
222
290
|
router_logits=router_logits,
|
223
|
-
|
224
|
-
use_grouped_topk=self.use_grouped_topk,
|
225
|
-
renormalize=self.renormalize,
|
226
|
-
topk_group=self.topk_group,
|
227
|
-
num_expert_group=self.num_expert_group,
|
228
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
229
|
-
custom_routing_function=self.custom_routing_function,
|
230
|
-
correction_bias=self.correction_bias,
|
231
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
291
|
+
topk_config=self.topk_config,
|
232
292
|
num_token_non_padded=num_token_non_padded,
|
233
293
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
234
294
|
)
|
@@ -244,39 +304,40 @@ class TopK(CustomOp):
|
|
244
304
|
global_num_experts = router_logits.shape[-1]
|
245
305
|
|
246
306
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
247
|
-
if global_num_experts == 256:
|
307
|
+
if global_num_experts == 256 and self.topk_config.renormalize is False:
|
308
|
+
|
309
|
+
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
248
310
|
router_logits = router_logits.to(torch.float32)
|
311
|
+
|
249
312
|
return torch_npu.npu_moe_gating_top_k(
|
250
313
|
router_logits,
|
251
|
-
k=self.top_k,
|
252
|
-
bias=self.correction_bias.to(torch.float32),
|
253
|
-
k_group=self.topk_group,
|
254
|
-
group_count=self.num_expert_group,
|
314
|
+
k=self.topk_config.top_k,
|
315
|
+
bias=self.topk_config.correction_bias.to(torch.float32),
|
316
|
+
k_group=self.topk_config.topk_group,
|
317
|
+
group_count=self.topk_config.num_expert_group,
|
255
318
|
group_select_mode=1,
|
256
319
|
renorm=0,
|
257
320
|
norm_type=1,
|
258
|
-
routed_scaling_factor=
|
321
|
+
routed_scaling_factor=routed_scaling_factor,
|
259
322
|
eps=float(1e-20),
|
260
323
|
)
|
261
324
|
else:
|
262
|
-
torch_native = True
|
325
|
+
self.topk_config.torch_native = True
|
263
326
|
return select_experts(
|
264
327
|
hidden_states=hidden_states,
|
265
328
|
router_logits=router_logits,
|
266
|
-
|
267
|
-
use_grouped_topk=self.use_grouped_topk,
|
268
|
-
renormalize=self.renormalize,
|
269
|
-
topk_group=self.topk_group,
|
270
|
-
num_expert_group=self.num_expert_group,
|
271
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
272
|
-
custom_routing_function=self.custom_routing_function,
|
273
|
-
correction_bias=self.correction_bias,
|
274
|
-
torch_native=torch_native,
|
275
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
329
|
+
topk_config=self.topk_config,
|
276
330
|
num_token_non_padded=num_token_non_padded,
|
277
331
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
278
332
|
)
|
279
333
|
|
334
|
+
def empty_topk_output(self, device: torch.device) -> TopKOutput:
|
335
|
+
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
|
336
|
+
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
|
337
|
+
topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
|
338
|
+
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
|
339
|
+
return StandardTopKOutput(topk_weights, topk_idx, router_logits)
|
340
|
+
|
280
341
|
|
281
342
|
# ------------------------------- TopK implementation -------------------------------------
|
282
343
|
|
@@ -370,12 +431,13 @@ def grouped_topk_gpu(
|
|
370
431
|
gating_output: torch.Tensor,
|
371
432
|
topk: int,
|
372
433
|
renormalize: bool,
|
373
|
-
num_expert_group: int =
|
374
|
-
topk_group: int =
|
434
|
+
num_expert_group: Optional[int] = None,
|
435
|
+
topk_group: Optional[int] = None,
|
375
436
|
num_fused_shared_experts: int = 0,
|
376
437
|
routed_scaling_factor: Optional[float] = None,
|
377
438
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
378
439
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
440
|
+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
379
441
|
):
|
380
442
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
381
443
|
|
@@ -423,6 +485,8 @@ def grouped_topk_gpu(
|
|
423
485
|
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
424
486
|
)
|
425
487
|
topk_weights = topk_weights / topk_weights_sum
|
488
|
+
if apply_routed_scaling_factor_on_output:
|
489
|
+
topk_weights *= routed_scaling_factor
|
426
490
|
|
427
491
|
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
428
492
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
@@ -435,8 +499,8 @@ def grouped_topk_cpu(
|
|
435
499
|
gating_output: torch.Tensor,
|
436
500
|
topk: int,
|
437
501
|
renormalize: bool,
|
438
|
-
num_expert_group: int =
|
439
|
-
topk_group: int =
|
502
|
+
num_expert_group: Optional[int] = None,
|
503
|
+
topk_group: Optional[int] = None,
|
440
504
|
num_fused_shared_experts: int = 0,
|
441
505
|
routed_scaling_factor: Optional[float] = None,
|
442
506
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
@@ -465,12 +529,13 @@ def biased_grouped_topk_impl(
|
|
465
529
|
correction_bias: torch.Tensor,
|
466
530
|
topk: int,
|
467
531
|
renormalize: bool,
|
468
|
-
num_expert_group: int =
|
469
|
-
topk_group: int =
|
532
|
+
num_expert_group: Optional[int] = None,
|
533
|
+
topk_group: Optional[int] = None,
|
470
534
|
num_fused_shared_experts: int = 0,
|
471
535
|
routed_scaling_factor: Optional[float] = None,
|
472
536
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
473
537
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
538
|
+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
474
539
|
):
|
475
540
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
476
541
|
|
@@ -522,6 +587,8 @@ def biased_grouped_topk_impl(
|
|
522
587
|
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
523
588
|
)
|
524
589
|
topk_weights = topk_weights / topk_weights_sum
|
590
|
+
if apply_routed_scaling_factor_on_output:
|
591
|
+
topk_weights *= routed_scaling_factor
|
525
592
|
|
526
593
|
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
527
594
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
@@ -558,12 +625,13 @@ def biased_grouped_topk_gpu(
|
|
558
625
|
correction_bias: torch.Tensor,
|
559
626
|
topk: int,
|
560
627
|
renormalize: bool,
|
561
|
-
num_expert_group: int =
|
562
|
-
topk_group: int =
|
628
|
+
num_expert_group: Optional[int] = None,
|
629
|
+
topk_group: Optional[int] = None,
|
563
630
|
num_fused_shared_experts: int = 0,
|
564
631
|
routed_scaling_factor: Optional[float] = None,
|
565
632
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
566
633
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
634
|
+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
567
635
|
):
|
568
636
|
assert (
|
569
637
|
routed_scaling_factor is not None
|
@@ -583,6 +651,7 @@ def biased_grouped_topk_gpu(
|
|
583
651
|
topk,
|
584
652
|
num_fused_shared_experts,
|
585
653
|
routed_scaling_factor,
|
654
|
+
apply_routed_scaling_factor_on_output,
|
586
655
|
)
|
587
656
|
# TODO merge into kernel
|
588
657
|
if (expert_location_dispatch_info is not None) or (
|
@@ -593,6 +662,7 @@ def biased_grouped_topk_gpu(
|
|
593
662
|
)
|
594
663
|
return topk_weights, topk_ids
|
595
664
|
elif _use_aiter:
|
665
|
+
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
596
666
|
token = gating_output.shape[0]
|
597
667
|
device = gating_output.device
|
598
668
|
assert (
|
@@ -624,6 +694,7 @@ def biased_grouped_topk_gpu(
|
|
624
694
|
routed_scaling_factor=routed_scaling_factor,
|
625
695
|
num_token_non_padded=num_token_non_padded,
|
626
696
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
697
|
+
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
627
698
|
)
|
628
699
|
|
629
700
|
|
@@ -633,15 +704,17 @@ def biased_grouped_topk_cpu(
|
|
633
704
|
correction_bias: torch.Tensor,
|
634
705
|
topk: int,
|
635
706
|
renormalize: bool,
|
636
|
-
num_expert_group: int =
|
637
|
-
topk_group: int =
|
707
|
+
num_expert_group: Optional[int] = None,
|
708
|
+
topk_group: Optional[int] = None,
|
638
709
|
compiled: bool = True,
|
639
710
|
num_fused_shared_experts: int = 0,
|
640
711
|
routed_scaling_factor: Optional[float] = None,
|
641
712
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
642
713
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
714
|
+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
643
715
|
):
|
644
716
|
assert expert_location_dispatch_info is None
|
717
|
+
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
645
718
|
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
646
719
|
hidden_states,
|
647
720
|
gating_output,
|
@@ -670,20 +743,26 @@ else:
|
|
670
743
|
def select_experts(
|
671
744
|
hidden_states: torch.Tensor,
|
672
745
|
router_logits: torch.Tensor,
|
673
|
-
|
746
|
+
topk_config: TopKConfig,
|
674
747
|
*,
|
675
|
-
use_grouped_topk: bool = False,
|
676
|
-
renormalize: bool = False,
|
677
|
-
topk_group: Optional[int] = None,
|
678
|
-
num_expert_group: Optional[int] = None,
|
679
|
-
num_fused_shared_experts: int = 0,
|
680
|
-
custom_routing_function: Optional[Callable] = None,
|
681
|
-
correction_bias: Optional[torch.Tensor] = None,
|
682
|
-
torch_native: bool = False,
|
683
|
-
routed_scaling_factor: Optional[float] = None,
|
684
748
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
685
749
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
686
|
-
) ->
|
750
|
+
) -> StandardTopKOutput:
|
751
|
+
|
752
|
+
top_k = topk_config.top_k
|
753
|
+
use_grouped_topk = topk_config.use_grouped_topk
|
754
|
+
topk_group = topk_config.topk_group
|
755
|
+
num_expert_group = topk_config.num_expert_group
|
756
|
+
renormalize = topk_config.renormalize
|
757
|
+
num_fused_shared_experts = topk_config.num_fused_shared_experts
|
758
|
+
custom_routing_function = topk_config.custom_routing_function
|
759
|
+
correction_bias = topk_config.correction_bias
|
760
|
+
torch_native = topk_config.torch_native
|
761
|
+
routed_scaling_factor = topk_config.routed_scaling_factor
|
762
|
+
apply_routed_scaling_factor_on_output = (
|
763
|
+
topk_config.apply_routed_scaling_factor_on_output
|
764
|
+
)
|
765
|
+
|
687
766
|
router_logits, correction_bias = (
|
688
767
|
expert_location_dispatch.transform_select_experts_inputs(
|
689
768
|
router_logits=router_logits,
|
@@ -708,6 +787,7 @@ def select_experts(
|
|
708
787
|
routed_scaling_factor=routed_scaling_factor,
|
709
788
|
num_token_non_padded=num_token_non_padded,
|
710
789
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
790
|
+
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
711
791
|
)
|
712
792
|
else:
|
713
793
|
topk_weights, topk_ids = biased_grouped_topk(
|
@@ -722,12 +802,14 @@ def select_experts(
|
|
722
802
|
routed_scaling_factor=routed_scaling_factor,
|
723
803
|
num_token_non_padded=num_token_non_padded,
|
724
804
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
805
|
+
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
725
806
|
)
|
726
807
|
elif torch_native and custom_routing_function is None:
|
727
808
|
assert (
|
728
809
|
num_token_non_padded is None
|
729
810
|
), "num_token_non_padded is not yet supported in fused_topk_native"
|
730
811
|
assert expert_location_dispatch_info is None
|
812
|
+
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
731
813
|
topk_weights, topk_ids = fused_topk_native(
|
732
814
|
hidden_states=hidden_states,
|
733
815
|
gating_output=router_logits,
|
@@ -735,6 +817,7 @@ def select_experts(
|
|
735
817
|
renormalize=renormalize,
|
736
818
|
)
|
737
819
|
elif custom_routing_function is None:
|
820
|
+
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
738
821
|
# Qwen3MOE uses fused_topk
|
739
822
|
topk_weights, topk_ids = fused_topk(
|
740
823
|
hidden_states=hidden_states,
|
@@ -749,6 +832,7 @@ def select_experts(
|
|
749
832
|
num_token_non_padded is None
|
750
833
|
), "num_token_non_padded is not yet supported in custom_routing_function"
|
751
834
|
assert expert_location_dispatch_info is None
|
835
|
+
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
752
836
|
topk_weights, topk_ids = custom_routing_function(
|
753
837
|
hidden_states=hidden_states,
|
754
838
|
gating_output=router_logits,
|