sglang 0.3.6__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_one_batch.py +2 -4
- sglang/bench_serving.py +75 -26
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +13 -15
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +38 -57
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +13 -13
- sglang/srt/layers/attention/flashinfer_backend.py +13 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +13 -14
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +13 -15
- sglang/srt/layers/logits_processor.py +13 -15
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +25 -19
- sglang/srt/managers/detokenizer_manager.py +13 -16
- sglang/srt/managers/io_struct.py +43 -28
- sglang/srt/managers/schedule_batch.py +55 -26
- sglang/srt/managers/schedule_policy.py +13 -15
- sglang/srt/managers/scheduler.py +89 -70
- sglang/srt/managers/session_controller.py +14 -15
- sglang/srt/managers/tokenizer_manager.py +29 -22
- sglang/srt/managers/tp_worker.py +13 -15
- sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +20 -19
- sglang/srt/model_executor/forward_batch_info.py +19 -17
- sglang/srt/model_executor/model_runner.py +42 -30
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +15 -15
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +24 -19
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +13 -15
- sglang/srt/models/llavavid.py +13 -15
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +21 -19
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +15 -17
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +13 -15
- sglang/srt/openai_api/protocol.py +13 -15
- sglang/srt/sampling/sampling_batch_info.py +4 -1
- sglang/srt/sampling/sampling_params.py +13 -15
- sglang/srt/server.py +59 -34
- sglang/srt/server_args.py +22 -22
- sglang/srt/utils.py +196 -17
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +13 -14
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +24 -15
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.6.dist-info/RECORD +0 -161
- /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,633 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
2
|
+
|
3
|
+
from abc import abstractmethod
|
4
|
+
from enum import Enum
|
5
|
+
from typing import Callable, List, Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
from vllm.distributed import (
|
9
|
+
get_tensor_model_parallel_rank,
|
10
|
+
get_tensor_model_parallel_world_size,
|
11
|
+
tensor_model_parallel_all_reduce,
|
12
|
+
)
|
13
|
+
from vllm.model_executor.custom_op import CustomOp
|
14
|
+
|
15
|
+
from sglang.srt.layers.custom_op_util import register_custom_op
|
16
|
+
from sglang.srt.layers.quantization.base_config import (
|
17
|
+
QuantizationConfig,
|
18
|
+
QuantizeMethodBase,
|
19
|
+
)
|
20
|
+
from sglang.srt.utils import set_weight_attrs
|
21
|
+
|
22
|
+
if torch.cuda.is_available() or torch.hip.is_available():
|
23
|
+
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
24
|
+
else:
|
25
|
+
fused_experts = None # type: ignore
|
26
|
+
|
27
|
+
import logging
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
class FusedMoeWeightScaleSupported(Enum):
|
33
|
+
TENSOR = "tensor"
|
34
|
+
CHANNEL = "channel"
|
35
|
+
GROUP = "group"
|
36
|
+
|
37
|
+
|
38
|
+
class FusedMoEMethodBase(QuantizeMethodBase):
|
39
|
+
|
40
|
+
@abstractmethod
|
41
|
+
def create_weights(
|
42
|
+
self,
|
43
|
+
layer: torch.nn.Module,
|
44
|
+
num_experts: int,
|
45
|
+
hidden_size: int,
|
46
|
+
intermediate_size: int,
|
47
|
+
params_dtype: torch.dtype,
|
48
|
+
**extra_weight_attrs,
|
49
|
+
):
|
50
|
+
raise NotImplementedError
|
51
|
+
|
52
|
+
@abstractmethod
|
53
|
+
def apply(
|
54
|
+
self,
|
55
|
+
layer: torch.nn.Module,
|
56
|
+
x: torch.Tensor,
|
57
|
+
router_logits: torch.Tensor,
|
58
|
+
top_k: int,
|
59
|
+
renormalize: bool,
|
60
|
+
use_grouped_topk: bool,
|
61
|
+
) -> torch.Tensor:
|
62
|
+
raise NotImplementedError
|
63
|
+
|
64
|
+
|
65
|
+
@register_custom_op("sglang_unquantized_fused_moe")
|
66
|
+
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
67
|
+
"""MoE method without quantization."""
|
68
|
+
|
69
|
+
def create_weights(
|
70
|
+
self,
|
71
|
+
layer: torch.nn.Module,
|
72
|
+
num_experts: int,
|
73
|
+
hidden_size: int,
|
74
|
+
intermediate_size: int,
|
75
|
+
params_dtype: torch.dtype,
|
76
|
+
**extra_weight_attrs,
|
77
|
+
):
|
78
|
+
# Fused gate_up_proj (column parallel)
|
79
|
+
w13_weight = torch.nn.Parameter(
|
80
|
+
torch.empty(
|
81
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
82
|
+
),
|
83
|
+
requires_grad=False,
|
84
|
+
)
|
85
|
+
layer.register_parameter("w13_weight", w13_weight)
|
86
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
87
|
+
|
88
|
+
# down_proj (row parallel)
|
89
|
+
w2_weight = torch.nn.Parameter(
|
90
|
+
torch.empty(
|
91
|
+
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
92
|
+
),
|
93
|
+
requires_grad=False,
|
94
|
+
)
|
95
|
+
layer.register_parameter("w2_weight", w2_weight)
|
96
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
97
|
+
|
98
|
+
def apply(
|
99
|
+
self,
|
100
|
+
layer: torch.nn.Module,
|
101
|
+
x: torch.Tensor,
|
102
|
+
router_logits: torch.Tensor,
|
103
|
+
top_k: int,
|
104
|
+
renormalize: bool,
|
105
|
+
use_grouped_topk: bool,
|
106
|
+
topk_group: Optional[int] = None,
|
107
|
+
num_expert_group: Optional[int] = None,
|
108
|
+
custom_routing_function: Optional[Callable] = None,
|
109
|
+
) -> torch.Tensor:
|
110
|
+
return self.forward(
|
111
|
+
x=x,
|
112
|
+
layer=layer,
|
113
|
+
router_logits=router_logits,
|
114
|
+
top_k=top_k,
|
115
|
+
renormalize=renormalize,
|
116
|
+
use_grouped_topk=use_grouped_topk,
|
117
|
+
topk_group=topk_group,
|
118
|
+
num_expert_group=num_expert_group,
|
119
|
+
custom_routing_function=custom_routing_function,
|
120
|
+
)
|
121
|
+
|
122
|
+
def forward_cuda(
|
123
|
+
self,
|
124
|
+
layer: torch.nn.Module,
|
125
|
+
x: torch.Tensor,
|
126
|
+
use_grouped_topk: bool,
|
127
|
+
top_k: int,
|
128
|
+
router_logits: torch.Tensor,
|
129
|
+
renormalize: bool,
|
130
|
+
topk_group: Optional[int] = None,
|
131
|
+
num_expert_group: Optional[int] = None,
|
132
|
+
custom_routing_function: Optional[Callable] = None,
|
133
|
+
) -> torch.Tensor:
|
134
|
+
topk_weights, topk_ids = FusedMoE.select_experts(
|
135
|
+
hidden_states=x,
|
136
|
+
router_logits=router_logits,
|
137
|
+
use_grouped_topk=use_grouped_topk,
|
138
|
+
top_k=top_k,
|
139
|
+
renormalize=renormalize,
|
140
|
+
topk_group=topk_group,
|
141
|
+
num_expert_group=num_expert_group,
|
142
|
+
custom_routing_function=custom_routing_function,
|
143
|
+
)
|
144
|
+
|
145
|
+
return fused_experts(
|
146
|
+
hidden_states=x,
|
147
|
+
w1=layer.w13_weight,
|
148
|
+
w2=layer.w2_weight,
|
149
|
+
topk_weights=topk_weights,
|
150
|
+
topk_ids=topk_ids,
|
151
|
+
inplace=True,
|
152
|
+
)
|
153
|
+
|
154
|
+
def forward_cpu(self, *args, **kwargs):
|
155
|
+
raise NotImplementedError("The CPU backend currently does not support MoE.")
|
156
|
+
|
157
|
+
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
158
|
+
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
159
|
+
|
160
|
+
forward_native = forward_cuda
|
161
|
+
|
162
|
+
|
163
|
+
class FusedMoE(torch.nn.Module):
|
164
|
+
"""FusedMoE layer for MoE models.
|
165
|
+
|
166
|
+
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
167
|
+
w13) and RowParallelLinear weights (down_proj/ w2).
|
168
|
+
|
169
|
+
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
170
|
+
copy that naming convention here and handle any remapping in the
|
171
|
+
load_weights function in each model implementation.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
num_experts: Number of experts in the model
|
175
|
+
top_k: Number of experts selected for each token
|
176
|
+
hidden_size: Input hidden state size of the transformer
|
177
|
+
intermediate_size: Intermediate size of the experts
|
178
|
+
params_dtype: Data type for the parameters.
|
179
|
+
reduce_results: Whether to all all_reduce on the output of the layer
|
180
|
+
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
181
|
+
quant_config: Quantization configure.
|
182
|
+
"""
|
183
|
+
|
184
|
+
def __init__(
|
185
|
+
self,
|
186
|
+
num_experts: int,
|
187
|
+
top_k: int,
|
188
|
+
hidden_size: int,
|
189
|
+
intermediate_size: int,
|
190
|
+
params_dtype: Optional[torch.dtype] = None,
|
191
|
+
reduce_results: bool = False,
|
192
|
+
renormalize: bool = True,
|
193
|
+
use_grouped_topk: bool = False,
|
194
|
+
num_expert_group: Optional[int] = None,
|
195
|
+
topk_group: Optional[int] = None,
|
196
|
+
quant_config: Optional[QuantizationConfig] = None,
|
197
|
+
tp_size: Optional[int] = None,
|
198
|
+
prefix: str = "",
|
199
|
+
custom_routing_function: Optional[Callable] = None,
|
200
|
+
):
|
201
|
+
super().__init__()
|
202
|
+
|
203
|
+
if params_dtype is None:
|
204
|
+
params_dtype = torch.get_default_dtype()
|
205
|
+
|
206
|
+
self.tp_size = (
|
207
|
+
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
208
|
+
)
|
209
|
+
self.top_k = top_k
|
210
|
+
self.num_experts = num_experts
|
211
|
+
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
212
|
+
self.reduce_results = reduce_results
|
213
|
+
self.renormalize = renormalize
|
214
|
+
self.use_grouped_topk = use_grouped_topk
|
215
|
+
if self.use_grouped_topk:
|
216
|
+
assert num_expert_group is not None and topk_group is not None
|
217
|
+
self.num_expert_group = num_expert_group
|
218
|
+
self.topk_group = topk_group
|
219
|
+
self.custom_routing_function = custom_routing_function
|
220
|
+
|
221
|
+
if quant_config is None:
|
222
|
+
self.quant_method: Optional[QuantizeMethodBase] = (
|
223
|
+
UnquantizedFusedMoEMethod()
|
224
|
+
)
|
225
|
+
else:
|
226
|
+
self.quant_method = quant_config.get_quant_method(self, prefix)
|
227
|
+
assert self.quant_method is not None
|
228
|
+
|
229
|
+
self.quant_method.create_weights(
|
230
|
+
layer=self,
|
231
|
+
num_experts=num_experts,
|
232
|
+
hidden_size=hidden_size,
|
233
|
+
intermediate_size=self.intermediate_size_per_partition,
|
234
|
+
params_dtype=params_dtype,
|
235
|
+
weight_loader=self.weight_loader,
|
236
|
+
)
|
237
|
+
|
238
|
+
def _load_per_tensor_weight_scale(
|
239
|
+
self,
|
240
|
+
shard_id: str,
|
241
|
+
param: torch.nn.Parameter,
|
242
|
+
loaded_weight: torch.Tensor,
|
243
|
+
expert_id: int,
|
244
|
+
):
|
245
|
+
param_data = param.data
|
246
|
+
# for per tensor weight quantization
|
247
|
+
if shard_id in ("w1", "w3"):
|
248
|
+
# We have to keep the weight scales of w1 and w3 because
|
249
|
+
# we need to re-quantize w1/w3 weights after weight loading.
|
250
|
+
idx = 0 if shard_id == "w1" else 1
|
251
|
+
param_data[expert_id][idx] = loaded_weight
|
252
|
+
# If we are in the row parallel case (down_proj)
|
253
|
+
elif shard_id == "w2":
|
254
|
+
param_data[expert_id] = loaded_weight
|
255
|
+
|
256
|
+
def _load_model_weight_or_group_weight_scale(
|
257
|
+
self,
|
258
|
+
shard_dim: int,
|
259
|
+
expert_data: torch.Tensor,
|
260
|
+
shard_id: str,
|
261
|
+
loaded_weight: torch.tensor,
|
262
|
+
tp_rank: int,
|
263
|
+
):
|
264
|
+
# Load grouped weight scales for group quantization
|
265
|
+
# or model weights
|
266
|
+
if shard_id == "w2":
|
267
|
+
self._load_w2(
|
268
|
+
shard_id=shard_id,
|
269
|
+
shard_dim=shard_dim,
|
270
|
+
loaded_weight=loaded_weight,
|
271
|
+
expert_data=expert_data,
|
272
|
+
tp_rank=tp_rank,
|
273
|
+
)
|
274
|
+
elif shard_id in ("w1", "w3"):
|
275
|
+
self._load_w13(
|
276
|
+
shard_id=shard_id,
|
277
|
+
shard_dim=shard_dim,
|
278
|
+
loaded_weight=loaded_weight,
|
279
|
+
expert_data=expert_data,
|
280
|
+
tp_rank=tp_rank,
|
281
|
+
)
|
282
|
+
|
283
|
+
def _load_per_channel_weight_scale(
|
284
|
+
self,
|
285
|
+
expert_data: torch.Tensor,
|
286
|
+
shard_dim: int,
|
287
|
+
shard_id: str,
|
288
|
+
loaded_weight: torch.tensor,
|
289
|
+
tp_rank: int,
|
290
|
+
):
|
291
|
+
# for per channel weight quantization
|
292
|
+
if shard_id == "w2":
|
293
|
+
expert_data.copy_(loaded_weight)
|
294
|
+
elif shard_id in ("w1", "w3"):
|
295
|
+
self._load_w13(
|
296
|
+
shard_id=shard_id,
|
297
|
+
shard_dim=shard_dim,
|
298
|
+
loaded_weight=loaded_weight,
|
299
|
+
expert_data=expert_data,
|
300
|
+
tp_rank=tp_rank,
|
301
|
+
)
|
302
|
+
|
303
|
+
def _load_w13(
|
304
|
+
self,
|
305
|
+
expert_data: torch.Tensor,
|
306
|
+
shard_dim: int,
|
307
|
+
shard_id: str,
|
308
|
+
loaded_weight: torch.tensor,
|
309
|
+
tp_rank: int,
|
310
|
+
):
|
311
|
+
|
312
|
+
# Index the loaded weight for tp sharding.
|
313
|
+
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
314
|
+
shard_size = expert_data.shape[shard_dim] // 2
|
315
|
+
loaded_weight = loaded_weight.narrow(
|
316
|
+
shard_dim, shard_size * tp_rank, shard_size
|
317
|
+
)
|
318
|
+
# Narrow parameter and load.
|
319
|
+
# w1, gate_proj: Load into first logical weight of w13.
|
320
|
+
if shard_id == "w1":
|
321
|
+
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
322
|
+
# w3, up_proj: Load into second logical weight of w13.
|
323
|
+
else:
|
324
|
+
assert shard_id == "w3"
|
325
|
+
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
326
|
+
expert_data.copy_(loaded_weight)
|
327
|
+
|
328
|
+
def _load_w2(
|
329
|
+
self,
|
330
|
+
expert_data: torch.Tensor,
|
331
|
+
shard_dim: int,
|
332
|
+
shard_id: str,
|
333
|
+
loaded_weight: torch.tensor,
|
334
|
+
tp_rank: int,
|
335
|
+
):
|
336
|
+
|
337
|
+
# Index the loaded weight for tp sharding.
|
338
|
+
# down_proj: "RowParallel" so tp sharding on input_dim
|
339
|
+
# Narrow parameter and load.
|
340
|
+
shard_size = expert_data.shape[shard_dim]
|
341
|
+
loaded_weight = loaded_weight.narrow(
|
342
|
+
shard_dim, shard_size * tp_rank, shard_size
|
343
|
+
)
|
344
|
+
# w2, down_proj: Load into only logical weight of w2.
|
345
|
+
expert_data.copy_(loaded_weight)
|
346
|
+
|
347
|
+
def _load_single_value(
|
348
|
+
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
|
349
|
+
):
|
350
|
+
param_data = param.data
|
351
|
+
|
352
|
+
# Input scales can be loaded directly and should be equal.
|
353
|
+
param_data[expert_id] = loaded_weight
|
354
|
+
|
355
|
+
def _load_g_idx(
|
356
|
+
self,
|
357
|
+
shard_id: str,
|
358
|
+
expert_data: torch.Tensor,
|
359
|
+
shard_dim: int,
|
360
|
+
loaded_weight: torch.tensor,
|
361
|
+
tp_rank: int,
|
362
|
+
):
|
363
|
+
|
364
|
+
if shard_id == "w2":
|
365
|
+
self._load_w2(
|
366
|
+
shard_id=shard_id,
|
367
|
+
shard_dim=shard_dim,
|
368
|
+
loaded_weight=loaded_weight,
|
369
|
+
expert_data=expert_data,
|
370
|
+
tp_rank=tp_rank,
|
371
|
+
)
|
372
|
+
else:
|
373
|
+
assert shard_id in ("w1", "w3")
|
374
|
+
expert_data.copy_(loaded_weight)
|
375
|
+
|
376
|
+
def weight_loader(
|
377
|
+
self,
|
378
|
+
param: torch.nn.Parameter,
|
379
|
+
loaded_weight: torch.Tensor,
|
380
|
+
weight_name: str,
|
381
|
+
shard_id: str,
|
382
|
+
expert_id: int,
|
383
|
+
) -> None:
|
384
|
+
|
385
|
+
# compressed-tensors checkpoints with packed weights are stored flipped
|
386
|
+
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
387
|
+
# against known CompressionFormat enum values that have this quality
|
388
|
+
loaded_weight = (
|
389
|
+
loaded_weight.t().contiguous()
|
390
|
+
if (
|
391
|
+
self.quant_method.__class__.__name__
|
392
|
+
== "CompressedTensorsWNA16MoEMethod"
|
393
|
+
)
|
394
|
+
else loaded_weight
|
395
|
+
)
|
396
|
+
|
397
|
+
if shard_id not in ("w1", "w2", "w3"):
|
398
|
+
raise ValueError(
|
399
|
+
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
400
|
+
)
|
401
|
+
|
402
|
+
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
403
|
+
# Fetch the dim to shard the parameter/loaded weight
|
404
|
+
# based on the shard id. This will be whatever
|
405
|
+
# dimension intermediate_size is used.
|
406
|
+
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
407
|
+
|
408
|
+
expert_data = param.data[expert_id]
|
409
|
+
tp_rank = get_tensor_model_parallel_rank()
|
410
|
+
|
411
|
+
# is_transposed: if the dim to shard the weight
|
412
|
+
# should be flipped. Required by GPTQ, compressed-tensors
|
413
|
+
# should be whatever dimension intermediate_size is
|
414
|
+
is_transposed = getattr(param, "is_transposed", False)
|
415
|
+
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
416
|
+
if is_transposed:
|
417
|
+
shard_dim = ~shard_dim
|
418
|
+
|
419
|
+
# Case input scale: input_scale loading is only supported for fp8
|
420
|
+
if "input_scale" in weight_name:
|
421
|
+
# this is needed for compressed-tensors only
|
422
|
+
loaded_weight = loaded_weight.to(param.data.device)
|
423
|
+
|
424
|
+
if (
|
425
|
+
param.data[expert_id] != 1
|
426
|
+
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
|
427
|
+
):
|
428
|
+
raise ValueError(
|
429
|
+
"input_scales of w1 and w3 of a layer "
|
430
|
+
f"must be equal. But got {param.data[expert_id]} "
|
431
|
+
f"vs. {loaded_weight}"
|
432
|
+
)
|
433
|
+
|
434
|
+
self._load_single_value(
|
435
|
+
param=param, loaded_weight=loaded_weight, expert_id=expert_id
|
436
|
+
)
|
437
|
+
return
|
438
|
+
|
439
|
+
# Case g_idx
|
440
|
+
if "g_idx" in weight_name:
|
441
|
+
self._load_g_idx(
|
442
|
+
shard_dim=0,
|
443
|
+
shard_id=shard_id,
|
444
|
+
loaded_weight=loaded_weight,
|
445
|
+
expert_data=expert_data,
|
446
|
+
tp_rank=tp_rank,
|
447
|
+
)
|
448
|
+
return
|
449
|
+
|
450
|
+
# Case weight scales and zero_points
|
451
|
+
if "scale" in weight_name or "zero" in weight_name:
|
452
|
+
# load the weight scales and zp based on the quantization scheme
|
453
|
+
# supported weight scales/zp can be found in
|
454
|
+
# FusedMoeWeightScaleSupported
|
455
|
+
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
456
|
+
# specific to each case
|
457
|
+
quant_method = getattr(param, "quant_method", None)
|
458
|
+
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
459
|
+
self._load_per_channel_weight_scale(
|
460
|
+
shard_id=shard_id,
|
461
|
+
shard_dim=shard_dim,
|
462
|
+
loaded_weight=loaded_weight,
|
463
|
+
expert_data=expert_data,
|
464
|
+
tp_rank=tp_rank,
|
465
|
+
)
|
466
|
+
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
|
467
|
+
self._load_model_weight_or_group_weight_scale(
|
468
|
+
shard_id=shard_id,
|
469
|
+
shard_dim=shard_dim,
|
470
|
+
loaded_weight=loaded_weight,
|
471
|
+
expert_data=expert_data,
|
472
|
+
tp_rank=tp_rank,
|
473
|
+
)
|
474
|
+
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
475
|
+
self._load_per_tensor_weight_scale(
|
476
|
+
shard_id=shard_id,
|
477
|
+
param=param,
|
478
|
+
loaded_weight=loaded_weight,
|
479
|
+
expert_id=expert_id,
|
480
|
+
)
|
481
|
+
else:
|
482
|
+
raise ValueError(
|
483
|
+
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}"
|
484
|
+
)
|
485
|
+
return
|
486
|
+
|
487
|
+
# Case weight_shape
|
488
|
+
if "weight_shape" in weight_name:
|
489
|
+
# only required by compressed-tensors
|
490
|
+
self._load_single_value(
|
491
|
+
param=param, loaded_weight=loaded_weight, expert_id=expert_id
|
492
|
+
)
|
493
|
+
return
|
494
|
+
|
495
|
+
# Case model weights
|
496
|
+
if "weight" in weight_name:
|
497
|
+
self._load_model_weight_or_group_weight_scale(
|
498
|
+
shard_id=shard_id,
|
499
|
+
shard_dim=shard_dim,
|
500
|
+
loaded_weight=loaded_weight,
|
501
|
+
expert_data=expert_data,
|
502
|
+
tp_rank=tp_rank,
|
503
|
+
)
|
504
|
+
return
|
505
|
+
|
506
|
+
@staticmethod
|
507
|
+
def select_experts(
|
508
|
+
hidden_states: torch.Tensor,
|
509
|
+
router_logits: torch.Tensor,
|
510
|
+
top_k: int,
|
511
|
+
use_grouped_topk: bool,
|
512
|
+
renormalize: bool,
|
513
|
+
topk_group: Optional[int] = None,
|
514
|
+
num_expert_group: Optional[int] = None,
|
515
|
+
custom_routing_function: Optional[Callable] = None,
|
516
|
+
):
|
517
|
+
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
518
|
+
fused_topk,
|
519
|
+
grouped_topk,
|
520
|
+
)
|
521
|
+
|
522
|
+
# DeekSeekv2 uses grouped_top_k
|
523
|
+
if use_grouped_topk:
|
524
|
+
assert topk_group is not None
|
525
|
+
assert num_expert_group is not None
|
526
|
+
topk_weights, topk_ids = grouped_topk(
|
527
|
+
hidden_states=hidden_states,
|
528
|
+
gating_output=router_logits,
|
529
|
+
topk=top_k,
|
530
|
+
renormalize=renormalize,
|
531
|
+
num_expert_group=num_expert_group,
|
532
|
+
topk_group=topk_group,
|
533
|
+
)
|
534
|
+
elif custom_routing_function is None:
|
535
|
+
topk_weights, topk_ids = fused_topk(
|
536
|
+
hidden_states=hidden_states,
|
537
|
+
gating_output=router_logits,
|
538
|
+
topk=top_k,
|
539
|
+
renormalize=renormalize,
|
540
|
+
)
|
541
|
+
else:
|
542
|
+
topk_weights, topk_ids = custom_routing_function(
|
543
|
+
hidden_states=hidden_states,
|
544
|
+
gating_output=router_logits,
|
545
|
+
topk=top_k,
|
546
|
+
renormalize=renormalize,
|
547
|
+
)
|
548
|
+
|
549
|
+
return topk_weights, topk_ids
|
550
|
+
|
551
|
+
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
552
|
+
assert self.quant_method is not None
|
553
|
+
|
554
|
+
# Matrix multiply.
|
555
|
+
final_hidden_states = self.quant_method.apply(
|
556
|
+
layer=self,
|
557
|
+
x=hidden_states,
|
558
|
+
router_logits=router_logits,
|
559
|
+
top_k=self.top_k,
|
560
|
+
renormalize=self.renormalize,
|
561
|
+
use_grouped_topk=self.use_grouped_topk,
|
562
|
+
topk_group=self.topk_group,
|
563
|
+
num_expert_group=self.num_expert_group,
|
564
|
+
custom_routing_function=self.custom_routing_function,
|
565
|
+
)
|
566
|
+
|
567
|
+
if self.reduce_results and self.tp_size > 1:
|
568
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
569
|
+
|
570
|
+
return final_hidden_states
|
571
|
+
|
572
|
+
@classmethod
|
573
|
+
def make_expert_params_mapping(
|
574
|
+
cls,
|
575
|
+
ckpt_gate_proj_name: str,
|
576
|
+
ckpt_down_proj_name: str,
|
577
|
+
ckpt_up_proj_name: str,
|
578
|
+
num_experts: int,
|
579
|
+
) -> List[Tuple[str, str, int, str]]:
|
580
|
+
|
581
|
+
return [
|
582
|
+
# (param_name, weight_name, expert_id, shard_id)
|
583
|
+
(
|
584
|
+
(
|
585
|
+
"experts.w13_"
|
586
|
+
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
587
|
+
else "experts.w2_"
|
588
|
+
),
|
589
|
+
f"experts.{expert_id}.{weight_name}.",
|
590
|
+
expert_id,
|
591
|
+
shard_id,
|
592
|
+
)
|
593
|
+
for expert_id in range(num_experts)
|
594
|
+
for shard_id, weight_name in [
|
595
|
+
("w1", ckpt_gate_proj_name),
|
596
|
+
("w2", ckpt_down_proj_name),
|
597
|
+
("w3", ckpt_up_proj_name),
|
598
|
+
]
|
599
|
+
]
|
600
|
+
|
601
|
+
def _load_fp8_scale(
|
602
|
+
self,
|
603
|
+
param: torch.nn.Parameter,
|
604
|
+
loaded_weight: torch.Tensor,
|
605
|
+
weight_name: str,
|
606
|
+
shard_id: str,
|
607
|
+
expert_id: int,
|
608
|
+
) -> None:
|
609
|
+
param_data = param.data
|
610
|
+
|
611
|
+
# Input scales can be loaded directly and should be equal.
|
612
|
+
if "input_scale" in weight_name:
|
613
|
+
if (
|
614
|
+
param_data[expert_id] != 1
|
615
|
+
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
616
|
+
):
|
617
|
+
raise ValueError(
|
618
|
+
"input_scales of w1 and w3 of a layer "
|
619
|
+
f"must be equal. But got {param_data[expert_id]} "
|
620
|
+
f"vs. {loaded_weight}"
|
621
|
+
)
|
622
|
+
param_data[expert_id] = loaded_weight
|
623
|
+
# Weight scales
|
624
|
+
elif "weight_scale" in weight_name:
|
625
|
+
# If we are in merged column case (gate_up_proj)
|
626
|
+
if shard_id in ("w1", "w3"):
|
627
|
+
# We have to keep the weight scales of w1 and w3 because
|
628
|
+
# we need to re-quantize w1/w3 weights after weight loading.
|
629
|
+
idx = 0 if shard_id == "w1" else 1
|
630
|
+
param_data[expert_id][idx] = loaded_weight
|
631
|
+
# If we are in the row parallel case (down_proj)
|
632
|
+
else:
|
633
|
+
param_data[expert_id] = loaded_weight
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Fused operators for normalization layers."""
|
17
15
|
|
18
16
|
import logging
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Logits processing."""
|
17
15
|
|
18
16
|
import dataclasses
|