sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,16 @@ from vllm import _custom_ops as ops
|
|
15
15
|
|
16
16
|
from sglang.srt.layers.moe.topk import select_experts
|
17
17
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
18
|
-
from sglang.srt.
|
18
|
+
from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8
|
19
|
+
from sglang.srt.utils import (
|
20
|
+
direct_register_custom_op,
|
21
|
+
get_bool_env_var,
|
22
|
+
get_device_name,
|
23
|
+
is_cuda_available,
|
24
|
+
is_hip,
|
25
|
+
)
|
19
26
|
|
20
|
-
|
27
|
+
is_hip_ = is_hip()
|
21
28
|
|
22
29
|
|
23
30
|
logger = logging.getLogger(__name__)
|
@@ -86,6 +93,7 @@ def fused_moe_kernel(
|
|
86
93
|
top_k: tl.constexpr,
|
87
94
|
compute_type: tl.constexpr,
|
88
95
|
use_fp8_w8a8: tl.constexpr,
|
96
|
+
use_int8_w8a8: tl.constexpr,
|
89
97
|
use_int8_w8a16: tl.constexpr,
|
90
98
|
even_Ks: tl.constexpr,
|
91
99
|
):
|
@@ -159,7 +167,7 @@ def fused_moe_kernel(
|
|
159
167
|
)
|
160
168
|
b_scale = tl.load(b_scale_ptrs)
|
161
169
|
|
162
|
-
if use_fp8_w8a8:
|
170
|
+
if use_fp8_w8a8 or use_int8_w8a8:
|
163
171
|
if group_k > 0 and group_n > 0:
|
164
172
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
165
173
|
offs_bsn = offs_bn // group_n
|
@@ -198,7 +206,7 @@ def fused_moe_kernel(
|
|
198
206
|
# We accumulate along the K dimension.
|
199
207
|
if use_int8_w8a16:
|
200
208
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
201
|
-
elif use_fp8_w8a8:
|
209
|
+
elif use_fp8_w8a8 or use_int8_w8a8:
|
202
210
|
if group_k > 0 and group_n > 0:
|
203
211
|
k_start = k * BLOCK_SIZE_K
|
204
212
|
offs_ks = k_start // group_k
|
@@ -221,7 +229,7 @@ def fused_moe_kernel(
|
|
221
229
|
accumulator = accumulator * moe_weight[:, None]
|
222
230
|
if use_int8_w8a16:
|
223
231
|
accumulator = (accumulator * b_scale).to(compute_type)
|
224
|
-
elif use_fp8_w8a8:
|
232
|
+
elif use_fp8_w8a8 or use_int8_w8a8:
|
225
233
|
if group_k > 0 and group_n > 0:
|
226
234
|
accumulator = accumulator.to(compute_type)
|
227
235
|
else:
|
@@ -477,8 +485,10 @@ def invoke_fused_moe_kernel(
|
|
477
485
|
config: Dict[str, Any],
|
478
486
|
compute_type: tl.dtype,
|
479
487
|
use_fp8_w8a8: bool,
|
488
|
+
use_int8_w8a8: bool,
|
480
489
|
use_int8_w8a16: bool,
|
481
490
|
block_shape: Optional[List[int]] = None,
|
491
|
+
no_combine: bool = False,
|
482
492
|
) -> None:
|
483
493
|
assert topk_weights.stride(1) == 1
|
484
494
|
assert sorted_token_ids.stride(0) == 1
|
@@ -499,6 +509,18 @@ def invoke_fused_moe_kernel(
|
|
499
509
|
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
500
510
|
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
501
511
|
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
512
|
+
elif use_int8_w8a8:
|
513
|
+
assert B_scale is not None
|
514
|
+
if block_shape is None:
|
515
|
+
padded_size = padding_size
|
516
|
+
A, A_scale = ops.scaled_int8_quant(A, A_scale)
|
517
|
+
else:
|
518
|
+
assert len(block_shape) == 2
|
519
|
+
block_n, block_k = block_shape[0], block_shape[1]
|
520
|
+
A, A_scale = per_token_group_quant_int8(A, block_k)
|
521
|
+
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
522
|
+
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
523
|
+
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
502
524
|
elif use_int8_w8a16:
|
503
525
|
assert B_scale is not None
|
504
526
|
else:
|
@@ -548,6 +570,7 @@ def invoke_fused_moe_kernel(
|
|
548
570
|
top_k=top_k,
|
549
571
|
compute_type=compute_type,
|
550
572
|
use_fp8_w8a8=use_fp8_w8a8,
|
573
|
+
use_int8_w8a8=use_int8_w8a8,
|
551
574
|
use_int8_w8a16=use_int8_w8a16,
|
552
575
|
even_Ks=even_Ks,
|
553
576
|
**config,
|
@@ -625,7 +648,7 @@ def get_default_config(
|
|
625
648
|
"BLOCK_SIZE_K": 128,
|
626
649
|
"GROUP_SIZE_M": 32,
|
627
650
|
"num_warps": 8,
|
628
|
-
"num_stages": 2 if
|
651
|
+
"num_stages": 2 if is_hip_ else 4,
|
629
652
|
}
|
630
653
|
if M <= E:
|
631
654
|
config = {
|
@@ -634,7 +657,7 @@ def get_default_config(
|
|
634
657
|
"BLOCK_SIZE_K": 128,
|
635
658
|
"GROUP_SIZE_M": 1,
|
636
659
|
"num_warps": 4,
|
637
|
-
"num_stages": 2 if
|
660
|
+
"num_stages": 2 if is_hip_ else 4,
|
638
661
|
}
|
639
662
|
else:
|
640
663
|
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
@@ -644,7 +667,7 @@ def get_default_config(
|
|
644
667
|
"BLOCK_SIZE_K": block_shape[1],
|
645
668
|
"GROUP_SIZE_M": 32,
|
646
669
|
"num_warps": 4,
|
647
|
-
"num_stages": 2 if
|
670
|
+
"num_stages": 2 if is_hip_ else 3,
|
648
671
|
}
|
649
672
|
else:
|
650
673
|
config = {
|
@@ -701,9 +724,12 @@ def get_config_dtype_str(
|
|
701
724
|
dtype: torch.dtype,
|
702
725
|
use_int8_w8a16: Optional[bool] = False,
|
703
726
|
use_fp8_w8a8: Optional[bool] = False,
|
727
|
+
use_int8_w8a8: Optional[bool] = False,
|
704
728
|
):
|
705
729
|
if use_fp8_w8a8:
|
706
730
|
return "fp8_w8a8"
|
731
|
+
elif use_int8_w8a8:
|
732
|
+
return "int8_w8a8"
|
707
733
|
elif use_int8_w8a16:
|
708
734
|
return "int8_w8a16"
|
709
735
|
elif dtype == torch.float:
|
@@ -721,6 +747,7 @@ def inplace_fused_experts(
|
|
721
747
|
topk_ids: torch.Tensor,
|
722
748
|
activation: str = "silu",
|
723
749
|
use_fp8_w8a8: bool = False,
|
750
|
+
use_int8_w8a8: bool = False,
|
724
751
|
use_int8_w8a16: bool = False,
|
725
752
|
w1_scale: Optional[torch.Tensor] = None,
|
726
753
|
w2_scale: Optional[torch.Tensor] = None,
|
@@ -737,6 +764,7 @@ def inplace_fused_experts(
|
|
737
764
|
True,
|
738
765
|
activation,
|
739
766
|
use_fp8_w8a8,
|
767
|
+
use_int8_w8a8,
|
740
768
|
use_int8_w8a16,
|
741
769
|
w1_scale,
|
742
770
|
w2_scale,
|
@@ -754,6 +782,7 @@ def inplace_fused_experts_fake(
|
|
754
782
|
topk_ids: torch.Tensor,
|
755
783
|
activation: str = "silu",
|
756
784
|
use_fp8_w8a8: bool = False,
|
785
|
+
use_int8_w8a8: bool = False,
|
757
786
|
use_int8_w8a16: bool = False,
|
758
787
|
w1_scale: Optional[torch.Tensor] = None,
|
759
788
|
w2_scale: Optional[torch.Tensor] = None,
|
@@ -780,12 +809,14 @@ def outplace_fused_experts(
|
|
780
809
|
topk_ids: torch.Tensor,
|
781
810
|
activation: str = "silu",
|
782
811
|
use_fp8_w8a8: bool = False,
|
812
|
+
use_int8_w8a8: bool = False,
|
783
813
|
use_int8_w8a16: bool = False,
|
784
814
|
w1_scale: Optional[torch.Tensor] = None,
|
785
815
|
w2_scale: Optional[torch.Tensor] = None,
|
786
816
|
a1_scale: Optional[torch.Tensor] = None,
|
787
817
|
a2_scale: Optional[torch.Tensor] = None,
|
788
818
|
block_shape: Optional[List[int]] = None,
|
819
|
+
no_combine: bool = False,
|
789
820
|
) -> torch.Tensor:
|
790
821
|
return fused_experts_impl(
|
791
822
|
hidden_states,
|
@@ -796,12 +827,14 @@ def outplace_fused_experts(
|
|
796
827
|
False,
|
797
828
|
activation,
|
798
829
|
use_fp8_w8a8,
|
830
|
+
use_int8_w8a8,
|
799
831
|
use_int8_w8a16,
|
800
832
|
w1_scale,
|
801
833
|
w2_scale,
|
802
834
|
a1_scale,
|
803
835
|
a2_scale,
|
804
836
|
block_shape,
|
837
|
+
no_combine=no_combine,
|
805
838
|
)
|
806
839
|
|
807
840
|
|
@@ -813,12 +846,14 @@ def outplace_fused_experts_fake(
|
|
813
846
|
topk_ids: torch.Tensor,
|
814
847
|
activation: str = "silu",
|
815
848
|
use_fp8_w8a8: bool = False,
|
849
|
+
use_int8_w8a8: bool = False,
|
816
850
|
use_int8_w8a16: bool = False,
|
817
851
|
w1_scale: Optional[torch.Tensor] = None,
|
818
852
|
w2_scale: Optional[torch.Tensor] = None,
|
819
853
|
a1_scale: Optional[torch.Tensor] = None,
|
820
854
|
a2_scale: Optional[torch.Tensor] = None,
|
821
855
|
block_shape: Optional[List[int]] = None,
|
856
|
+
no_combine: bool = False,
|
822
857
|
) -> torch.Tensor:
|
823
858
|
return torch.empty_like(hidden_states)
|
824
859
|
|
@@ -840,14 +875,17 @@ def fused_experts(
|
|
840
875
|
inplace: bool = False,
|
841
876
|
activation: str = "silu",
|
842
877
|
use_fp8_w8a8: bool = False,
|
878
|
+
use_int8_w8a8: bool = False,
|
843
879
|
use_int8_w8a16: bool = False,
|
844
880
|
w1_scale: Optional[torch.Tensor] = None,
|
845
881
|
w2_scale: Optional[torch.Tensor] = None,
|
846
882
|
a1_scale: Optional[torch.Tensor] = None,
|
847
883
|
a2_scale: Optional[torch.Tensor] = None,
|
848
884
|
block_shape: Optional[List[int]] = None,
|
885
|
+
no_combine: bool = False,
|
849
886
|
):
|
850
887
|
if inplace:
|
888
|
+
assert not no_combine, "no combine + inplace makes no sense"
|
851
889
|
torch.ops.sglang.inplace_fused_experts(
|
852
890
|
hidden_states,
|
853
891
|
w1,
|
@@ -856,6 +894,7 @@ def fused_experts(
|
|
856
894
|
topk_ids,
|
857
895
|
activation,
|
858
896
|
use_fp8_w8a8,
|
897
|
+
use_int8_w8a8,
|
859
898
|
use_int8_w8a16,
|
860
899
|
w1_scale,
|
861
900
|
w2_scale,
|
@@ -873,12 +912,14 @@ def fused_experts(
|
|
873
912
|
topk_ids,
|
874
913
|
activation,
|
875
914
|
use_fp8_w8a8,
|
915
|
+
use_int8_w8a8,
|
876
916
|
use_int8_w8a16,
|
877
917
|
w1_scale,
|
878
918
|
w2_scale,
|
879
919
|
a1_scale,
|
880
920
|
a2_scale,
|
881
921
|
block_shape,
|
922
|
+
no_combine=no_combine,
|
882
923
|
)
|
883
924
|
|
884
925
|
|
@@ -891,15 +932,21 @@ def fused_experts_impl(
|
|
891
932
|
inplace: bool = False,
|
892
933
|
activation: str = "silu",
|
893
934
|
use_fp8_w8a8: bool = False,
|
935
|
+
use_int8_w8a8: bool = False,
|
894
936
|
use_int8_w8a16: bool = False,
|
895
937
|
w1_scale: Optional[torch.Tensor] = None,
|
896
938
|
w2_scale: Optional[torch.Tensor] = None,
|
897
939
|
a1_scale: Optional[torch.Tensor] = None,
|
898
940
|
a2_scale: Optional[torch.Tensor] = None,
|
899
941
|
block_shape: Optional[List[int]] = None,
|
942
|
+
no_combine: bool = False,
|
900
943
|
):
|
901
944
|
padded_size = padding_size
|
902
|
-
if
|
945
|
+
if (
|
946
|
+
not (use_fp8_w8a8 or use_int8_w8a8)
|
947
|
+
or block_shape is not None
|
948
|
+
or (is_hip_ and get_bool_env_var("CK_MOE"))
|
949
|
+
):
|
903
950
|
padded_size = 0
|
904
951
|
|
905
952
|
# Check constraints.
|
@@ -918,6 +965,7 @@ def fused_experts_impl(
|
|
918
965
|
M = min(num_tokens, CHUNK_SIZE)
|
919
966
|
config_dtype = get_config_dtype_str(
|
920
967
|
use_fp8_w8a8=use_fp8_w8a8,
|
968
|
+
use_int8_w8a8=use_int8_w8a8,
|
921
969
|
use_int8_w8a16=use_int8_w8a16,
|
922
970
|
dtype=hidden_states.dtype,
|
923
971
|
)
|
@@ -933,25 +981,33 @@ def fused_experts_impl(
|
|
933
981
|
|
934
982
|
config = get_config_func(M)
|
935
983
|
|
936
|
-
|
937
|
-
|
984
|
+
cache = torch.empty(
|
985
|
+
M * topk_ids.shape[1] * max(N, w2.shape[1]),
|
938
986
|
device=hidden_states.device,
|
939
987
|
dtype=hidden_states.dtype,
|
940
988
|
)
|
989
|
+
intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view(
|
990
|
+
(M, topk_ids.shape[1], N),
|
991
|
+
)
|
941
992
|
intermediate_cache2 = torch.empty(
|
942
993
|
(M * topk_ids.shape[1], N // 2),
|
943
994
|
device=hidden_states.device,
|
944
995
|
dtype=hidden_states.dtype,
|
945
996
|
)
|
946
|
-
intermediate_cache3 =
|
997
|
+
intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view(
|
947
998
|
(M, topk_ids.shape[1], w2.shape[1]),
|
948
|
-
device=hidden_states.device,
|
949
|
-
dtype=hidden_states.dtype,
|
950
999
|
)
|
951
1000
|
|
952
1001
|
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
953
1002
|
|
954
|
-
if
|
1003
|
+
if no_combine:
|
1004
|
+
assert not inplace
|
1005
|
+
out_hidden_states = torch.empty(
|
1006
|
+
(num_tokens, topk_ids.shape[1], w2.shape[1]),
|
1007
|
+
device=hidden_states.device,
|
1008
|
+
dtype=hidden_states.dtype,
|
1009
|
+
)
|
1010
|
+
elif inplace:
|
955
1011
|
out_hidden_states = hidden_states
|
956
1012
|
else:
|
957
1013
|
out_hidden_states = torch.empty_like(hidden_states)
|
@@ -1000,6 +1056,7 @@ def fused_experts_impl(
|
|
1000
1056
|
config,
|
1001
1057
|
compute_type=compute_type,
|
1002
1058
|
use_fp8_w8a8=use_fp8_w8a8,
|
1059
|
+
use_int8_w8a8=use_int8_w8a8,
|
1003
1060
|
use_int8_w8a16=use_int8_w8a16,
|
1004
1061
|
block_shape=block_shape,
|
1005
1062
|
)
|
@@ -1020,7 +1077,11 @@ def fused_experts_impl(
|
|
1020
1077
|
invoke_fused_moe_kernel(
|
1021
1078
|
intermediate_cache2,
|
1022
1079
|
w2,
|
1023
|
-
|
1080
|
+
(
|
1081
|
+
intermediate_cache3
|
1082
|
+
if not no_combine and topk_ids.shape[1] != 1
|
1083
|
+
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
|
1084
|
+
),
|
1024
1085
|
a2_scale,
|
1025
1086
|
w2_scale,
|
1026
1087
|
curr_topk_weights,
|
@@ -1033,20 +1094,21 @@ def fused_experts_impl(
|
|
1033
1094
|
config,
|
1034
1095
|
compute_type=compute_type,
|
1035
1096
|
use_fp8_w8a8=use_fp8_w8a8,
|
1097
|
+
use_int8_w8a8=use_int8_w8a8,
|
1036
1098
|
use_int8_w8a16=use_int8_w8a16,
|
1037
1099
|
block_shape=block_shape,
|
1038
1100
|
)
|
1039
1101
|
|
1040
|
-
if
|
1102
|
+
if no_combine:
|
1103
|
+
pass
|
1104
|
+
elif is_hip_:
|
1041
1105
|
ops.moe_sum(
|
1042
1106
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
1043
1107
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1044
1108
|
)
|
1045
1109
|
else:
|
1046
1110
|
if topk_ids.shape[1] == 1:
|
1047
|
-
out_hidden_states
|
1048
|
-
intermediate_cache3[:, 0]
|
1049
|
-
)
|
1111
|
+
pass # we write directly into out_hidden_states
|
1050
1112
|
elif topk_ids.shape[1] == 2:
|
1051
1113
|
torch.add(
|
1052
1114
|
intermediate_cache3[:, 0],
|
@@ -1077,12 +1139,14 @@ def fused_moe(
|
|
1077
1139
|
topk_group: Optional[int] = None,
|
1078
1140
|
custom_routing_function: Optional[Callable] = None,
|
1079
1141
|
use_fp8_w8a8: bool = False,
|
1142
|
+
use_int8_w8a8: bool = False,
|
1080
1143
|
use_int8_w8a16: bool = False,
|
1081
1144
|
w1_scale: Optional[torch.Tensor] = None,
|
1082
1145
|
w2_scale: Optional[torch.Tensor] = None,
|
1083
1146
|
a1_scale: Optional[torch.Tensor] = None,
|
1084
1147
|
a2_scale: Optional[torch.Tensor] = None,
|
1085
1148
|
block_shape: Optional[List[int]] = None,
|
1149
|
+
no_combine: bool = False,
|
1086
1150
|
) -> torch.Tensor:
|
1087
1151
|
"""
|
1088
1152
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
@@ -1104,6 +1168,8 @@ def fused_moe(
|
|
1104
1168
|
note: Deepseek V2/V3/R1 series models use grouped_topk
|
1105
1169
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
1106
1170
|
products for w1 and w2. Defaults to False.
|
1171
|
+
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
1172
|
+
products for w1 and w2. Defaults to False.
|
1107
1173
|
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
1108
1174
|
products for w1 and w2. Defaults to False.
|
1109
1175
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
@@ -1143,10 +1209,12 @@ def fused_moe(
|
|
1143
1209
|
inplace=inplace,
|
1144
1210
|
activation=activation,
|
1145
1211
|
use_fp8_w8a8=use_fp8_w8a8,
|
1212
|
+
use_int8_w8a8=use_int8_w8a8,
|
1146
1213
|
use_int8_w8a16=use_int8_w8a16,
|
1147
1214
|
w1_scale=w1_scale,
|
1148
1215
|
w2_scale=w2_scale,
|
1149
1216
|
a1_scale=a1_scale,
|
1150
1217
|
a2_scale=a2_scale,
|
1151
1218
|
block_shape=block_shape,
|
1219
|
+
no_combine=no_combine,
|
1152
1220
|
)
|
@@ -29,6 +29,9 @@ import logging
|
|
29
29
|
|
30
30
|
is_hip_ = is_hip()
|
31
31
|
|
32
|
+
if is_hip_:
|
33
|
+
from aiter import ck_moe
|
34
|
+
|
32
35
|
logger = logging.getLogger(__name__)
|
33
36
|
|
34
37
|
|
@@ -125,6 +128,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
125
128
|
custom_routing_function: Optional[Callable] = None,
|
126
129
|
correction_bias: Optional[torch.Tensor] = None,
|
127
130
|
activation: str = "silu",
|
131
|
+
inplace: bool = True,
|
132
|
+
no_combine: bool = False,
|
128
133
|
) -> torch.Tensor:
|
129
134
|
return self.forward(
|
130
135
|
x=x,
|
@@ -138,6 +143,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
138
143
|
custom_routing_function=custom_routing_function,
|
139
144
|
correction_bias=correction_bias,
|
140
145
|
activation=activation,
|
146
|
+
inplace=inplace,
|
147
|
+
no_combine=no_combine,
|
141
148
|
)
|
142
149
|
|
143
150
|
def forward_cuda(
|
@@ -153,6 +160,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
153
160
|
custom_routing_function: Optional[Callable] = None,
|
154
161
|
correction_bias: Optional[torch.Tensor] = None,
|
155
162
|
activation: str = "silu",
|
163
|
+
inplace: bool = True,
|
164
|
+
no_combine: bool = False,
|
156
165
|
) -> torch.Tensor:
|
157
166
|
topk_weights, topk_ids = select_experts(
|
158
167
|
hidden_states=x,
|
@@ -167,17 +176,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
167
176
|
)
|
168
177
|
|
169
178
|
if is_hip_ and get_bool_env_var("CK_MOE"):
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
179
|
+
assert not no_combine, "unsupported"
|
180
|
+
return ck_moe(
|
181
|
+
x,
|
182
|
+
layer.w13_weight,
|
183
|
+
layer.w2_weight,
|
184
|
+
topk_weights,
|
185
|
+
topk_ids,
|
186
|
+
None,
|
187
|
+
None,
|
188
|
+
None,
|
189
|
+
None,
|
190
|
+
32,
|
191
|
+
None,
|
192
|
+
activation,
|
181
193
|
)
|
182
194
|
else:
|
183
195
|
return fused_experts(
|
@@ -186,8 +198,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
186
198
|
w2=layer.w2_weight,
|
187
199
|
topk_weights=topk_weights,
|
188
200
|
topk_ids=topk_ids,
|
189
|
-
inplace=
|
201
|
+
inplace=inplace and not no_combine,
|
190
202
|
activation=activation,
|
203
|
+
no_combine=no_combine,
|
191
204
|
)
|
192
205
|
|
193
206
|
def forward_cpu(
|
@@ -202,6 +215,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
202
215
|
num_expert_group: Optional[int] = None,
|
203
216
|
custom_routing_function: Optional[Callable] = None,
|
204
217
|
correction_bias: Optional[torch.Tensor] = None,
|
218
|
+
inplace: bool = True,
|
205
219
|
) -> torch.Tensor:
|
206
220
|
return moe_forward_native(
|
207
221
|
layer,
|
@@ -241,6 +255,7 @@ class FusedMoE(torch.nn.Module):
|
|
241
255
|
reduce_results: Whether to all all_reduce on the output of the layer
|
242
256
|
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
243
257
|
quant_config: Quantization configure.
|
258
|
+
inplace: suggestion to compute inplace (modify input activation).
|
244
259
|
"""
|
245
260
|
|
246
261
|
def __init__(
|
@@ -262,6 +277,8 @@ class FusedMoE(torch.nn.Module):
|
|
262
277
|
correction_bias: Optional[torch.Tensor] = None,
|
263
278
|
activation: str = "silu",
|
264
279
|
use_presharded_weights: bool = False,
|
280
|
+
inplace: bool = True,
|
281
|
+
no_combine: bool = False,
|
265
282
|
):
|
266
283
|
super().__init__()
|
267
284
|
|
@@ -285,6 +302,9 @@ class FusedMoE(torch.nn.Module):
|
|
285
302
|
self.custom_routing_function = custom_routing_function
|
286
303
|
self.correction_bias = correction_bias
|
287
304
|
self.activation = activation
|
305
|
+
self.use_presharded_weights = use_presharded_weights
|
306
|
+
self.inplace = inplace
|
307
|
+
self.no_combine = no_combine
|
288
308
|
|
289
309
|
if quant_config is None:
|
290
310
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
@@ -304,7 +324,6 @@ class FusedMoE(torch.nn.Module):
|
|
304
324
|
params_dtype=params_dtype,
|
305
325
|
weight_loader=self.weight_loader,
|
306
326
|
)
|
307
|
-
self.use_presharded_weights = use_presharded_weights
|
308
327
|
|
309
328
|
def _load_per_tensor_weight_scale(
|
310
329
|
self,
|
@@ -598,6 +617,8 @@ class FusedMoE(torch.nn.Module):
|
|
598
617
|
custom_routing_function=self.custom_routing_function,
|
599
618
|
correction_bias=self.correction_bias,
|
600
619
|
activation=self.activation,
|
620
|
+
inplace=self.inplace,
|
621
|
+
no_combine=self.no_combine,
|
601
622
|
)
|
602
623
|
|
603
624
|
if self.reduce_results and self.tp_size > 1:
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -75,7 +75,6 @@ def fused_topk(
|
|
75
75
|
return topk_weights, topk_ids
|
76
76
|
|
77
77
|
|
78
|
-
# This is used by the Deepseek V2/V3/R1 series models
|
79
78
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
80
79
|
def grouped_topk(
|
81
80
|
hidden_states: torch.Tensor,
|
@@ -84,10 +83,17 @@ def grouped_topk(
|
|
84
83
|
renormalize: bool,
|
85
84
|
num_expert_group: int = 0,
|
86
85
|
topk_group: int = 0,
|
86
|
+
scoring_func: str = "softmax",
|
87
87
|
):
|
88
88
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
89
89
|
|
90
|
-
|
90
|
+
if scoring_func == "softmax":
|
91
|
+
scores = torch.softmax(gating_output, dim=-1)
|
92
|
+
elif scoring_func == "sigmoid":
|
93
|
+
scores = gating_output.sigmoid()
|
94
|
+
else:
|
95
|
+
raise ValueError(f"Scoring function '{scoring_func}' is not supported.")
|
96
|
+
|
91
97
|
num_token = scores.shape[0]
|
92
98
|
group_scores = (
|
93
99
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
@@ -111,6 +117,7 @@ def grouped_topk(
|
|
111
117
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
112
118
|
|
113
119
|
|
120
|
+
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
114
121
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
115
122
|
def biased_grouped_topk(
|
116
123
|
hidden_states: torch.Tensor,
|
@@ -141,7 +148,9 @@ def biased_grouped_topk(
|
|
141
148
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
142
149
|
.reshape(num_token, -1)
|
143
150
|
) # [n, e]
|
144
|
-
tmp_scores = scores_for_choice.masked_fill(
|
151
|
+
tmp_scores = scores_for_choice.masked_fill(
|
152
|
+
~score_mask.bool(), float("-inf")
|
153
|
+
) # [n, e]
|
145
154
|
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
146
155
|
topk_weights = scores.gather(1, topk_ids)
|
147
156
|
|
@@ -163,7 +172,7 @@ def select_experts(
|
|
163
172
|
correction_bias: Optional[torch.Tensor] = None,
|
164
173
|
torch_native: bool = False,
|
165
174
|
):
|
166
|
-
#
|
175
|
+
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
167
176
|
if use_grouped_topk:
|
168
177
|
assert topk_group is not None
|
169
178
|
assert num_expert_group is not None
|