sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,422 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import importlib
|
4
|
+
from typing import TYPE_CHECKING, Callable, List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from torch.nn.parameter import Parameter
|
9
|
+
|
10
|
+
from sglang.srt.custom_op import CustomOp
|
11
|
+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
12
|
+
from sglang.srt.layers.quantization.base_config import (
|
13
|
+
FusedMoEMethodBase,
|
14
|
+
LinearMethodBase,
|
15
|
+
QuantizeMethodBase,
|
16
|
+
)
|
17
|
+
from sglang.srt.utils import (
|
18
|
+
cpu_has_amx_support,
|
19
|
+
get_bool_env_var,
|
20
|
+
is_cpu,
|
21
|
+
is_hip,
|
22
|
+
set_weight_attrs,
|
23
|
+
use_intel_amx_backend,
|
24
|
+
)
|
25
|
+
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
28
|
+
|
29
|
+
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
30
|
+
|
31
|
+
|
32
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
33
|
+
_is_hip = is_hip()
|
34
|
+
_is_cpu = is_cpu()
|
35
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
36
|
+
|
37
|
+
if _use_aiter:
|
38
|
+
from aiter import ActivationType
|
39
|
+
from aiter.fused_moe import fused_moe
|
40
|
+
from aiter.ops.shuffle import shuffle_weight
|
41
|
+
|
42
|
+
|
43
|
+
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
44
|
+
"""Unquantized method for embeddings."""
|
45
|
+
|
46
|
+
def create_weights(
|
47
|
+
self,
|
48
|
+
layer: torch.nn.Module,
|
49
|
+
input_size_per_partition: int,
|
50
|
+
output_partition_sizes: List[int],
|
51
|
+
input_size: int,
|
52
|
+
output_size: int,
|
53
|
+
params_dtype: torch.dtype,
|
54
|
+
**extra_weight_attrs,
|
55
|
+
):
|
56
|
+
"""Create weights for embedding layer."""
|
57
|
+
weight = Parameter(
|
58
|
+
torch.empty(
|
59
|
+
sum(output_partition_sizes),
|
60
|
+
input_size_per_partition,
|
61
|
+
dtype=params_dtype,
|
62
|
+
),
|
63
|
+
requires_grad=False,
|
64
|
+
)
|
65
|
+
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
66
|
+
layer.register_parameter("weight", weight)
|
67
|
+
set_weight_attrs(weight, extra_weight_attrs)
|
68
|
+
|
69
|
+
def apply(
|
70
|
+
self,
|
71
|
+
layer: torch.nn.Module,
|
72
|
+
x: torch.Tensor,
|
73
|
+
bias: Optional[torch.Tensor] = None,
|
74
|
+
) -> torch.Tensor:
|
75
|
+
return F.linear(x, layer.weight, bias)
|
76
|
+
|
77
|
+
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
|
78
|
+
return F.embedding(input_, layer.weight)
|
79
|
+
|
80
|
+
|
81
|
+
class UnquantizedLinearMethod(LinearMethodBase):
|
82
|
+
"""Linear method without quantization."""
|
83
|
+
|
84
|
+
def create_weights(
|
85
|
+
self,
|
86
|
+
layer: torch.nn.Module,
|
87
|
+
input_size_per_partition: int,
|
88
|
+
output_partition_sizes: List[int],
|
89
|
+
input_size: int,
|
90
|
+
output_size: int,
|
91
|
+
params_dtype: torch.dtype,
|
92
|
+
**extra_weight_attrs,
|
93
|
+
):
|
94
|
+
weight = Parameter(
|
95
|
+
torch.empty(
|
96
|
+
sum(output_partition_sizes),
|
97
|
+
input_size_per_partition,
|
98
|
+
dtype=params_dtype,
|
99
|
+
),
|
100
|
+
requires_grad=False,
|
101
|
+
)
|
102
|
+
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
103
|
+
layer.register_parameter("weight", weight)
|
104
|
+
set_weight_attrs(weight, extra_weight_attrs)
|
105
|
+
|
106
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
107
|
+
if _is_cpu and _is_cpu_amx_available:
|
108
|
+
_amx_process_weight_after_loading(layer, ["weight"])
|
109
|
+
|
110
|
+
def apply(
|
111
|
+
self,
|
112
|
+
layer: torch.nn.Module,
|
113
|
+
x: torch.Tensor,
|
114
|
+
bias: Optional[torch.Tensor] = None,
|
115
|
+
) -> torch.Tensor:
|
116
|
+
|
117
|
+
if use_intel_amx_backend(layer):
|
118
|
+
return torch.ops.sgl_kernel.weight_packed_linear(
|
119
|
+
x, layer.weight, bias, True # is_vnni
|
120
|
+
)
|
121
|
+
|
122
|
+
return F.linear(x, layer.weight, bias)
|
123
|
+
|
124
|
+
|
125
|
+
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
126
|
+
"""MoE method without quantization."""
|
127
|
+
|
128
|
+
def __init__(self, use_triton_kernels: bool = False):
|
129
|
+
super().__init__()
|
130
|
+
self.use_triton_kernels = use_triton_kernels
|
131
|
+
|
132
|
+
def create_weights(
|
133
|
+
self,
|
134
|
+
layer: torch.nn.Module,
|
135
|
+
num_experts: int,
|
136
|
+
hidden_size: int,
|
137
|
+
intermediate_size: int,
|
138
|
+
params_dtype: torch.dtype,
|
139
|
+
**extra_weight_attrs,
|
140
|
+
):
|
141
|
+
# Fused gate_up_proj (column parallel)
|
142
|
+
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
143
|
+
if self.use_triton_kernels:
|
144
|
+
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
145
|
+
w13_weight = torch.nn.Parameter(
|
146
|
+
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
|
147
|
+
requires_grad=False,
|
148
|
+
)
|
149
|
+
layer.register_parameter("w13_weight", w13_weight)
|
150
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
151
|
+
|
152
|
+
# down_proj (row parallel)
|
153
|
+
w2_weight_n, w2_weight_k = (
|
154
|
+
hidden_size,
|
155
|
+
intermediate_size,
|
156
|
+
)
|
157
|
+
if self.use_triton_kernels:
|
158
|
+
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
159
|
+
w2_weight = torch.nn.Parameter(
|
160
|
+
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
|
161
|
+
requires_grad=False,
|
162
|
+
)
|
163
|
+
layer.register_parameter("w2_weight", w2_weight)
|
164
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
165
|
+
|
166
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
167
|
+
if _use_aiter:
|
168
|
+
layer.w13_weight = torch.nn.Parameter(
|
169
|
+
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
170
|
+
requires_grad=False,
|
171
|
+
)
|
172
|
+
torch.cuda.empty_cache()
|
173
|
+
layer.w2_weight = torch.nn.Parameter(
|
174
|
+
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
175
|
+
requires_grad=False,
|
176
|
+
)
|
177
|
+
torch.cuda.empty_cache()
|
178
|
+
|
179
|
+
# Pack weight for get better performance on CPU
|
180
|
+
if _is_cpu and _is_cpu_amx_available:
|
181
|
+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
182
|
+
|
183
|
+
return
|
184
|
+
|
185
|
+
def apply(
|
186
|
+
self,
|
187
|
+
layer: torch.nn.Module,
|
188
|
+
x: torch.Tensor,
|
189
|
+
topk_output: TopKOutput,
|
190
|
+
*,
|
191
|
+
activation: str = "silu",
|
192
|
+
apply_router_weight_on_input: bool = False,
|
193
|
+
inplace: bool = True,
|
194
|
+
no_combine: bool = False,
|
195
|
+
routed_scaling_factor: Optional[float] = None,
|
196
|
+
) -> torch.Tensor:
|
197
|
+
return self.forward(
|
198
|
+
x=x,
|
199
|
+
layer=layer,
|
200
|
+
topk_output=topk_output,
|
201
|
+
activation=activation,
|
202
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
203
|
+
inplace=inplace,
|
204
|
+
no_combine=no_combine,
|
205
|
+
routed_scaling_factor=routed_scaling_factor,
|
206
|
+
)
|
207
|
+
|
208
|
+
def forward_cuda(
|
209
|
+
self,
|
210
|
+
layer: torch.nn.Module,
|
211
|
+
x: torch.Tensor,
|
212
|
+
topk_output: TopKOutput,
|
213
|
+
*,
|
214
|
+
activation: str = "silu",
|
215
|
+
apply_router_weight_on_input: bool = False,
|
216
|
+
inplace: bool = True,
|
217
|
+
no_combine: bool = False,
|
218
|
+
routed_scaling_factor: Optional[float] = None,
|
219
|
+
) -> torch.Tensor:
|
220
|
+
|
221
|
+
if self.use_triton_kernels:
|
222
|
+
# TODO(ch-wan): re-enable the Triton kernel
|
223
|
+
raise NotImplementedError("The Triton kernel is temporarily disabled.")
|
224
|
+
# return triton_kernel_moe_forward(
|
225
|
+
# hidden_states=x,
|
226
|
+
# w1=layer.w13_weight,
|
227
|
+
# w2=layer.w2_weight,
|
228
|
+
# gating_output=router_logits,
|
229
|
+
# topk=top_k,
|
230
|
+
# renormalize=renormalize,
|
231
|
+
# )
|
232
|
+
else:
|
233
|
+
if _use_aiter:
|
234
|
+
assert not no_combine, "unsupported"
|
235
|
+
topk_weights, topk_ids, _ = topk_output
|
236
|
+
if apply_router_weight_on_input:
|
237
|
+
assert (
|
238
|
+
topk_weights.dim() == 2
|
239
|
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
240
|
+
_, topk = topk_weights.shape
|
241
|
+
assert (
|
242
|
+
topk == 1
|
243
|
+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
244
|
+
x = x * topk_weights.to(x.dtype)
|
245
|
+
topk_weights = torch.ones_like(
|
246
|
+
topk_weights, dtype=torch.float32
|
247
|
+
) # topk_weights must be FP32 (float32)
|
248
|
+
return fused_moe(
|
249
|
+
x,
|
250
|
+
layer.w13_weight,
|
251
|
+
layer.w2_weight,
|
252
|
+
topk_weights,
|
253
|
+
topk_ids,
|
254
|
+
activation=(
|
255
|
+
ActivationType.Silu
|
256
|
+
if activation == "silu"
|
257
|
+
else ActivationType.Gelu
|
258
|
+
),
|
259
|
+
)
|
260
|
+
else:
|
261
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
262
|
+
fused_experts,
|
263
|
+
)
|
264
|
+
|
265
|
+
return fused_experts(
|
266
|
+
hidden_states=x,
|
267
|
+
w1=layer.w13_weight,
|
268
|
+
w2=layer.w2_weight,
|
269
|
+
topk_output=topk_output,
|
270
|
+
inplace=inplace and not no_combine,
|
271
|
+
activation=activation,
|
272
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
273
|
+
no_combine=no_combine,
|
274
|
+
routed_scaling_factor=routed_scaling_factor,
|
275
|
+
)
|
276
|
+
|
277
|
+
def forward_cpu(
|
278
|
+
self,
|
279
|
+
layer: torch.nn.Module,
|
280
|
+
x: torch.Tensor,
|
281
|
+
topk_output: TopKOutput,
|
282
|
+
*,
|
283
|
+
activation: str = "silu",
|
284
|
+
apply_router_weight_on_input: bool = False,
|
285
|
+
inplace: bool = True,
|
286
|
+
no_combine: bool = False,
|
287
|
+
routed_scaling_factor: Optional[float] = None,
|
288
|
+
) -> torch.Tensor:
|
289
|
+
assert activation == "silu", f"activation = {activation} is not supported."
|
290
|
+
|
291
|
+
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
292
|
+
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
293
|
+
|
294
|
+
topk_weights, topk_ids, _ = topk_output
|
295
|
+
x, topk_weights = apply_topk_weights_cpu(
|
296
|
+
apply_router_weight_on_input, topk_weights, x
|
297
|
+
)
|
298
|
+
return torch.ops.sgl_kernel.fused_experts_cpu(
|
299
|
+
x,
|
300
|
+
layer.w13_weight,
|
301
|
+
layer.w2_weight,
|
302
|
+
topk_weights,
|
303
|
+
topk_ids,
|
304
|
+
False, # inplace # See [Note] inplace should be False in fused_experts.
|
305
|
+
False, # use_int8_w8a8
|
306
|
+
False, # use_fp8_w8a16
|
307
|
+
None, # w1_scale
|
308
|
+
None, # w2_scale
|
309
|
+
None, # block_size
|
310
|
+
None, # a1_scale
|
311
|
+
None, # a2_scale
|
312
|
+
True, # is_vnni
|
313
|
+
)
|
314
|
+
else:
|
315
|
+
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
316
|
+
|
317
|
+
return moe_forward_native(
|
318
|
+
layer,
|
319
|
+
x,
|
320
|
+
topk_output,
|
321
|
+
activation=activation,
|
322
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
323
|
+
inplace=inplace,
|
324
|
+
no_combine=no_combine,
|
325
|
+
routed_scaling_factor=routed_scaling_factor,
|
326
|
+
)
|
327
|
+
|
328
|
+
def forward_npu(
|
329
|
+
self,
|
330
|
+
layer: torch.nn.Module,
|
331
|
+
x: torch.Tensor,
|
332
|
+
topk_output: TopKOutput,
|
333
|
+
*,
|
334
|
+
activation: str = "silu",
|
335
|
+
apply_router_weight_on_input: bool = False,
|
336
|
+
inplace: bool = True,
|
337
|
+
no_combine: bool = False,
|
338
|
+
routed_scaling_factor: Optional[float] = None,
|
339
|
+
) -> torch.Tensor:
|
340
|
+
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
341
|
+
|
342
|
+
return moe_forward_native(
|
343
|
+
layer,
|
344
|
+
x,
|
345
|
+
topk_output,
|
346
|
+
activation=activation,
|
347
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
348
|
+
inplace=inplace,
|
349
|
+
no_combine=no_combine,
|
350
|
+
routed_scaling_factor=routed_scaling_factor,
|
351
|
+
)
|
352
|
+
|
353
|
+
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
354
|
+
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
355
|
+
|
356
|
+
forward_native = forward_cpu
|
357
|
+
|
358
|
+
|
359
|
+
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
360
|
+
|
361
|
+
def create_weights(
|
362
|
+
self,
|
363
|
+
layer: torch.nn.Module,
|
364
|
+
num_experts_per_partition: int,
|
365
|
+
hidden_size: int,
|
366
|
+
intermediate_size: int,
|
367
|
+
params_dtype: torch.dtype,
|
368
|
+
**extra_weight_attrs,
|
369
|
+
):
|
370
|
+
# Fused gate_up_proj (column parallel)
|
371
|
+
w13_weight = torch.nn.Parameter(
|
372
|
+
torch.empty(
|
373
|
+
num_experts_per_partition,
|
374
|
+
2 * intermediate_size,
|
375
|
+
hidden_size,
|
376
|
+
dtype=params_dtype,
|
377
|
+
),
|
378
|
+
requires_grad=False,
|
379
|
+
)
|
380
|
+
layer.register_parameter("w13_weight", w13_weight)
|
381
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
382
|
+
|
383
|
+
# down_proj (row parallel)
|
384
|
+
w2_weight = torch.nn.Parameter(
|
385
|
+
torch.empty(
|
386
|
+
num_experts_per_partition,
|
387
|
+
hidden_size,
|
388
|
+
intermediate_size,
|
389
|
+
dtype=params_dtype,
|
390
|
+
),
|
391
|
+
requires_grad=False,
|
392
|
+
)
|
393
|
+
layer.register_parameter("w2_weight", w2_weight)
|
394
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
395
|
+
|
396
|
+
# scale
|
397
|
+
layer.register_parameter("w13_input_scale", None)
|
398
|
+
layer.register_parameter("w13_weight_scale", None)
|
399
|
+
|
400
|
+
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
401
|
+
|
402
|
+
w2_input_scale = torch.nn.Parameter(
|
403
|
+
ones_tensor,
|
404
|
+
requires_grad=False,
|
405
|
+
)
|
406
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
407
|
+
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
408
|
+
|
409
|
+
w2_weight_scale = torch.nn.Parameter(
|
410
|
+
ones_tensor,
|
411
|
+
requires_grad=False,
|
412
|
+
)
|
413
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
414
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
415
|
+
|
416
|
+
def apply(
|
417
|
+
self,
|
418
|
+
layer: torch.nn.Module,
|
419
|
+
hidden_states: torch.Tensor,
|
420
|
+
topk_output: TopKOutput,
|
421
|
+
) -> torch.Tensor:
|
422
|
+
raise NotImplementedError
|