sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 32,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 64,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 64,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 64,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 256,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 32,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -6,13 +6,13 @@ import functools
|
|
6
6
|
import json
|
7
7
|
import logging
|
8
8
|
import os
|
9
|
-
from typing import Any,
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple
|
10
10
|
|
11
11
|
import torch
|
12
12
|
import triton
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
|
-
from sglang.srt.layers.moe.topk import
|
15
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
16
16
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
17
17
|
per_token_group_quant_fp8,
|
18
18
|
scaled_fp8_quant,
|
@@ -39,11 +39,20 @@ _is_hip = is_hip()
|
|
39
39
|
_is_cuda = is_cuda()
|
40
40
|
_is_cpu_amx_available = cpu_has_amx_support()
|
41
41
|
_is_cpu = is_cpu()
|
42
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
42
43
|
|
43
44
|
if _is_cuda:
|
44
45
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
45
46
|
elif _is_cpu and _is_cpu_amx_available:
|
46
47
|
pass
|
48
|
+
elif _is_hip:
|
49
|
+
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
|
50
|
+
|
51
|
+
if _use_aiter:
|
52
|
+
try:
|
53
|
+
from aiter import moe_sum
|
54
|
+
except ImportError:
|
55
|
+
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
47
56
|
else:
|
48
57
|
from vllm import _custom_ops as vllm_ops
|
49
58
|
from vllm._custom_ops import scaled_fp8_quant
|
@@ -752,14 +761,13 @@ def moe_align_block_size(
|
|
752
761
|
sorted_ids = torch.empty(
|
753
762
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
754
763
|
)
|
755
|
-
sorted_ids.fill_(topk_ids.numel())
|
756
|
-
|
757
764
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
758
765
|
expert_ids = torch.empty(
|
759
766
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
760
767
|
)
|
761
768
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
762
769
|
if enable_moe_align_block_size_triton:
|
770
|
+
sorted_ids.fill_(topk_ids.numel())
|
763
771
|
moe_align_block_size_triton(
|
764
772
|
topk_ids,
|
765
773
|
num_experts,
|
@@ -778,6 +786,11 @@ def moe_align_block_size(
|
|
778
786
|
device=topk_ids.device,
|
779
787
|
)
|
780
788
|
|
789
|
+
# Threshold based on benchmark results
|
790
|
+
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
|
791
|
+
if not fuse_sorted_ids_padding:
|
792
|
+
sorted_ids.fill_(topk_ids.numel())
|
793
|
+
|
781
794
|
sgl_moe_align_block_size(
|
782
795
|
topk_ids,
|
783
796
|
num_experts,
|
@@ -787,6 +800,7 @@ def moe_align_block_size(
|
|
787
800
|
num_tokens_post_pad,
|
788
801
|
token_cnts_buffer,
|
789
802
|
cumsum_buffer,
|
803
|
+
fuse_sorted_ids_padding,
|
790
804
|
)
|
791
805
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
792
806
|
|
@@ -1328,8 +1342,7 @@ def fused_experts(
|
|
1328
1342
|
hidden_states: torch.Tensor,
|
1329
1343
|
w1: torch.Tensor,
|
1330
1344
|
w2: torch.Tensor,
|
1331
|
-
|
1332
|
-
topk_ids: torch.Tensor,
|
1345
|
+
topk_output: TopKOutput,
|
1333
1346
|
inplace: bool = False,
|
1334
1347
|
activation: str = "silu",
|
1335
1348
|
apply_router_weight_on_input: bool = False,
|
@@ -1348,7 +1361,7 @@ def fused_experts(
|
|
1348
1361
|
no_combine: bool = False,
|
1349
1362
|
routed_scaling_factor: Optional[float] = None,
|
1350
1363
|
):
|
1351
|
-
|
1364
|
+
topk_weights, topk_ids, _ = topk_output
|
1352
1365
|
if inplace:
|
1353
1366
|
assert not no_combine, "no combine + inplace makes no sense"
|
1354
1367
|
torch.ops.sglang.inplace_fused_experts(
|
@@ -1517,11 +1530,7 @@ def fused_experts_impl(
|
|
1517
1530
|
routed_scaling_factor: Optional[float] = None,
|
1518
1531
|
):
|
1519
1532
|
padded_size = padding_size
|
1520
|
-
if (
|
1521
|
-
not (use_fp8_w8a8 or use_int8_w8a8)
|
1522
|
-
or block_shape is not None
|
1523
|
-
or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
|
1524
|
-
):
|
1533
|
+
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
1525
1534
|
padded_size = 0
|
1526
1535
|
|
1527
1536
|
# Check constraints.
|
@@ -1719,6 +1728,17 @@ def fused_experts_impl(
|
|
1719
1728
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1720
1729
|
routed_scaling_factor,
|
1721
1730
|
)
|
1731
|
+
elif _is_hip:
|
1732
|
+
if _use_aiter:
|
1733
|
+
moe_sum(
|
1734
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
1735
|
+
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1736
|
+
)
|
1737
|
+
else:
|
1738
|
+
vllm_ops.moe_sum(
|
1739
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
1740
|
+
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1741
|
+
)
|
1722
1742
|
else:
|
1723
1743
|
vllm_ops.moe_sum(
|
1724
1744
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
@@ -1732,17 +1752,10 @@ def fused_moe(
|
|
1732
1752
|
hidden_states: torch.Tensor,
|
1733
1753
|
w1: torch.Tensor,
|
1734
1754
|
w2: torch.Tensor,
|
1735
|
-
|
1736
|
-
topk: int,
|
1737
|
-
renormalize: bool,
|
1755
|
+
topk_output: TopKOutput,
|
1738
1756
|
inplace: bool = False,
|
1739
1757
|
activation: str = "silu",
|
1740
1758
|
apply_router_weight_on_input: bool = False,
|
1741
|
-
use_grouped_topk: bool = False,
|
1742
|
-
num_expert_group: Optional[int] = None,
|
1743
|
-
num_fused_shared_experts: int = 0,
|
1744
|
-
topk_group: Optional[int] = None,
|
1745
|
-
custom_routing_function: Optional[Callable] = None,
|
1746
1759
|
use_fp8_w8a8: bool = False,
|
1747
1760
|
use_int8_w8a8: bool = False,
|
1748
1761
|
use_int8_w8a16: bool = False,
|
@@ -1766,16 +1779,9 @@ def fused_moe(
|
|
1766
1779
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
1767
1780
|
- w1 (torch.Tensor): The first set of expert weights.
|
1768
1781
|
- w2 (torch.Tensor): The second set of expert weights.
|
1769
|
-
-
|
1770
|
-
(before softmax).
|
1771
|
-
- topk (int): The number of top-k experts to select.
|
1772
|
-
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
1782
|
+
- topk_output (TopKOutput): The top-k output of the experts.
|
1773
1783
|
- inplace (bool): If True, perform the operation in-place.
|
1774
1784
|
Defaults to False.
|
1775
|
-
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
1776
|
-
- topk_group: Optional[int]: additional parameter for grouped_topk
|
1777
|
-
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
1778
|
-
note: Deepseek V2/V3/R1 series models use grouped_topk
|
1779
1785
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
1780
1786
|
products for w1 and w2. Defaults to False.
|
1781
1787
|
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
@@ -1799,28 +1805,12 @@ def fused_moe(
|
|
1799
1805
|
Returns:
|
1800
1806
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
1801
1807
|
"""
|
1802
|
-
# Check constraints.
|
1803
|
-
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
1804
|
-
|
1805
|
-
topk_weights, topk_ids = select_experts(
|
1806
|
-
hidden_states=hidden_states,
|
1807
|
-
router_logits=gating_output,
|
1808
|
-
use_grouped_topk=use_grouped_topk,
|
1809
|
-
top_k=topk,
|
1810
|
-
renormalize=renormalize,
|
1811
|
-
topk_group=topk_group,
|
1812
|
-
num_expert_group=num_expert_group,
|
1813
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
1814
|
-
custom_routing_function=custom_routing_function,
|
1815
|
-
routed_scaling_factor=routed_scaling_factor,
|
1816
|
-
)
|
1817
1808
|
|
1818
1809
|
return fused_experts(
|
1819
1810
|
hidden_states,
|
1820
1811
|
w1,
|
1821
1812
|
w2,
|
1822
|
-
|
1823
|
-
topk_ids,
|
1813
|
+
topk_output,
|
1824
1814
|
inplace=inplace,
|
1825
1815
|
activation=activation,
|
1826
1816
|
apply_router_weight_on_input=apply_router_weight_on_input,
|