sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +9 -6
- sglang/bench_serving.py +46 -22
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +16 -7
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +32 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +9 -2
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +7 -2
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +40 -16
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +115 -97
- sglang/srt/managers/tokenizer_manager.py +194 -112
- sglang/srt/managers/tp_worker.py +290 -359
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +71 -25
- sglang/srt/model_executor/forward_batch_info.py +293 -156
- sglang/srt/model_executor/model_runner.py +77 -57
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -6
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -13
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +187 -48
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -8
- sglang/srt/server.py +91 -29
- sglang/srt/server_args.py +32 -19
- sglang/srt/utils.py +32 -15
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +81 -73
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +36 -7
- sglang/test/test_utils.py +24 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
- sglang-0.2.13.dist-info/RECORD +112 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,587 @@
|
|
1
|
+
# Adapted from
|
2
|
+
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
|
3
|
+
from abc import abstractmethod
|
4
|
+
from typing import List, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from vllm.distributed import (
|
8
|
+
get_tensor_model_parallel_rank,
|
9
|
+
get_tensor_model_parallel_world_size,
|
10
|
+
tensor_model_parallel_all_reduce,
|
11
|
+
)
|
12
|
+
from vllm.logger import init_logger
|
13
|
+
from vllm.model_executor.custom_op import CustomOp
|
14
|
+
from vllm.model_executor.layers.quantization.base_config import (
|
15
|
+
QuantizationConfig,
|
16
|
+
QuantizeMethodBase,
|
17
|
+
)
|
18
|
+
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
19
|
+
from vllm.model_executor.utils import set_weight_attrs
|
20
|
+
|
21
|
+
logger = init_logger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class FusedMoEMethodBase(QuantizeMethodBase):
|
25
|
+
|
26
|
+
@abstractmethod
|
27
|
+
def create_weights(
|
28
|
+
self,
|
29
|
+
layer: torch.nn.Module,
|
30
|
+
num_experts: int,
|
31
|
+
hidden_size: int,
|
32
|
+
intermediate_size: int,
|
33
|
+
params_dtype: torch.dtype,
|
34
|
+
**extra_weight_attrs,
|
35
|
+
):
|
36
|
+
raise NotImplementedError
|
37
|
+
|
38
|
+
@abstractmethod
|
39
|
+
def apply(
|
40
|
+
self,
|
41
|
+
layer: torch.nn.Module,
|
42
|
+
x: torch.Tensor,
|
43
|
+
router_logits: torch.Tensor,
|
44
|
+
top_k: int,
|
45
|
+
renormalize: bool = True,
|
46
|
+
use_grouped_topk: bool = False,
|
47
|
+
num_expert_group: Optional[int] = None,
|
48
|
+
topk_group: Optional[int] = None,
|
49
|
+
) -> torch.Tensor:
|
50
|
+
raise NotImplementedError
|
51
|
+
|
52
|
+
|
53
|
+
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
54
|
+
"""MoE method without quantization."""
|
55
|
+
|
56
|
+
def create_weights(
|
57
|
+
self,
|
58
|
+
layer: torch.nn.Module,
|
59
|
+
num_experts: int,
|
60
|
+
hidden_size: int,
|
61
|
+
intermediate_size: int,
|
62
|
+
params_dtype: torch.dtype,
|
63
|
+
**extra_weight_attrs,
|
64
|
+
):
|
65
|
+
|
66
|
+
# Fused gate_up_proj (column parallel)
|
67
|
+
w13_weight = torch.nn.Parameter(
|
68
|
+
torch.empty(
|
69
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
70
|
+
),
|
71
|
+
requires_grad=False,
|
72
|
+
)
|
73
|
+
layer.register_parameter("w13_weight", w13_weight)
|
74
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
75
|
+
|
76
|
+
# down_proj (row parallel)
|
77
|
+
w2_weight = torch.nn.Parameter(
|
78
|
+
torch.empty(
|
79
|
+
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
80
|
+
),
|
81
|
+
requires_grad=False,
|
82
|
+
)
|
83
|
+
layer.register_parameter("w2_weight", w2_weight)
|
84
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
85
|
+
|
86
|
+
def apply(
|
87
|
+
self,
|
88
|
+
layer: torch.nn.Module,
|
89
|
+
x: torch.Tensor,
|
90
|
+
router_logits: torch.Tensor,
|
91
|
+
top_k: int,
|
92
|
+
renormalize: bool = True,
|
93
|
+
use_grouped_topk: bool = False,
|
94
|
+
num_expert_group: Optional[int] = None,
|
95
|
+
topk_group: Optional[int] = None,
|
96
|
+
) -> torch.Tensor:
|
97
|
+
return self.forward(
|
98
|
+
x,
|
99
|
+
layer.w13_weight,
|
100
|
+
layer.w2_weight,
|
101
|
+
router_logits,
|
102
|
+
top_k,
|
103
|
+
renormalize,
|
104
|
+
use_grouped_topk,
|
105
|
+
num_expert_group,
|
106
|
+
topk_group,
|
107
|
+
)
|
108
|
+
|
109
|
+
def forward_cuda(
|
110
|
+
self,
|
111
|
+
x: torch.Tensor,
|
112
|
+
w1: torch.Tensor,
|
113
|
+
w2: torch.Tensor,
|
114
|
+
router_logits: torch.Tensor,
|
115
|
+
top_k: int,
|
116
|
+
renormalize: bool,
|
117
|
+
use_grouped_topk: bool,
|
118
|
+
num_expert_group: Optional[int],
|
119
|
+
topk_group: Optional[int],
|
120
|
+
) -> torch.Tensor:
|
121
|
+
from sglang.srt.layers.fused_moe.fused_moe import fused_moe
|
122
|
+
|
123
|
+
return fused_moe(
|
124
|
+
x,
|
125
|
+
w1,
|
126
|
+
w2,
|
127
|
+
router_logits,
|
128
|
+
top_k,
|
129
|
+
renormalize=renormalize,
|
130
|
+
inplace=True,
|
131
|
+
use_grouped_topk=use_grouped_topk,
|
132
|
+
num_expert_group=num_expert_group,
|
133
|
+
topk_group=topk_group,
|
134
|
+
)
|
135
|
+
|
136
|
+
def forward_cpu(self, *args, **kwargs):
|
137
|
+
raise NotImplementedError("The CPU backend currently does not support MoE.")
|
138
|
+
|
139
|
+
def forward_tpu(
|
140
|
+
self,
|
141
|
+
x: torch.Tensor,
|
142
|
+
w1: torch.Tensor,
|
143
|
+
w2: torch.Tensor,
|
144
|
+
router_logits: torch.Tensor,
|
145
|
+
top_k: int,
|
146
|
+
renormalize: bool,
|
147
|
+
use_grouped_topk: bool,
|
148
|
+
num_expert_group: Optional[int],
|
149
|
+
topk_group: Optional[int],
|
150
|
+
) -> torch.Tensor:
|
151
|
+
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
152
|
+
|
153
|
+
assert not use_grouped_topk
|
154
|
+
assert num_expert_group is None
|
155
|
+
assert topk_group is None
|
156
|
+
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
|
157
|
+
|
158
|
+
|
159
|
+
class FusedMoE(torch.nn.Module):
|
160
|
+
"""FusedMoE layer for MoE models.
|
161
|
+
|
162
|
+
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
163
|
+
w13) and RowParallelLinear weights (down_proj/ w2).
|
164
|
+
|
165
|
+
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
166
|
+
copy that naming convention here and handle any remapping in the
|
167
|
+
load_weights function in each model implementation.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
num_experts: Number of experts in the model
|
171
|
+
top_k: Number of experts selected for each token
|
172
|
+
hidden_size: Input hidden state size of the transformer
|
173
|
+
intermediate_size: Intermediate size of the experts
|
174
|
+
params_dtype: Data type for the parameters.
|
175
|
+
reduce_results: Whether to all all_reduce on the output of the layer
|
176
|
+
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
177
|
+
quant_config: Quantization configure.
|
178
|
+
"""
|
179
|
+
|
180
|
+
def __init__(
|
181
|
+
self,
|
182
|
+
num_experts: int,
|
183
|
+
top_k: int,
|
184
|
+
hidden_size: int,
|
185
|
+
intermediate_size: int,
|
186
|
+
params_dtype: Optional[torch.dtype] = None,
|
187
|
+
reduce_results: bool = False,
|
188
|
+
renormalize: bool = True,
|
189
|
+
use_grouped_topk: bool = False,
|
190
|
+
num_expert_group: Optional[int] = None,
|
191
|
+
topk_group: Optional[int] = None,
|
192
|
+
quant_config: Optional[QuantizationConfig] = None,
|
193
|
+
tp_size: Optional[int] = None,
|
194
|
+
prefix: str = "",
|
195
|
+
):
|
196
|
+
super().__init__()
|
197
|
+
|
198
|
+
if params_dtype is None:
|
199
|
+
params_dtype = torch.get_default_dtype()
|
200
|
+
|
201
|
+
self.tp_size = (
|
202
|
+
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
203
|
+
)
|
204
|
+
self.top_k = top_k
|
205
|
+
self.num_experts = num_experts
|
206
|
+
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
207
|
+
self.reduce_results = reduce_results
|
208
|
+
self.renormalize = renormalize
|
209
|
+
self.use_grouped_topk = use_grouped_topk
|
210
|
+
if self.use_grouped_topk:
|
211
|
+
assert num_expert_group is not None and topk_group is not None
|
212
|
+
self.num_expert_group = num_expert_group
|
213
|
+
self.topk_group = topk_group
|
214
|
+
|
215
|
+
if quant_config is None:
|
216
|
+
self.quant_method: Optional[QuantizeMethodBase] = (
|
217
|
+
UnquantizedFusedMoEMethod()
|
218
|
+
)
|
219
|
+
else:
|
220
|
+
if isinstance(quant_config, Fp8Config):
|
221
|
+
self.quant_method = Fp8MoEMethod(quant_config)
|
222
|
+
else:
|
223
|
+
self.quant_method = quant_config.get_quant_method(self, prefix)
|
224
|
+
assert self.quant_method is not None
|
225
|
+
|
226
|
+
self.quant_method.create_weights(
|
227
|
+
layer=self,
|
228
|
+
num_experts=num_experts,
|
229
|
+
hidden_size=hidden_size,
|
230
|
+
intermediate_size=self.intermediate_size_per_partition,
|
231
|
+
params_dtype=params_dtype,
|
232
|
+
weight_loader=self.weight_loader,
|
233
|
+
)
|
234
|
+
|
235
|
+
def weight_loader(
|
236
|
+
self,
|
237
|
+
param: torch.nn.Parameter,
|
238
|
+
loaded_weight: torch.Tensor,
|
239
|
+
weight_name: str,
|
240
|
+
shard_id: int,
|
241
|
+
expert_id: int,
|
242
|
+
pre_sharded: bool,
|
243
|
+
):
|
244
|
+
param_data = param.data
|
245
|
+
|
246
|
+
# Input scales can be loaded directly and should be equal.
|
247
|
+
if "input_scale" in weight_name:
|
248
|
+
if (
|
249
|
+
param_data[expert_id] != 1
|
250
|
+
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
251
|
+
):
|
252
|
+
raise ValueError(
|
253
|
+
"input_scales of w1 and w3 of a layer "
|
254
|
+
f"must be equal. But got {param_data[expert_id]} "
|
255
|
+
f"vs. {loaded_weight}"
|
256
|
+
)
|
257
|
+
param_data[expert_id] = loaded_weight
|
258
|
+
# Weight scales
|
259
|
+
elif "weight_scale" in weight_name:
|
260
|
+
# If we are in merged column case (gate_up_proj)
|
261
|
+
# shard_id 0 == gate_proj / w1
|
262
|
+
# shard_id 2 == up_proj / w3
|
263
|
+
if shard_id == 0 or shard_id == 2:
|
264
|
+
# We have to keep the weight scales of w1 and w3 because
|
265
|
+
# we need to re-quantize w1/w3 weights after weight loading.
|
266
|
+
idx = 0 if shard_id == 0 else 1
|
267
|
+
param_data[expert_id][idx] = loaded_weight
|
268
|
+
# If we are in the row parallel case (down_proj)
|
269
|
+
# shard_id 1 == down_proj / w2
|
270
|
+
else:
|
271
|
+
param_data[expert_id] = loaded_weight
|
272
|
+
# Weights
|
273
|
+
else:
|
274
|
+
tp_rank = get_tensor_model_parallel_rank()
|
275
|
+
shard_size = self.intermediate_size_per_partition
|
276
|
+
if pre_sharded:
|
277
|
+
shard = slice(None)
|
278
|
+
else:
|
279
|
+
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
280
|
+
|
281
|
+
# w1, gate_proj case: Load into first shard of w13.
|
282
|
+
if shard_id == 0:
|
283
|
+
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
284
|
+
# w3, up_proj case: Load into second shard of w13.
|
285
|
+
elif shard_id == 2:
|
286
|
+
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
287
|
+
shard, :
|
288
|
+
]
|
289
|
+
# w2, down_proj case: Load into only shard of w2.
|
290
|
+
elif shard_id == 1:
|
291
|
+
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
292
|
+
else:
|
293
|
+
raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}")
|
294
|
+
|
295
|
+
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
296
|
+
assert self.quant_method is not None
|
297
|
+
|
298
|
+
# Matrix multiply.
|
299
|
+
final_hidden_states = self.quant_method.apply(
|
300
|
+
self,
|
301
|
+
x=hidden_states,
|
302
|
+
router_logits=router_logits,
|
303
|
+
top_k=self.top_k,
|
304
|
+
renormalize=self.renormalize,
|
305
|
+
use_grouped_topk=self.use_grouped_topk,
|
306
|
+
num_expert_group=self.num_expert_group,
|
307
|
+
topk_group=self.topk_group,
|
308
|
+
)
|
309
|
+
|
310
|
+
if self.reduce_results and self.tp_size > 1:
|
311
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
312
|
+
|
313
|
+
return final_hidden_states
|
314
|
+
|
315
|
+
@classmethod
|
316
|
+
def make_expert_params_mapping(
|
317
|
+
cls,
|
318
|
+
ckpt_gate_proj_name: str,
|
319
|
+
ckpt_down_proj_name: str,
|
320
|
+
ckpt_up_proj_name: str,
|
321
|
+
num_experts: int,
|
322
|
+
) -> List[Tuple[str, str, int, int]]:
|
323
|
+
|
324
|
+
gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
|
325
|
+
gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name]
|
326
|
+
|
327
|
+
return (
|
328
|
+
[
|
329
|
+
# These are the weight scales for the experts
|
330
|
+
# (param_name, weight_name, expert_id, shard_id)
|
331
|
+
(
|
332
|
+
(
|
333
|
+
"experts.w13_scale"
|
334
|
+
if weight_name in gate_up
|
335
|
+
else "experts.w2_scale"
|
336
|
+
),
|
337
|
+
f"experts.{expert_id}.{weight_name}.weight_scale",
|
338
|
+
expert_id,
|
339
|
+
shard_id,
|
340
|
+
)
|
341
|
+
for expert_id in range(num_experts)
|
342
|
+
for shard_id, weight_name in enumerate(gate_down_up)
|
343
|
+
]
|
344
|
+
+ [
|
345
|
+
# These are the weights for the experts
|
346
|
+
# (param_name, weight_name, expert_id, shard_id)
|
347
|
+
(
|
348
|
+
(
|
349
|
+
"experts.w13_weight"
|
350
|
+
if weight_name in gate_up
|
351
|
+
else "experts.w2_weight"
|
352
|
+
),
|
353
|
+
f"experts.{expert_id}.{weight_name}.weight",
|
354
|
+
expert_id,
|
355
|
+
shard_id,
|
356
|
+
)
|
357
|
+
for expert_id in range(num_experts)
|
358
|
+
for shard_id, weight_name in enumerate(gate_down_up)
|
359
|
+
]
|
360
|
+
+ [
|
361
|
+
# These are the weight scales for the experts
|
362
|
+
# (param_name, weight_name, expert_id, shard_id)
|
363
|
+
(
|
364
|
+
(
|
365
|
+
"experts.a13_scale"
|
366
|
+
if weight_name in gate_up
|
367
|
+
else "experts.a2_scale"
|
368
|
+
),
|
369
|
+
f"experts.{expert_id}.{weight_name}.input_scale",
|
370
|
+
expert_id,
|
371
|
+
shard_id,
|
372
|
+
)
|
373
|
+
for expert_id in range(num_experts)
|
374
|
+
for shard_id, weight_name in enumerate(gate_down_up)
|
375
|
+
]
|
376
|
+
)
|
377
|
+
|
378
|
+
|
379
|
+
import torch
|
380
|
+
from torch.nn import Module
|
381
|
+
from vllm import _custom_ops as ops
|
382
|
+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
383
|
+
all_close_1d,
|
384
|
+
per_tensor_dequantize,
|
385
|
+
)
|
386
|
+
from vllm.utils import print_warning_once
|
387
|
+
|
388
|
+
|
389
|
+
class Fp8MoEMethod(FusedMoEMethodBase):
|
390
|
+
"""MoE method for FP8.
|
391
|
+
Supports loading FP8 checkpoints with static weight scale and
|
392
|
+
dynamic/static activation scale.
|
393
|
+
|
394
|
+
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
395
|
+
activation scaling. The weight scaling factor will be initialized after
|
396
|
+
the model weights are loaded.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
quant_config: The quantization config.
|
400
|
+
"""
|
401
|
+
|
402
|
+
def __init__(self, quant_config: Fp8Config):
|
403
|
+
self.quant_config = quant_config
|
404
|
+
|
405
|
+
def create_weights(
|
406
|
+
self,
|
407
|
+
layer: Module,
|
408
|
+
num_experts: int,
|
409
|
+
hidden_size: int,
|
410
|
+
intermediate_size: int,
|
411
|
+
params_dtype: torch.dtype,
|
412
|
+
**extra_weight_attrs,
|
413
|
+
):
|
414
|
+
|
415
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
416
|
+
params_dtype = torch.float8_e4m3fn
|
417
|
+
|
418
|
+
# WEIGHTS
|
419
|
+
w13_weight = torch.nn.Parameter(
|
420
|
+
torch.empty(
|
421
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
422
|
+
),
|
423
|
+
requires_grad=False,
|
424
|
+
)
|
425
|
+
layer.register_parameter("w13_weight", w13_weight)
|
426
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
427
|
+
|
428
|
+
w2_weight = torch.nn.Parameter(
|
429
|
+
torch.empty(
|
430
|
+
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
431
|
+
),
|
432
|
+
requires_grad=False,
|
433
|
+
)
|
434
|
+
layer.register_parameter("w2_weight", w2_weight)
|
435
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
436
|
+
|
437
|
+
# WEIGHT_SCALES
|
438
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
439
|
+
# They will be combined to a single scale after weight loading.
|
440
|
+
w13_scale = torch.nn.Parameter(
|
441
|
+
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
442
|
+
)
|
443
|
+
layer.register_parameter("w13_scale", w13_scale)
|
444
|
+
|
445
|
+
w2_scale = torch.nn.Parameter(
|
446
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
447
|
+
)
|
448
|
+
layer.register_parameter("w2_scale", w2_scale)
|
449
|
+
|
450
|
+
# If loading fp8 checkpoint, pass the weight loaders.
|
451
|
+
# If loading an fp16 checkpoint, do not (we will quantize in
|
452
|
+
# process_weights_after_loading()
|
453
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
454
|
+
set_weight_attrs(w13_scale, extra_weight_attrs)
|
455
|
+
set_weight_attrs(w2_scale, extra_weight_attrs)
|
456
|
+
|
457
|
+
# INPUT_SCALES
|
458
|
+
if self.quant_config.activation_scheme == "static":
|
459
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
460
|
+
raise ValueError(
|
461
|
+
"Found static activation scheme for checkpoint that "
|
462
|
+
"was not serialized fp8."
|
463
|
+
)
|
464
|
+
|
465
|
+
a13_scale = torch.nn.Parameter(
|
466
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
467
|
+
)
|
468
|
+
layer.register_parameter("a13_scale", a13_scale)
|
469
|
+
set_weight_attrs(a13_scale, extra_weight_attrs)
|
470
|
+
|
471
|
+
a2_scale = torch.nn.Parameter(
|
472
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
473
|
+
)
|
474
|
+
layer.register_parameter("a2_scale", a2_scale)
|
475
|
+
set_weight_attrs(a2_scale, extra_weight_attrs)
|
476
|
+
else:
|
477
|
+
layer.a13_scale = None
|
478
|
+
layer.a2_scale = None
|
479
|
+
|
480
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
481
|
+
|
482
|
+
# If checkpoint is fp16, quantize in place.
|
483
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
484
|
+
w13_weight = torch.empty_like(
|
485
|
+
layer.w13_weight.data, dtype=torch.float8_e4m3fn
|
486
|
+
)
|
487
|
+
w2_weight = torch.empty_like(
|
488
|
+
layer.w2_weight.data, dtype=torch.float8_e4m3fn
|
489
|
+
)
|
490
|
+
|
491
|
+
# Re-initialize w13_scale because we directly quantize
|
492
|
+
# merged w13 weights and generate a single scaling factor.
|
493
|
+
layer.w13_scale = torch.nn.Parameter(
|
494
|
+
torch.ones(
|
495
|
+
layer.num_experts, dtype=torch.float32, device=w13_weight.device
|
496
|
+
),
|
497
|
+
requires_grad=False,
|
498
|
+
)
|
499
|
+
for expert in range(layer.num_experts):
|
500
|
+
w13_weight[expert, :, :], layer.w13_scale[expert] = (
|
501
|
+
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
502
|
+
)
|
503
|
+
w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant(
|
504
|
+
layer.w2_weight.data[expert, :, :]
|
505
|
+
)
|
506
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
507
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
508
|
+
return
|
509
|
+
|
510
|
+
# If checkpoint is fp8, we need to handle that the
|
511
|
+
# MoE kernels require single activation scale and single weight
|
512
|
+
# scale for w13 per expert.
|
513
|
+
else:
|
514
|
+
# Fp8 moe kernels require a single activation scale.
|
515
|
+
# We take the max of all the scales in case they differ.
|
516
|
+
if self.quant_config.activation_scheme == "static":
|
517
|
+
if layer.a13_scale is None or layer.a2_scale is None:
|
518
|
+
raise ValueError(
|
519
|
+
"QuantConfig has static quantization, but found "
|
520
|
+
"activation scales are None."
|
521
|
+
)
|
522
|
+
if not all_close_1d(layer.a13_scale) or not all_close_1d(
|
523
|
+
layer.a2_scale
|
524
|
+
):
|
525
|
+
print_warning_once(
|
526
|
+
"Found input_scales that are not equal for "
|
527
|
+
"fp8 MoE layer. Using the maximum across experts "
|
528
|
+
"for each layer. "
|
529
|
+
)
|
530
|
+
layer.a13_scale = torch.nn.Parameter(
|
531
|
+
layer.a13_scale.max(), requires_grad=False
|
532
|
+
)
|
533
|
+
layer.a2_scale = torch.nn.Parameter(
|
534
|
+
layer.a2_scale.max(), requires_grad=False
|
535
|
+
)
|
536
|
+
|
537
|
+
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
538
|
+
# We take the max then dequant and requant each expert.
|
539
|
+
assert layer.w13_scale is not None
|
540
|
+
shard_size = layer.intermediate_size_per_partition
|
541
|
+
max_w13_scales = layer.w13_scale.max(dim=1).values
|
542
|
+
for expert_id in range(layer.num_experts):
|
543
|
+
start = 0
|
544
|
+
for shard_id in range(2):
|
545
|
+
dq_weight = per_tensor_dequantize(
|
546
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
547
|
+
layer.w13_scale[expert_id][shard_id],
|
548
|
+
)
|
549
|
+
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
550
|
+
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
551
|
+
)
|
552
|
+
start += shard_size
|
553
|
+
|
554
|
+
layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
555
|
+
return
|
556
|
+
|
557
|
+
def apply(
|
558
|
+
self,
|
559
|
+
layer: torch.nn.Module,
|
560
|
+
x: torch.Tensor,
|
561
|
+
router_logits: torch.Tensor,
|
562
|
+
top_k: int,
|
563
|
+
renormalize: bool = True,
|
564
|
+
use_grouped_topk: bool = False,
|
565
|
+
num_expert_group: Optional[int] = None,
|
566
|
+
topk_group: Optional[int] = None,
|
567
|
+
) -> torch.Tensor:
|
568
|
+
|
569
|
+
from sglang.srt.layers.fused_moe.fused_moe import fused_moe
|
570
|
+
|
571
|
+
return fused_moe(
|
572
|
+
x,
|
573
|
+
layer.w13_weight,
|
574
|
+
layer.w2_weight,
|
575
|
+
router_logits,
|
576
|
+
top_k,
|
577
|
+
renormalize=renormalize,
|
578
|
+
inplace=True,
|
579
|
+
use_fp8=True,
|
580
|
+
w1_scale=layer.w13_scale,
|
581
|
+
w2_scale=layer.w2_scale,
|
582
|
+
a1_scale=layer.a13_scale,
|
583
|
+
a2_scale=layer.a2_scale,
|
584
|
+
use_grouped_topk=use_grouped_topk,
|
585
|
+
num_expert_group=num_expert_group,
|
586
|
+
topk_group=topk_group,
|
587
|
+
)
|
@@ -0,0 +1,65 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""Fused operators for normalization layers."""
|
17
|
+
|
18
|
+
from typing import Optional, Tuple, Union
|
19
|
+
|
20
|
+
import torch
|
21
|
+
import torch.nn as nn
|
22
|
+
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
23
|
+
from vllm.model_executor.custom_op import CustomOp
|
24
|
+
|
25
|
+
|
26
|
+
class RMSNorm(CustomOp):
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
hidden_size: int,
|
30
|
+
eps: float = 1e-6,
|
31
|
+
) -> None:
|
32
|
+
super().__init__()
|
33
|
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
34
|
+
self.variance_epsilon = eps
|
35
|
+
|
36
|
+
def forward_cuda(
|
37
|
+
self,
|
38
|
+
x: torch.Tensor,
|
39
|
+
residual: Optional[torch.Tensor] = None,
|
40
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
41
|
+
|
42
|
+
if residual is not None:
|
43
|
+
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
44
|
+
return x, residual
|
45
|
+
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
46
|
+
return out
|
47
|
+
|
48
|
+
def forward_native(
|
49
|
+
self,
|
50
|
+
x: torch.Tensor,
|
51
|
+
residual: Optional[torch.Tensor] = None,
|
52
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
53
|
+
orig_dtype = x.dtype
|
54
|
+
x = x.to(torch.float32)
|
55
|
+
if residual is not None:
|
56
|
+
x = x + residual.to(torch.float32)
|
57
|
+
residual = x.to(orig_dtype)
|
58
|
+
|
59
|
+
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
60
|
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
61
|
+
x = x.to(orig_dtype) * self.weight
|
62
|
+
if residual is None:
|
63
|
+
return x
|
64
|
+
else:
|
65
|
+
return x, residual
|
@@ -164,9 +164,9 @@ class LogitsProcessor(nn.Module):
|
|
164
164
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
165
165
|
|
166
166
|
if hasattr(self.config, "final_logit_softcapping"):
|
167
|
-
last_logits
|
167
|
+
last_logits.div_(self.config.final_logit_softcapping)
|
168
168
|
last_logits = torch.tanh(last_logits)
|
169
|
-
last_logits
|
169
|
+
last_logits.mul_(self.config.final_logit_softcapping)
|
170
170
|
|
171
171
|
# Return only last_logits if logprob is not requested
|
172
172
|
if not logits_metadata.return_logprob:
|
@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module):
|
|
208
208
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
209
209
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
210
210
|
|
211
|
+
if hasattr(self.config, "final_logit_softcapping"):
|
212
|
+
all_logits.div_(self.config.final_logit_softcapping)
|
213
|
+
all_logits = torch.tanh(all_logits)
|
214
|
+
all_logits.mul_(self.config.final_logit_softcapping)
|
215
|
+
|
211
216
|
all_logprobs = all_logits
|
212
217
|
del all_logits, hidden_states
|
213
218
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|