sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -1,60 +1,29 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
2
2
|
|
3
|
-
import
|
4
|
-
from abc import abstractmethod
|
3
|
+
import logging
|
5
4
|
from enum import Enum
|
6
|
-
from typing import
|
5
|
+
from typing import List, Optional, Tuple
|
7
6
|
|
8
7
|
import torch
|
9
8
|
|
10
|
-
from sglang.srt.custom_op import CustomOp
|
11
9
|
from sglang.srt.distributed import (
|
12
10
|
get_tensor_model_parallel_rank,
|
13
11
|
get_tensor_model_parallel_world_size,
|
14
12
|
tensor_model_parallel_all_reduce,
|
15
13
|
)
|
16
|
-
from sglang.srt.layers.
|
17
|
-
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
18
|
-
from sglang.srt.layers.moe.topk import select_experts
|
14
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
19
15
|
from sglang.srt.layers.quantization.base_config import (
|
20
16
|
QuantizationConfig,
|
21
17
|
QuantizeMethodBase,
|
22
18
|
)
|
19
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
23
20
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
24
21
|
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
25
|
-
from sglang.srt.utils import
|
26
|
-
cpu_has_amx_support,
|
27
|
-
get_bool_env_var,
|
28
|
-
is_cpu,
|
29
|
-
is_hip,
|
30
|
-
set_weight_attrs,
|
31
|
-
use_intel_amx_backend,
|
32
|
-
)
|
33
|
-
|
34
|
-
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
35
|
-
|
36
|
-
if torch.cuda.is_available():
|
37
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
38
|
-
|
39
|
-
if has_triton_kernels:
|
40
|
-
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
41
|
-
triton_kernel_moe_forward,
|
42
|
-
)
|
43
|
-
else:
|
44
|
-
fused_experts = None # type: ignore
|
45
|
-
|
46
|
-
import logging
|
22
|
+
from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip
|
47
23
|
|
48
24
|
_is_hip = is_hip()
|
49
25
|
_is_cpu_amx_available = cpu_has_amx_support()
|
50
26
|
_is_cpu = is_cpu()
|
51
|
-
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
52
|
-
|
53
|
-
if _use_aiter:
|
54
|
-
from aiter import ActivationType
|
55
|
-
from aiter.fused_moe import fused_moe
|
56
|
-
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
57
|
-
from aiter.ops.shuffle import shuffle_weight
|
58
27
|
|
59
28
|
logger = logging.getLogger(__name__)
|
60
29
|
|
@@ -66,333 +35,6 @@ class FusedMoeWeightScaleSupported(Enum):
|
|
66
35
|
BLOCK = "block"
|
67
36
|
|
68
37
|
|
69
|
-
class FusedMoEMethodBase(QuantizeMethodBase):
|
70
|
-
|
71
|
-
@abstractmethod
|
72
|
-
def create_weights(
|
73
|
-
self,
|
74
|
-
layer: torch.nn.Module,
|
75
|
-
num_experts: int,
|
76
|
-
hidden_size: int,
|
77
|
-
intermediate_size: int,
|
78
|
-
params_dtype: torch.dtype,
|
79
|
-
**extra_weight_attrs,
|
80
|
-
):
|
81
|
-
raise NotImplementedError
|
82
|
-
|
83
|
-
@abstractmethod
|
84
|
-
def apply(
|
85
|
-
self,
|
86
|
-
layer: torch.nn.Module,
|
87
|
-
x: torch.Tensor,
|
88
|
-
router_logits: torch.Tensor,
|
89
|
-
top_k: int,
|
90
|
-
renormalize: bool,
|
91
|
-
use_grouped_topk: bool,
|
92
|
-
) -> torch.Tensor:
|
93
|
-
raise NotImplementedError
|
94
|
-
|
95
|
-
|
96
|
-
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
97
|
-
"""MoE method without quantization."""
|
98
|
-
|
99
|
-
def __init__(self, use_triton_kernels: bool = False):
|
100
|
-
super().__init__()
|
101
|
-
self.use_triton_kernels = use_triton_kernels
|
102
|
-
|
103
|
-
def create_weights(
|
104
|
-
self,
|
105
|
-
layer: torch.nn.Module,
|
106
|
-
num_experts: int,
|
107
|
-
hidden_size: int,
|
108
|
-
intermediate_size: int,
|
109
|
-
params_dtype: torch.dtype,
|
110
|
-
**extra_weight_attrs,
|
111
|
-
):
|
112
|
-
# Fused gate_up_proj (column parallel)
|
113
|
-
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
114
|
-
if self.use_triton_kernels:
|
115
|
-
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
116
|
-
w13_weight = torch.nn.Parameter(
|
117
|
-
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
|
118
|
-
requires_grad=False,
|
119
|
-
)
|
120
|
-
layer.register_parameter("w13_weight", w13_weight)
|
121
|
-
set_weight_attrs(w13_weight, extra_weight_attrs)
|
122
|
-
|
123
|
-
# down_proj (row parallel)
|
124
|
-
w2_weight_n, w2_weight_k = (
|
125
|
-
hidden_size,
|
126
|
-
intermediate_size,
|
127
|
-
)
|
128
|
-
if self.use_triton_kernels:
|
129
|
-
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
130
|
-
w2_weight = torch.nn.Parameter(
|
131
|
-
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
|
132
|
-
requires_grad=False,
|
133
|
-
)
|
134
|
-
layer.register_parameter("w2_weight", w2_weight)
|
135
|
-
set_weight_attrs(w2_weight, extra_weight_attrs)
|
136
|
-
|
137
|
-
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
138
|
-
if _use_aiter:
|
139
|
-
layer.w13_weight = torch.nn.Parameter(
|
140
|
-
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
141
|
-
requires_grad=False,
|
142
|
-
)
|
143
|
-
torch.cuda.empty_cache()
|
144
|
-
layer.w2_weight = torch.nn.Parameter(
|
145
|
-
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
146
|
-
requires_grad=False,
|
147
|
-
)
|
148
|
-
torch.cuda.empty_cache()
|
149
|
-
|
150
|
-
# Pack weight for get better performance on CPU
|
151
|
-
if _is_cpu and _is_cpu_amx_available:
|
152
|
-
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
153
|
-
|
154
|
-
return
|
155
|
-
|
156
|
-
def apply(
|
157
|
-
self,
|
158
|
-
layer: torch.nn.Module,
|
159
|
-
x: torch.Tensor,
|
160
|
-
router_logits: torch.Tensor,
|
161
|
-
top_k: int,
|
162
|
-
renormalize: bool,
|
163
|
-
use_grouped_topk: bool,
|
164
|
-
topk_group: Optional[int] = None,
|
165
|
-
num_expert_group: Optional[int] = None,
|
166
|
-
num_fused_shared_experts: int = 0,
|
167
|
-
custom_routing_function: Optional[Callable] = None,
|
168
|
-
correction_bias: Optional[torch.Tensor] = None,
|
169
|
-
activation: str = "silu",
|
170
|
-
apply_router_weight_on_input: bool = False,
|
171
|
-
inplace: bool = True,
|
172
|
-
no_combine: bool = False,
|
173
|
-
routed_scaling_factor: Optional[float] = None,
|
174
|
-
) -> torch.Tensor:
|
175
|
-
return self.forward(
|
176
|
-
x=x,
|
177
|
-
layer=layer,
|
178
|
-
router_logits=router_logits,
|
179
|
-
top_k=top_k,
|
180
|
-
renormalize=renormalize,
|
181
|
-
use_grouped_topk=use_grouped_topk,
|
182
|
-
topk_group=topk_group,
|
183
|
-
num_expert_group=num_expert_group,
|
184
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
185
|
-
custom_routing_function=custom_routing_function,
|
186
|
-
correction_bias=correction_bias,
|
187
|
-
activation=activation,
|
188
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
189
|
-
inplace=inplace,
|
190
|
-
no_combine=no_combine,
|
191
|
-
routed_scaling_factor=routed_scaling_factor,
|
192
|
-
)
|
193
|
-
|
194
|
-
def forward_cuda(
|
195
|
-
self,
|
196
|
-
layer: torch.nn.Module,
|
197
|
-
x: torch.Tensor,
|
198
|
-
use_grouped_topk: bool,
|
199
|
-
top_k: int,
|
200
|
-
router_logits: torch.Tensor,
|
201
|
-
renormalize: bool,
|
202
|
-
topk_group: Optional[int] = None,
|
203
|
-
num_expert_group: Optional[int] = None,
|
204
|
-
num_fused_shared_experts: int = 0,
|
205
|
-
custom_routing_function: Optional[Callable] = None,
|
206
|
-
correction_bias: Optional[torch.Tensor] = None,
|
207
|
-
activation: str = "silu",
|
208
|
-
apply_router_weight_on_input: bool = False,
|
209
|
-
inplace: bool = True,
|
210
|
-
no_combine: bool = False,
|
211
|
-
routed_scaling_factor: Optional[float] = None,
|
212
|
-
) -> torch.Tensor:
|
213
|
-
|
214
|
-
if self.use_triton_kernels:
|
215
|
-
return triton_kernel_moe_forward(
|
216
|
-
hidden_states=x,
|
217
|
-
w1=layer.w13_weight,
|
218
|
-
w2=layer.w2_weight,
|
219
|
-
gating_output=router_logits,
|
220
|
-
topk=top_k,
|
221
|
-
renormalize=renormalize,
|
222
|
-
)
|
223
|
-
else:
|
224
|
-
topk_weights, topk_ids = select_experts(
|
225
|
-
hidden_states=x,
|
226
|
-
router_logits=router_logits,
|
227
|
-
use_grouped_topk=use_grouped_topk,
|
228
|
-
top_k=top_k,
|
229
|
-
renormalize=renormalize,
|
230
|
-
topk_group=topk_group,
|
231
|
-
num_expert_group=num_expert_group,
|
232
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
233
|
-
custom_routing_function=custom_routing_function,
|
234
|
-
correction_bias=correction_bias,
|
235
|
-
routed_scaling_factor=routed_scaling_factor,
|
236
|
-
)
|
237
|
-
|
238
|
-
if _use_aiter:
|
239
|
-
assert not no_combine, "unsupported"
|
240
|
-
if apply_router_weight_on_input:
|
241
|
-
assert (
|
242
|
-
topk_weights.dim() == 2
|
243
|
-
), "`topk_weights` should be in shape (num_tokens, topk)"
|
244
|
-
_, topk = topk_weights.shape
|
245
|
-
assert (
|
246
|
-
topk == 1
|
247
|
-
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
248
|
-
x = x * topk_weights.to(x.dtype)
|
249
|
-
topk_weights = torch.ones_like(
|
250
|
-
topk_weights, dtype=torch.float32
|
251
|
-
) # topk_weights must be FP32 (float32)
|
252
|
-
|
253
|
-
return fused_moe(
|
254
|
-
x,
|
255
|
-
layer.w13_weight,
|
256
|
-
layer.w2_weight,
|
257
|
-
topk_weights,
|
258
|
-
topk_ids,
|
259
|
-
activation=(
|
260
|
-
ActivationType.Silu
|
261
|
-
if activation == "silu"
|
262
|
-
else ActivationType.Gelu
|
263
|
-
),
|
264
|
-
)
|
265
|
-
else:
|
266
|
-
return fused_experts(
|
267
|
-
hidden_states=x,
|
268
|
-
w1=layer.w13_weight,
|
269
|
-
w2=layer.w2_weight,
|
270
|
-
topk_weights=topk_weights,
|
271
|
-
topk_ids=topk_ids,
|
272
|
-
inplace=inplace and not no_combine,
|
273
|
-
activation=activation,
|
274
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
275
|
-
no_combine=no_combine,
|
276
|
-
routed_scaling_factor=routed_scaling_factor,
|
277
|
-
)
|
278
|
-
|
279
|
-
def forward_cpu(
|
280
|
-
self,
|
281
|
-
layer: torch.nn.Module,
|
282
|
-
x: torch.Tensor,
|
283
|
-
use_grouped_topk: bool,
|
284
|
-
top_k: int,
|
285
|
-
router_logits: torch.Tensor,
|
286
|
-
renormalize: bool,
|
287
|
-
topk_group: Optional[int] = None,
|
288
|
-
num_expert_group: Optional[int] = None,
|
289
|
-
num_fused_shared_experts: int = 0,
|
290
|
-
custom_routing_function: Optional[Callable] = None,
|
291
|
-
correction_bias: Optional[torch.Tensor] = None,
|
292
|
-
activation: str = "silu",
|
293
|
-
apply_router_weight_on_input: bool = False,
|
294
|
-
inplace: bool = True,
|
295
|
-
no_combine: bool = False,
|
296
|
-
routed_scaling_factor: Optional[float] = None,
|
297
|
-
) -> torch.Tensor:
|
298
|
-
assert activation == "silu", f"activation = {activation} is not supported."
|
299
|
-
|
300
|
-
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
301
|
-
topk_weights, topk_ids = select_experts(
|
302
|
-
hidden_states=x,
|
303
|
-
router_logits=router_logits,
|
304
|
-
use_grouped_topk=use_grouped_topk,
|
305
|
-
top_k=top_k,
|
306
|
-
renormalize=renormalize,
|
307
|
-
topk_group=topk_group,
|
308
|
-
num_expert_group=num_expert_group,
|
309
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
310
|
-
custom_routing_function=custom_routing_function,
|
311
|
-
correction_bias=correction_bias,
|
312
|
-
routed_scaling_factor=routed_scaling_factor,
|
313
|
-
)
|
314
|
-
|
315
|
-
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
|
316
|
-
return torch.ops.sgl_kernel.fused_experts_cpu(
|
317
|
-
x,
|
318
|
-
layer.w13_weight,
|
319
|
-
layer.w2_weight,
|
320
|
-
topk_weights,
|
321
|
-
topk_ids,
|
322
|
-
False, # inplace # See [Note] inplace should be False in fused_experts.
|
323
|
-
False, # use_int8_w8a8
|
324
|
-
False, # use_fp8_w8a16
|
325
|
-
None, # w1_scale
|
326
|
-
None, # w2_scale
|
327
|
-
None, # block_size
|
328
|
-
None, # a1_scale
|
329
|
-
None, # a2_scale
|
330
|
-
True, # is_vnni
|
331
|
-
)
|
332
|
-
else:
|
333
|
-
return moe_forward_native(
|
334
|
-
layer,
|
335
|
-
x,
|
336
|
-
use_grouped_topk,
|
337
|
-
top_k,
|
338
|
-
router_logits,
|
339
|
-
renormalize,
|
340
|
-
topk_group,
|
341
|
-
num_expert_group,
|
342
|
-
num_fused_shared_experts,
|
343
|
-
custom_routing_function,
|
344
|
-
correction_bias,
|
345
|
-
activation,
|
346
|
-
apply_router_weight_on_input,
|
347
|
-
inplace,
|
348
|
-
no_combine,
|
349
|
-
routed_scaling_factor,
|
350
|
-
)
|
351
|
-
|
352
|
-
def forward_npu(
|
353
|
-
self,
|
354
|
-
layer: torch.nn.Module,
|
355
|
-
x: torch.Tensor,
|
356
|
-
use_grouped_topk: bool,
|
357
|
-
top_k: int,
|
358
|
-
router_logits: torch.Tensor,
|
359
|
-
renormalize: bool,
|
360
|
-
topk_group: Optional[int] = None,
|
361
|
-
num_expert_group: Optional[int] = None,
|
362
|
-
num_fused_shared_experts: int = 0,
|
363
|
-
custom_routing_function: Optional[Callable] = None,
|
364
|
-
correction_bias: Optional[torch.Tensor] = None,
|
365
|
-
activation: str = "silu",
|
366
|
-
apply_router_weight_on_input: bool = False,
|
367
|
-
inplace: bool = True,
|
368
|
-
no_combine: bool = False,
|
369
|
-
routed_scaling_factor: Optional[float] = None,
|
370
|
-
) -> torch.Tensor:
|
371
|
-
return moe_forward_native(
|
372
|
-
layer,
|
373
|
-
x,
|
374
|
-
use_grouped_topk,
|
375
|
-
top_k,
|
376
|
-
router_logits,
|
377
|
-
renormalize,
|
378
|
-
topk_group,
|
379
|
-
num_expert_group,
|
380
|
-
num_fused_shared_experts,
|
381
|
-
custom_routing_function,
|
382
|
-
correction_bias,
|
383
|
-
activation,
|
384
|
-
apply_router_weight_on_input,
|
385
|
-
inplace,
|
386
|
-
no_combine,
|
387
|
-
routed_scaling_factor,
|
388
|
-
)
|
389
|
-
|
390
|
-
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
391
|
-
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
392
|
-
|
393
|
-
forward_native = forward_cpu
|
394
|
-
|
395
|
-
|
396
38
|
class FusedMoE(torch.nn.Module):
|
397
39
|
"""FusedMoE layer for MoE models.
|
398
40
|
|
@@ -418,22 +60,15 @@ class FusedMoE(torch.nn.Module):
|
|
418
60
|
def __init__(
|
419
61
|
self,
|
420
62
|
num_experts: int,
|
421
|
-
top_k: int,
|
422
63
|
hidden_size: int,
|
423
64
|
intermediate_size: int,
|
65
|
+
top_k: Optional[int] = None,
|
424
66
|
layer_id: Optional[int] = None,
|
425
67
|
params_dtype: Optional[torch.dtype] = None,
|
426
68
|
reduce_results: bool = False,
|
427
|
-
renormalize: bool = True,
|
428
|
-
use_grouped_topk: bool = False,
|
429
|
-
num_expert_group: Optional[int] = None,
|
430
|
-
num_fused_shared_experts: int = 0,
|
431
|
-
topk_group: Optional[int] = None,
|
432
69
|
quant_config: Optional[QuantizationConfig] = None,
|
433
70
|
tp_size: Optional[int] = None,
|
434
71
|
prefix: str = "",
|
435
|
-
custom_routing_function: Optional[Callable] = None,
|
436
|
-
correction_bias: Optional[torch.Tensor] = None,
|
437
72
|
activation: str = "silu",
|
438
73
|
apply_router_weight_on_input: bool = False,
|
439
74
|
use_presharded_weights: bool = False,
|
@@ -448,6 +83,7 @@ class FusedMoE(torch.nn.Module):
|
|
448
83
|
if params_dtype is None:
|
449
84
|
params_dtype = torch.get_default_dtype()
|
450
85
|
|
86
|
+
self.top_k = top_k
|
451
87
|
self.hidden_size = hidden_size
|
452
88
|
self.tp_size = (
|
453
89
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
@@ -485,19 +121,9 @@ class FusedMoE(torch.nn.Module):
|
|
485
121
|
self.ep_rank = 0
|
486
122
|
self.local_num_experts = num_experts
|
487
123
|
self.routed_scaling_factor = routed_scaling_factor
|
488
|
-
self.top_k = top_k
|
489
124
|
assert intermediate_size % self.tp_size == 0
|
490
125
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
491
126
|
self.reduce_results = reduce_results
|
492
|
-
self.renormalize = renormalize
|
493
|
-
self.use_grouped_topk = use_grouped_topk
|
494
|
-
if self.use_grouped_topk:
|
495
|
-
assert num_expert_group is not None and topk_group is not None
|
496
|
-
self.num_expert_group = num_expert_group
|
497
|
-
self.num_fused_shared_experts = num_fused_shared_experts
|
498
|
-
self.topk_group = topk_group
|
499
|
-
self.custom_routing_function = custom_routing_function
|
500
|
-
self.correction_bias = correction_bias
|
501
127
|
self.activation = activation
|
502
128
|
self.apply_router_weight_on_input = apply_router_weight_on_input
|
503
129
|
self.use_presharded_weights = use_presharded_weights
|
@@ -553,7 +179,7 @@ class FusedMoE(torch.nn.Module):
|
|
553
179
|
shard_dim: int,
|
554
180
|
expert_data: torch.Tensor,
|
555
181
|
shard_id: str,
|
556
|
-
loaded_weight: torch.
|
182
|
+
loaded_weight: torch.Tensor,
|
557
183
|
tp_rank: int,
|
558
184
|
):
|
559
185
|
# Load grouped weight scales for group quantization
|
@@ -580,7 +206,7 @@ class FusedMoE(torch.nn.Module):
|
|
580
206
|
expert_data: torch.Tensor,
|
581
207
|
shard_dim: int,
|
582
208
|
shard_id: str,
|
583
|
-
loaded_weight: torch.
|
209
|
+
loaded_weight: torch.Tensor,
|
584
210
|
tp_rank: int,
|
585
211
|
):
|
586
212
|
# for per channel weight quantization
|
@@ -600,7 +226,7 @@ class FusedMoE(torch.nn.Module):
|
|
600
226
|
expert_data: torch.Tensor,
|
601
227
|
shard_dim: int,
|
602
228
|
shard_id: str,
|
603
|
-
loaded_weight: torch.
|
229
|
+
loaded_weight: torch.Tensor,
|
604
230
|
tp_rank: int,
|
605
231
|
):
|
606
232
|
|
@@ -645,7 +271,7 @@ class FusedMoE(torch.nn.Module):
|
|
645
271
|
expert_data: torch.Tensor,
|
646
272
|
shard_dim: int,
|
647
273
|
shard_id: str,
|
648
|
-
loaded_weight: torch.
|
274
|
+
loaded_weight: torch.Tensor,
|
649
275
|
tp_rank: int,
|
650
276
|
):
|
651
277
|
"""Load w2 weights for down projection.
|
@@ -717,7 +343,7 @@ class FusedMoE(torch.nn.Module):
|
|
717
343
|
shard_id: str,
|
718
344
|
expert_data: torch.Tensor,
|
719
345
|
shard_dim: int,
|
720
|
-
loaded_weight: torch.
|
346
|
+
loaded_weight: torch.Tensor,
|
721
347
|
tp_rank: int,
|
722
348
|
):
|
723
349
|
|
@@ -921,22 +547,14 @@ class FusedMoE(torch.nn.Module):
|
|
921
547
|
)
|
922
548
|
return
|
923
549
|
|
924
|
-
def forward(self, hidden_states: torch.Tensor,
|
550
|
+
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
925
551
|
assert self.quant_method is not None
|
926
552
|
|
927
553
|
# Matrix multiply.
|
928
554
|
final_hidden_states = self.quant_method.apply(
|
929
555
|
layer=self,
|
930
556
|
x=hidden_states,
|
931
|
-
|
932
|
-
top_k=self.top_k,
|
933
|
-
renormalize=self.renormalize,
|
934
|
-
use_grouped_topk=self.use_grouped_topk,
|
935
|
-
topk_group=self.topk_group,
|
936
|
-
num_expert_group=self.num_expert_group,
|
937
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
938
|
-
custom_routing_function=self.custom_routing_function,
|
939
|
-
correction_bias=self.correction_bias,
|
557
|
+
topk_output=topk_output,
|
940
558
|
activation=self.activation,
|
941
559
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
942
560
|
routed_scaling_factor=self.routed_scaling_factor,
|