sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +375 -51
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,6 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
2
2
|
|
3
|
-
import datetime
|
4
|
-
import glob
|
5
3
|
import logging
|
6
|
-
import os
|
7
|
-
import sys
|
8
4
|
from enum import Enum
|
9
5
|
from typing import List, Optional, Tuple
|
10
6
|
|
@@ -22,12 +18,18 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|
22
18
|
use_symmetric_memory,
|
23
19
|
)
|
24
20
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
25
|
-
from sglang.srt.layers.moe
|
26
|
-
|
21
|
+
from sglang.srt.layers.moe import (
|
22
|
+
MoeRunnerConfig,
|
23
|
+
get_moe_runner_backend,
|
24
|
+
should_use_flashinfer_trtllm_moe,
|
25
|
+
)
|
26
|
+
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
27
27
|
from sglang.srt.layers.quantization.base_config import (
|
28
28
|
QuantizationConfig,
|
29
29
|
QuantizeMethodBase,
|
30
30
|
)
|
31
|
+
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
|
32
|
+
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
|
31
33
|
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
32
34
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
33
35
|
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
@@ -109,9 +111,8 @@ class FusedMoE(torch.nn.Module):
|
|
109
111
|
hidden_size: Input hidden state size of the transformer
|
110
112
|
intermediate_size: Intermediate size of the experts
|
111
113
|
params_dtype: Data type for the parameters.
|
112
|
-
reduce_results: Whether to
|
113
|
-
|
114
|
-
quant_config: Quantization configure.
|
114
|
+
reduce_results: Whether to apply all_reduce on the output of the layer
|
115
|
+
quant_config: Quantization configuration.
|
115
116
|
inplace: suggestion to compute inplace (modify input activation).
|
116
117
|
"""
|
117
118
|
|
@@ -126,7 +127,6 @@ class FusedMoE(torch.nn.Module):
|
|
126
127
|
params_dtype: Optional[torch.dtype] = None,
|
127
128
|
reduce_results: bool = False,
|
128
129
|
quant_config: Optional[QuantizationConfig] = None,
|
129
|
-
tp_size: Optional[int] = None,
|
130
130
|
prefix: str = "",
|
131
131
|
activation: str = "silu",
|
132
132
|
apply_router_weight_on_input: bool = False,
|
@@ -134,9 +134,8 @@ class FusedMoE(torch.nn.Module):
|
|
134
134
|
inplace: bool = True,
|
135
135
|
no_combine: bool = False,
|
136
136
|
routed_scaling_factor: Optional[float] = None,
|
137
|
-
|
138
|
-
|
139
|
-
swiglu_limit: Optional[float] = None,
|
137
|
+
gemm1_alpha: Optional[float] = None,
|
138
|
+
gemm1_clamp_limit: Optional[float] = None,
|
140
139
|
use_weight_loader_fused: bool = False,
|
141
140
|
with_bias=False,
|
142
141
|
):
|
@@ -153,9 +152,17 @@ class FusedMoE(torch.nn.Module):
|
|
153
152
|
self.expert_map_cpu = None
|
154
153
|
self.expert_map_gpu = None
|
155
154
|
|
156
|
-
|
157
|
-
|
158
|
-
|
155
|
+
self.moe_runner_config = MoeRunnerConfig(
|
156
|
+
activation=activation,
|
157
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
158
|
+
inplace=inplace,
|
159
|
+
no_combine=no_combine,
|
160
|
+
routed_scaling_factor=routed_scaling_factor,
|
161
|
+
gemm1_alpha=gemm1_alpha,
|
162
|
+
gemm1_clamp_limit=gemm1_clamp_limit,
|
163
|
+
)
|
164
|
+
|
165
|
+
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
|
159
166
|
|
160
167
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
161
168
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
@@ -174,9 +181,6 @@ class FusedMoE(torch.nn.Module):
|
|
174
181
|
self.expert_map_cpu = torch.full(
|
175
182
|
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
176
183
|
)
|
177
|
-
self.expert_map_cpu = torch.full(
|
178
|
-
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
179
|
-
)
|
180
184
|
# Create a expert map for the local experts
|
181
185
|
self.expert_map_cpu[
|
182
186
|
self.moe_ep_rank
|
@@ -184,20 +188,12 @@ class FusedMoE(torch.nn.Module):
|
|
184
188
|
* self.num_local_experts
|
185
189
|
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
186
190
|
|
187
|
-
self.routed_scaling_factor = routed_scaling_factor
|
188
191
|
assert intermediate_size % self.moe_tp_size == 0
|
189
192
|
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
|
190
193
|
self.reduce_results = reduce_results
|
191
|
-
self.activation = activation
|
192
|
-
self.apply_router_weight_on_input = apply_router_weight_on_input
|
193
194
|
self.use_presharded_weights = use_presharded_weights
|
194
|
-
self.inplace = inplace
|
195
|
-
self.no_combine = no_combine
|
196
|
-
|
197
|
-
self.use_triton_kernels = (
|
198
|
-
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
199
|
-
)
|
200
195
|
|
196
|
+
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
201
197
|
if quant_config is None:
|
202
198
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
203
199
|
self.use_triton_kernels
|
@@ -207,14 +203,12 @@ class FusedMoE(torch.nn.Module):
|
|
207
203
|
assert self.quant_method is not None
|
208
204
|
|
209
205
|
self.quant_config = quant_config
|
210
|
-
self.
|
211
|
-
"enable_flashinfer_mxfp4_moe", False
|
212
|
-
)
|
206
|
+
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
|
213
207
|
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
|
214
208
|
if (
|
215
209
|
self.quant_config is not None
|
216
210
|
and self.quant_config.get_name() == "mxfp4"
|
217
|
-
and self.
|
211
|
+
and self.use_flashinfer_mxfp4_moe
|
218
212
|
):
|
219
213
|
hidden_size = round_up(hidden_size, 256)
|
220
214
|
self.quant_method.create_weights(
|
@@ -477,6 +471,7 @@ class FusedMoE(torch.nn.Module):
|
|
477
471
|
not expert_id
|
478
472
|
and self.quant_config is not None
|
479
473
|
and self.quant_config.get_name() == "mxfp4"
|
474
|
+
and self.quant_config.is_static_cfg()
|
480
475
|
):
|
481
476
|
if "bias" in weight_name:
|
482
477
|
dim1 = loaded_weight.shape[1]
|
@@ -625,9 +620,7 @@ class FusedMoE(torch.nn.Module):
|
|
625
620
|
|
626
621
|
if "ModelOpt" in self.quant_method.__class__.__name__:
|
627
622
|
# Determine per-tensor weight scale patterns based on variant
|
628
|
-
is_fp4_variant = (
|
629
|
-
"ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
|
630
|
-
)
|
623
|
+
is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
631
624
|
|
632
625
|
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
|
633
626
|
per_tensor_conditions = (
|
@@ -729,7 +722,11 @@ class FusedMoE(torch.nn.Module):
|
|
729
722
|
) -> None:
|
730
723
|
tp_rank = self.moe_tp_rank
|
731
724
|
|
732
|
-
if
|
725
|
+
if (
|
726
|
+
self.quant_config is not None
|
727
|
+
and self.quant_config.get_name() == "mxfp4"
|
728
|
+
and self.quant_config.is_static_cfg()
|
729
|
+
):
|
733
730
|
if "bias" in weight_name:
|
734
731
|
dim1 = loaded_weight.shape[1]
|
735
732
|
param.data[:, :dim1].copy_(loaded_weight)
|
@@ -794,7 +791,7 @@ class FusedMoE(torch.nn.Module):
|
|
794
791
|
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
|
795
792
|
)
|
796
793
|
|
797
|
-
def forward(self, hidden_states: torch.Tensor, topk_output:
|
794
|
+
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
798
795
|
origin_hidden_states_dim = hidden_states.shape[-1]
|
799
796
|
assert self.quant_method is not None
|
800
797
|
|
@@ -803,40 +800,22 @@ class FusedMoE(torch.nn.Module):
|
|
803
800
|
# If we are in EP mode, we need to move the expert map to GPU.
|
804
801
|
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
805
802
|
|
806
|
-
if self.expert_map_gpu is not None
|
807
|
-
topk_output
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
)
|
803
|
+
if self.expert_map_gpu is not None:
|
804
|
+
if TopKOutputChecker.format_is_standard(topk_output):
|
805
|
+
topk_output = topk_output._replace(
|
806
|
+
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
807
|
+
)
|
808
|
+
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
|
809
|
+
raise NotImplementedError()
|
812
810
|
|
813
811
|
# Matrix multiply.
|
814
812
|
with use_symmetric_memory(get_tp_group()) as sm:
|
815
|
-
kwargs = {}
|
816
|
-
if self.activation_alpha is not None:
|
817
|
-
kwargs["activation_alpha"] = self.activation_alpha
|
818
|
-
if self.swiglu_limit is not None:
|
819
|
-
kwargs["swiglu_limit"] = self.swiglu_limit
|
820
813
|
|
821
814
|
final_hidden_states = self.quant_method.apply(
|
822
815
|
layer=self,
|
823
816
|
x=hidden_states,
|
824
817
|
topk_output=topk_output,
|
825
|
-
|
826
|
-
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
827
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
828
|
-
**(
|
829
|
-
dict(
|
830
|
-
tp_rank=self.moe_tp_rank,
|
831
|
-
tp_size=self.moe_tp_size,
|
832
|
-
ep_rank=self.moe_ep_rank,
|
833
|
-
ep_size=self.moe_ep_size,
|
834
|
-
)
|
835
|
-
if self.quant_method.__class__.__name__
|
836
|
-
== "ModelOptNvFp4FusedMoEMethod"
|
837
|
-
else {}
|
838
|
-
),
|
839
|
-
**kwargs,
|
818
|
+
moe_runner_config=self.moe_runner_config,
|
840
819
|
)
|
841
820
|
sm.tag(final_hidden_states)
|
842
821
|
|
@@ -941,53 +920,39 @@ class FusedMoE(torch.nn.Module):
|
|
941
920
|
for shard_id in ["w1", "w2", "w3"]
|
942
921
|
]
|
943
922
|
|
923
|
+
def should_fuse_routed_scaling_factor_in_topk(self):
|
924
|
+
return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
|
925
|
+
isinstance(self.quant_method, Fp8MoEMethod)
|
926
|
+
and self.quant_method.use_cutlass_fused_experts_fp8
|
927
|
+
)
|
928
|
+
|
944
929
|
|
945
930
|
class FlashInferFusedMoE(FusedMoE):
|
946
931
|
def __init__(self, *args, **kwargs):
|
947
|
-
renormalize = kwargs.pop("renormalize", True)
|
948
|
-
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
949
|
-
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
950
|
-
num_expert_group = kwargs.pop("num_expert_group", None)
|
951
|
-
topk_group = kwargs.pop("topk_group", None)
|
952
|
-
correction_bias = kwargs.pop("correction_bias", None)
|
953
932
|
super().__init__(*args, **kwargs)
|
954
|
-
self.renormalize = renormalize
|
955
|
-
self.num_fused_shared_experts = num_fused_shared_experts
|
956
|
-
self.use_grouped_topk = use_grouped_topk
|
957
|
-
if self.use_grouped_topk:
|
958
|
-
assert num_expert_group is not None and topk_group is not None
|
959
|
-
self.num_expert_group = num_expert_group
|
960
|
-
self.topk_group = topk_group
|
961
|
-
self.correction_bias = correction_bias
|
962
933
|
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
963
934
|
|
964
|
-
def forward(self, hidden_states: torch.Tensor, topk_output:
|
935
|
+
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
965
936
|
assert self.use_flashinfer_trtllm_moe
|
966
937
|
assert (
|
967
|
-
self.activation == "silu"
|
938
|
+
self.moe_runner_config.activation == "silu"
|
968
939
|
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
969
940
|
assert self.quant_method is not None
|
970
941
|
assert (
|
971
|
-
|
942
|
+
topk_output.topk_config.renormalize
|
972
943
|
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
973
944
|
assert (
|
974
945
|
self.num_fused_shared_experts == 0
|
975
946
|
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
976
947
|
|
977
|
-
|
978
|
-
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
979
|
-
raise ValueError(
|
980
|
-
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
981
|
-
)
|
982
|
-
_, router_logits = topk_output
|
948
|
+
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
983
949
|
|
984
950
|
# Matrix multiply.
|
985
951
|
final_hidden_states = self.quant_method.apply_with_router_logits(
|
986
952
|
layer=self,
|
987
953
|
x=hidden_states,
|
988
|
-
|
989
|
-
|
990
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
954
|
+
topk_output=topk_output,
|
955
|
+
moe_runner_config=self.moe_runner_config,
|
991
956
|
)
|
992
957
|
|
993
958
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
@@ -1000,28 +965,8 @@ class FlashInferFP4MoE(FusedMoE):
|
|
1000
965
|
"""FP4 TRTLLM MoE implementation using FlashInfer."""
|
1001
966
|
|
1002
967
|
def __init__(self, *args, **kwargs):
|
1003
|
-
# Extract DeepSeek-specific parameters
|
1004
|
-
renormalize = kwargs.pop("renormalize", True)
|
1005
|
-
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
1006
|
-
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
1007
|
-
num_expert_group = kwargs.pop("num_expert_group", None)
|
1008
|
-
topk_group = kwargs.pop("topk_group", None)
|
1009
|
-
correction_bias = kwargs.pop("correction_bias", None)
|
1010
|
-
|
1011
|
-
# Extract additional TopK parameters that were previously extracted in forward
|
1012
|
-
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
|
1013
|
-
|
1014
968
|
super().__init__(*args, **kwargs)
|
1015
969
|
|
1016
|
-
# Store DeepSeek parameters
|
1017
|
-
self.renormalize = renormalize
|
1018
|
-
self.num_fused_shared_experts = num_fused_shared_experts
|
1019
|
-
self.use_grouped_topk = use_grouped_topk
|
1020
|
-
self.num_expert_group = num_expert_group
|
1021
|
-
self.topk_group = topk_group
|
1022
|
-
self.correction_bias = correction_bias
|
1023
|
-
self.routed_scaling_factor = routed_scaling_factor
|
1024
|
-
|
1025
970
|
# ---------------------------------------------------------------------
|
1026
971
|
# Helper: quantize hidden states to FP4 each forward pass
|
1027
972
|
# ---------------------------------------------------------------------
|
@@ -1052,21 +997,19 @@ class FlashInferFP4MoE(FusedMoE):
|
|
1052
997
|
|
1053
998
|
return hs_fp4, hs_sf
|
1054
999
|
|
1055
|
-
def forward(self, hidden_states: torch.Tensor, topk_output):
|
1000
|
+
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
1056
1001
|
"""Forward pass using FP4 TRTLLM kernel.
|
1057
1002
|
|
1058
1003
|
Args:
|
1059
1004
|
hidden_states: Input tensor
|
1060
|
-
topk_output:
|
1005
|
+
topk_output: TopKOutput object with Bypassed format
|
1061
1006
|
"""
|
1007
|
+
assert isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
1062
1008
|
|
1063
|
-
|
1064
|
-
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
1065
|
-
raise ValueError(
|
1066
|
-
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
1067
|
-
)
|
1009
|
+
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
1068
1010
|
|
1069
|
-
|
1011
|
+
router_logits = topk_output.router_logits
|
1012
|
+
topk_config = topk_output.topk_config
|
1070
1013
|
|
1071
1014
|
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
|
1072
1015
|
|
@@ -1074,7 +1017,7 @@ class FlashInferFP4MoE(FusedMoE):
|
|
1074
1017
|
|
1075
1018
|
result = trtllm_fp4_block_scale_moe(
|
1076
1019
|
routing_logits=router_logits,
|
1077
|
-
routing_bias=
|
1020
|
+
routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
|
1078
1021
|
hidden_states=hs_fp4,
|
1079
1022
|
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
|
1080
1023
|
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
|
@@ -1094,15 +1037,15 @@ class FlashInferFP4MoE(FusedMoE):
|
|
1094
1037
|
output1_scale_gate_scalar=self.g1_alphas.data,
|
1095
1038
|
output2_scale_scalar=self.g2_alphas.data,
|
1096
1039
|
num_experts=self.num_experts,
|
1097
|
-
top_k=
|
1098
|
-
n_group=
|
1099
|
-
topk_group=
|
1040
|
+
top_k=topk_config.top_k,
|
1041
|
+
n_group=topk_config.num_expert_group,
|
1042
|
+
topk_group=topk_config.topk_group,
|
1100
1043
|
intermediate_size=self.intermediate_size_per_partition,
|
1101
1044
|
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
|
1102
1045
|
local_num_experts=self.num_local_experts,
|
1103
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
1046
|
+
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
|
1104
1047
|
tile_tokens_dim=_get_tile_tokens_dim(
|
1105
|
-
hidden_states.shape[0],
|
1048
|
+
hidden_states.shape[0], topk_config.top_k, self.num_local_experts
|
1106
1049
|
),
|
1107
1050
|
routing_method_type=RoutingMethodType.DeepSeekV3,
|
1108
1051
|
do_finalize=True,
|
@@ -18,6 +18,7 @@ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
|
18
18
|
from triton_kernels.swiglu import swiglu_fn
|
19
19
|
|
20
20
|
if TYPE_CHECKING:
|
21
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
21
22
|
from sglang.srt.layers.moe.topk import TopKOutput
|
22
23
|
|
23
24
|
|
@@ -55,8 +56,7 @@ def triton_kernel_moe_forward(
|
|
55
56
|
w1: torch.Tensor,
|
56
57
|
w2: torch.Tensor,
|
57
58
|
topk_output: TopKOutput,
|
58
|
-
|
59
|
-
activation: str = "silu",
|
59
|
+
moe_runner_config: MoeRunnerConfig,
|
60
60
|
apply_router_weight_on_input: bool = False,
|
61
61
|
use_fp8_w8a8: bool = False,
|
62
62
|
per_channel_quant: bool = False,
|
@@ -69,7 +69,10 @@ def triton_kernel_moe_forward(
|
|
69
69
|
block_shape: Optional[list[int]] = None,
|
70
70
|
) -> torch.Tensor:
|
71
71
|
|
72
|
-
|
72
|
+
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
73
|
+
|
74
|
+
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
|
75
|
+
|
73
76
|
routing_data, gather_idx, scatter_idx = topk_output
|
74
77
|
|
75
78
|
return triton_kernel_fused_experts(
|
@@ -79,8 +82,8 @@ def triton_kernel_moe_forward(
|
|
79
82
|
routing_data,
|
80
83
|
gather_idx,
|
81
84
|
scatter_idx,
|
82
|
-
inplace=inplace
|
83
|
-
activation=activation,
|
85
|
+
inplace=False, # triton kernel doesn't support inplace
|
86
|
+
activation=moe_runner_config.activation,
|
84
87
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
85
88
|
use_fp8_w8a8=use_fp8_w8a8,
|
86
89
|
per_channel_quant=per_channel_quant,
|
@@ -192,8 +195,7 @@ def triton_kernel_moe_with_bias_forward(
|
|
192
195
|
w2_pcg,
|
193
196
|
b2: torch.Tensor,
|
194
197
|
topk_output: TopKOutput,
|
195
|
-
|
196
|
-
activation: str = "silu",
|
198
|
+
moe_runner_config: MoeRunnerConfig,
|
197
199
|
use_fp8_w8a8: bool = False,
|
198
200
|
per_channel_quant: bool = False,
|
199
201
|
global_num_experts: int = -1,
|
@@ -203,10 +205,11 @@ def triton_kernel_moe_with_bias_forward(
|
|
203
205
|
a1_scale: Optional[torch.Tensor] = None,
|
204
206
|
a2_scale: Optional[torch.Tensor] = None,
|
205
207
|
block_shape: Optional[list[int]] = None,
|
206
|
-
activation_alpha: Optional[float] = None,
|
207
|
-
swiglu_limit: Optional[int] = None,
|
208
208
|
) -> torch.Tensor:
|
209
|
-
|
209
|
+
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
210
|
+
|
211
|
+
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
|
212
|
+
|
210
213
|
routing_data, gather_idx, scatter_idx = topk_output
|
211
214
|
|
212
215
|
return triton_kernel_fused_experts_with_bias(
|
@@ -220,8 +223,8 @@ def triton_kernel_moe_with_bias_forward(
|
|
220
223
|
routing_data=routing_data,
|
221
224
|
gather_indx=gather_idx,
|
222
225
|
scatter_indx=scatter_idx,
|
223
|
-
inplace=inplace
|
224
|
-
activation=activation,
|
226
|
+
inplace=False, # triton kernel doesn't support inplace
|
227
|
+
activation=moe_runner_config.activation,
|
225
228
|
use_fp8_w8a8=use_fp8_w8a8,
|
226
229
|
per_channel_quant=per_channel_quant,
|
227
230
|
global_num_experts=global_num_experts,
|
@@ -231,8 +234,8 @@ def triton_kernel_moe_with_bias_forward(
|
|
231
234
|
a1_scale=a1_scale,
|
232
235
|
a2_scale=a2_scale,
|
233
236
|
block_shape=block_shape,
|
234
|
-
|
235
|
-
|
237
|
+
gemm1_alpha=moe_runner_config.gemm1_alpha,
|
238
|
+
gemm1_clamp_limit=moe_runner_config.gemm1_clamp_limit,
|
236
239
|
)
|
237
240
|
|
238
241
|
|
@@ -258,10 +261,9 @@ def triton_kernel_fused_experts_with_bias(
|
|
258
261
|
a1_scale: Optional[torch.Tensor] = None,
|
259
262
|
a2_scale: Optional[torch.Tensor] = None,
|
260
263
|
block_shape: Optional[list[int]] = None,
|
261
|
-
|
262
|
-
|
264
|
+
gemm1_alpha: Optional[float] = None,
|
265
|
+
gemm1_clamp_limit: Optional[float] = None,
|
263
266
|
) -> torch.Tensor:
|
264
|
-
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
|
265
267
|
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
266
268
|
assert per_channel_quant == False, "per_channel_quant is not supported"
|
267
269
|
assert expert_map == None, "expert_map is not supported"
|
@@ -307,7 +309,7 @@ def triton_kernel_fused_experts_with_bias(
|
|
307
309
|
|
308
310
|
act = FusedActivation(
|
309
311
|
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
|
310
|
-
(
|
312
|
+
(gemm1_alpha, gemm1_clamp_limit),
|
311
313
|
2,
|
312
314
|
)
|
313
315
|
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class MoeRunnerConfig:
|
7
|
+
activation: str = "silu"
|
8
|
+
apply_router_weight_on_input: bool = False
|
9
|
+
inplace: bool = True
|
10
|
+
no_combine: bool = False
|
11
|
+
routed_scaling_factor: Optional[float] = None
|
12
|
+
gemm1_alpha: Optional[float] = None
|
13
|
+
gemm1_clamp_limit: Optional[float] = None
|
@@ -0,0 +1,141 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1rc2/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
4
|
+
from enum import IntEnum
|
5
|
+
from functools import cache
|
6
|
+
from typing import Optional
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.utils import direct_register_custom_op, get_bool_env_var, is_hip
|
11
|
+
|
12
|
+
_is_hip = is_hip()
|
13
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
14
|
+
|
15
|
+
|
16
|
+
class ActivationMethod(IntEnum):
|
17
|
+
# This allows interfacing with AITER ActivationType enum
|
18
|
+
# without importing the ActivationType enum from AITER globally.
|
19
|
+
SILU = 0
|
20
|
+
GELU = 1
|
21
|
+
|
22
|
+
|
23
|
+
def rocm_aiter_asm_moe_tkw1_impl(
|
24
|
+
hidden_states: torch.Tensor,
|
25
|
+
w1: torch.Tensor,
|
26
|
+
w2: torch.Tensor,
|
27
|
+
topk_weights: torch.Tensor,
|
28
|
+
topk_ids: torch.Tensor,
|
29
|
+
fc1_scale: Optional[torch.Tensor] = None,
|
30
|
+
fc2_scale: Optional[torch.Tensor] = None,
|
31
|
+
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
32
|
+
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
33
|
+
a16: bool = False,
|
34
|
+
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
35
|
+
expert_mask: Optional[torch.Tensor] = None,
|
36
|
+
activation_method: int = ActivationMethod.SILU.value,
|
37
|
+
) -> torch.Tensor:
|
38
|
+
|
39
|
+
from aiter import ActivationType
|
40
|
+
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
41
|
+
|
42
|
+
activation = ActivationType(activation_method)
|
43
|
+
|
44
|
+
return asm_moe_tkw1(
|
45
|
+
hidden_states,
|
46
|
+
w1,
|
47
|
+
w2,
|
48
|
+
topk_weights,
|
49
|
+
topk_ids,
|
50
|
+
fc1_scale=fc1_scale,
|
51
|
+
fc2_scale=fc2_scale,
|
52
|
+
fc1_smooth_scale=fc1_smooth_scale,
|
53
|
+
fc2_smooth_scale=fc2_smooth_scale,
|
54
|
+
a16=a16,
|
55
|
+
per_tensor_quant_scale=per_tensor_quant_scale,
|
56
|
+
expert_mask=expert_mask,
|
57
|
+
activation=activation,
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def rocm_aiter_asm_moe_tkw1_fake(
|
62
|
+
hidden_states: torch.Tensor,
|
63
|
+
w1: torch.Tensor,
|
64
|
+
w2: torch.Tensor,
|
65
|
+
topk_weights: torch.Tensor,
|
66
|
+
topk_ids: torch.Tensor,
|
67
|
+
fc1_scale: Optional[torch.Tensor] = None,
|
68
|
+
fc2_scale: Optional[torch.Tensor] = None,
|
69
|
+
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
70
|
+
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
71
|
+
a16: bool = False,
|
72
|
+
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
73
|
+
expert_mask: Optional[torch.Tensor] = None,
|
74
|
+
activation_method: int = ActivationMethod.SILU.value,
|
75
|
+
) -> torch.Tensor:
|
76
|
+
return torch.empty_like(hidden_states)
|
77
|
+
|
78
|
+
|
79
|
+
if _use_aiter:
|
80
|
+
|
81
|
+
direct_register_custom_op(
|
82
|
+
op_name="rocm_aiter_asm_moe_tkw1",
|
83
|
+
op_func=rocm_aiter_asm_moe_tkw1_impl,
|
84
|
+
mutates_args=[],
|
85
|
+
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
|
86
|
+
)
|
87
|
+
|
88
|
+
|
89
|
+
def rocm_fused_experts_tkw1(
|
90
|
+
hidden_states: torch.Tensor,
|
91
|
+
w1: torch.Tensor,
|
92
|
+
w2: torch.Tensor,
|
93
|
+
topk_weights: torch.Tensor,
|
94
|
+
topk_ids: torch.Tensor,
|
95
|
+
activation: str = "silu",
|
96
|
+
apply_router_weight_on_input: bool = False,
|
97
|
+
use_fp8_w8a8: bool = False,
|
98
|
+
per_channel_quant: bool = False,
|
99
|
+
w1_scale: Optional[torch.Tensor] = None,
|
100
|
+
w2_scale: Optional[torch.Tensor] = None,
|
101
|
+
a1_scale: Optional[torch.Tensor] = None,
|
102
|
+
a2_scale: Optional[torch.Tensor] = None,
|
103
|
+
block_shape: Optional[list[int]] = None,
|
104
|
+
) -> torch.Tensor:
|
105
|
+
|
106
|
+
activation_method = (
|
107
|
+
ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU
|
108
|
+
)
|
109
|
+
# All AITER Fused MoE kernels are expecting the following datatypes
|
110
|
+
topk_weights = topk_weights.to(torch.float32)
|
111
|
+
topk_ids = topk_ids.to(torch.int32)
|
112
|
+
|
113
|
+
# w8a8 per-channel quantization
|
114
|
+
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
115
|
+
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
|
116
|
+
# This applies topk_weights on the GEMM output of the first FC layer
|
117
|
+
# rather than the second FC.
|
118
|
+
assert (
|
119
|
+
topk_weights.dim() == 2
|
120
|
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
121
|
+
assert topk_weights.shape[-1] == 1, (
|
122
|
+
"Only support topk=1 when" " `apply_router_weight_on_input` is True"
|
123
|
+
)
|
124
|
+
|
125
|
+
return torch.ops.sglang.rocm_aiter_asm_moe_tkw1(
|
126
|
+
hidden_states,
|
127
|
+
w1,
|
128
|
+
w2,
|
129
|
+
topk_weights,
|
130
|
+
topk_ids,
|
131
|
+
fc1_scale=w1_scale,
|
132
|
+
fc2_scale=w2_scale,
|
133
|
+
fc1_smooth_scale=None,
|
134
|
+
fc2_smooth_scale=None,
|
135
|
+
a16=False,
|
136
|
+
per_tensor_quant_scale=None,
|
137
|
+
expert_mask=None,
|
138
|
+
activation_method=activation_method,
|
139
|
+
)
|
140
|
+
else:
|
141
|
+
assert False, "This should not be called."
|
sglang/srt/layers/moe/router.py
CHANGED
@@ -45,11 +45,14 @@ def fused_moe_router_kernel(
|
|
45
45
|
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
|
46
46
|
|
47
47
|
# logit softcap
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
48
|
+
if moe_softcapping == 0:
|
49
|
+
logits_softcapped = logits
|
50
|
+
else:
|
51
|
+
logits_scaled = logits / moe_softcapping
|
52
|
+
exped = tl.exp(2 * logits_scaled)
|
53
|
+
top = exped - 1
|
54
|
+
bottom = exped + 1
|
55
|
+
logits_softcapped = top / bottom * moe_softcapping
|
53
56
|
|
54
57
|
# Add bias after softcapping
|
55
58
|
if is_correction_bias:
|
@@ -207,9 +210,12 @@ def fused_moe_router_large_bs_kernel(
|
|
207
210
|
b_ptrs += BLOCK_SIZE_K
|
208
211
|
|
209
212
|
# 4. logit softcap
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
+
if moe_softcapping == 0:
|
214
|
+
logits_softcapped = acc
|
215
|
+
else:
|
216
|
+
logits_scaled = acc / moe_softcapping
|
217
|
+
exped = tl.exp(2 * logits_scaled)
|
218
|
+
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
213
219
|
|
214
220
|
# 5. top1
|
215
221
|
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
|
@@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel(
|
|
234
240
|
|
235
241
|
# 7. handle topk == 2
|
236
242
|
if topk == 2:
|
237
|
-
cond_top2 = (arange_block_size_n < num_experts)
|
243
|
+
cond_top2 = (arange_block_size_n < num_experts) & (
|
238
244
|
arange_block_size_n != top1[:, None]
|
239
245
|
)
|
240
246
|
top2 = tl.argmax(
|