sglang 0.3.6.post1__py3-none-any.whl → 0.3.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +4 -8
- sglang/bench_one_batch_server.py +6 -5
- sglang/check_env.py +7 -1
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +2 -4
- sglang/srt/configs/model_config.py +2 -6
- sglang/srt/layers/attention/flashinfer_backend.py +3 -3
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -6
- sglang/srt/managers/image_processor.py +7 -10
- sglang/srt/managers/io_struct.py +0 -10
- sglang/srt/managers/schedule_batch.py +51 -13
- sglang/srt/managers/scheduler.py +41 -29
- sglang/srt/managers/session_controller.py +15 -7
- sglang/srt/managers/tokenizer_manager.py +4 -33
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -2
- sglang/srt/models/grok.py +11 -48
- sglang/srt/models/llava.py +16 -9
- sglang/srt/models/olmo2.py +392 -0
- sglang/srt/models/qwen2_vl.py +10 -3
- sglang/srt/openai_api/adapter.py +1 -1
- sglang/srt/server.py +48 -45
- sglang/srt/server_args.py +1 -1
- sglang/srt/utils.py +22 -24
- sglang/test/test_utils.py +21 -8
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/METADATA +4 -2
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/RECORD +34 -36
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/top_level.txt +0 -0
@@ -1,630 +0,0 @@
|
|
1
|
-
# Adapted from
|
2
|
-
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
|
3
|
-
import os
|
4
|
-
from abc import abstractmethod
|
5
|
-
from typing import List, Optional, Tuple
|
6
|
-
|
7
|
-
import torch
|
8
|
-
import torch.nn.functional as F
|
9
|
-
from vllm.distributed import (
|
10
|
-
get_tensor_model_parallel_rank,
|
11
|
-
get_tensor_model_parallel_world_size,
|
12
|
-
tensor_model_parallel_all_reduce,
|
13
|
-
)
|
14
|
-
from vllm.logger import init_logger
|
15
|
-
from vllm.model_executor.custom_op import CustomOp
|
16
|
-
from vllm.model_executor.layers.quantization.base_config import (
|
17
|
-
QuantizationConfig,
|
18
|
-
QuantizeMethodBase,
|
19
|
-
)
|
20
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
21
|
-
from vllm.model_executor.utils import set_weight_attrs
|
22
|
-
|
23
|
-
from sglang.srt.layers.fused_moe_grok.fused_moe import padding_size
|
24
|
-
from sglang.srt.utils import is_hip
|
25
|
-
|
26
|
-
logger = init_logger(__name__)
|
27
|
-
|
28
|
-
|
29
|
-
class FusedMoEMethodBase(QuantizeMethodBase):
|
30
|
-
|
31
|
-
@abstractmethod
|
32
|
-
def create_weights(
|
33
|
-
self,
|
34
|
-
layer: torch.nn.Module,
|
35
|
-
num_experts: int,
|
36
|
-
hidden_size: int,
|
37
|
-
intermediate_size: int,
|
38
|
-
params_dtype: torch.dtype,
|
39
|
-
**extra_weight_attrs,
|
40
|
-
):
|
41
|
-
raise NotImplementedError
|
42
|
-
|
43
|
-
@abstractmethod
|
44
|
-
def apply(
|
45
|
-
self,
|
46
|
-
layer: torch.nn.Module,
|
47
|
-
x: torch.Tensor,
|
48
|
-
router_logits: torch.Tensor,
|
49
|
-
top_k: int,
|
50
|
-
renormalize: bool = True,
|
51
|
-
use_grouped_topk: bool = False,
|
52
|
-
num_expert_group: Optional[int] = None,
|
53
|
-
topk_group: Optional[int] = None,
|
54
|
-
) -> torch.Tensor:
|
55
|
-
raise NotImplementedError
|
56
|
-
|
57
|
-
|
58
|
-
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
59
|
-
"""MoE method without quantization."""
|
60
|
-
|
61
|
-
def create_weights(
|
62
|
-
self,
|
63
|
-
layer: torch.nn.Module,
|
64
|
-
num_experts: int,
|
65
|
-
hidden_size: int,
|
66
|
-
intermediate_size: int,
|
67
|
-
params_dtype: torch.dtype,
|
68
|
-
**extra_weight_attrs,
|
69
|
-
):
|
70
|
-
|
71
|
-
# Fused gate_up_proj (column parallel)
|
72
|
-
w13_weight = torch.nn.Parameter(
|
73
|
-
torch.empty(
|
74
|
-
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
75
|
-
),
|
76
|
-
requires_grad=False,
|
77
|
-
)
|
78
|
-
layer.register_parameter("w13_weight", w13_weight)
|
79
|
-
set_weight_attrs(w13_weight, extra_weight_attrs)
|
80
|
-
|
81
|
-
# down_proj (row parallel)
|
82
|
-
w2_weight = torch.nn.Parameter(
|
83
|
-
torch.empty(
|
84
|
-
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
85
|
-
),
|
86
|
-
requires_grad=False,
|
87
|
-
)
|
88
|
-
layer.register_parameter("w2_weight", w2_weight)
|
89
|
-
set_weight_attrs(w2_weight, extra_weight_attrs)
|
90
|
-
|
91
|
-
def apply(
|
92
|
-
self,
|
93
|
-
layer: torch.nn.Module,
|
94
|
-
x: torch.Tensor,
|
95
|
-
router_logits: torch.Tensor,
|
96
|
-
top_k: int,
|
97
|
-
renormalize: bool = True,
|
98
|
-
use_grouped_topk: bool = False,
|
99
|
-
num_expert_group: Optional[int] = None,
|
100
|
-
topk_group: Optional[int] = None,
|
101
|
-
) -> torch.Tensor:
|
102
|
-
return self.forward(
|
103
|
-
x,
|
104
|
-
layer.w13_weight,
|
105
|
-
layer.w2_weight,
|
106
|
-
router_logits,
|
107
|
-
top_k,
|
108
|
-
renormalize,
|
109
|
-
use_grouped_topk,
|
110
|
-
num_expert_group,
|
111
|
-
topk_group,
|
112
|
-
)
|
113
|
-
|
114
|
-
def forward_cuda(
|
115
|
-
self,
|
116
|
-
x: torch.Tensor,
|
117
|
-
w1: torch.Tensor,
|
118
|
-
w2: torch.Tensor,
|
119
|
-
router_logits: torch.Tensor,
|
120
|
-
top_k: int,
|
121
|
-
renormalize: bool,
|
122
|
-
use_grouped_topk: bool,
|
123
|
-
num_expert_group: Optional[int],
|
124
|
-
topk_group: Optional[int],
|
125
|
-
) -> torch.Tensor:
|
126
|
-
from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe
|
127
|
-
|
128
|
-
return fused_moe(
|
129
|
-
x,
|
130
|
-
w1,
|
131
|
-
w2,
|
132
|
-
router_logits,
|
133
|
-
top_k,
|
134
|
-
renormalize=renormalize,
|
135
|
-
inplace=True,
|
136
|
-
use_grouped_topk=use_grouped_topk,
|
137
|
-
num_expert_group=num_expert_group,
|
138
|
-
topk_group=topk_group,
|
139
|
-
)
|
140
|
-
|
141
|
-
def forward_cpu(self, *args, **kwargs):
|
142
|
-
raise NotImplementedError("The CPU backend currently does not support MoE.")
|
143
|
-
|
144
|
-
def forward_tpu(
|
145
|
-
self,
|
146
|
-
x: torch.Tensor,
|
147
|
-
w1: torch.Tensor,
|
148
|
-
w2: torch.Tensor,
|
149
|
-
router_logits: torch.Tensor,
|
150
|
-
top_k: int,
|
151
|
-
renormalize: bool,
|
152
|
-
use_grouped_topk: bool,
|
153
|
-
num_expert_group: Optional[int],
|
154
|
-
topk_group: Optional[int],
|
155
|
-
) -> torch.Tensor:
|
156
|
-
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
157
|
-
|
158
|
-
|
159
|
-
class FusedMoE(torch.nn.Module):
|
160
|
-
"""FusedMoE layer for MoE models.
|
161
|
-
|
162
|
-
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
163
|
-
w13) and RowParallelLinear weights (down_proj/ w2).
|
164
|
-
|
165
|
-
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
166
|
-
copy that naming convention here and handle any remapping in the
|
167
|
-
load_weights function in each model implementation.
|
168
|
-
|
169
|
-
Args:
|
170
|
-
num_experts: Number of experts in the model
|
171
|
-
top_k: Number of experts selected for each token
|
172
|
-
hidden_size: Input hidden state size of the transformer
|
173
|
-
intermediate_size: Intermediate size of the experts
|
174
|
-
params_dtype: Data type for the parameters.
|
175
|
-
reduce_results: Whether to all all_reduce on the output of the layer
|
176
|
-
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
177
|
-
quant_config: Quantization configure.
|
178
|
-
"""
|
179
|
-
|
180
|
-
def __init__(
|
181
|
-
self,
|
182
|
-
num_experts: int,
|
183
|
-
top_k: int,
|
184
|
-
hidden_size: int,
|
185
|
-
intermediate_size: int,
|
186
|
-
params_dtype: Optional[torch.dtype] = None,
|
187
|
-
reduce_results: bool = False,
|
188
|
-
renormalize: bool = True,
|
189
|
-
use_grouped_topk: bool = False,
|
190
|
-
num_expert_group: Optional[int] = None,
|
191
|
-
topk_group: Optional[int] = None,
|
192
|
-
quant_config: Optional[QuantizationConfig] = None,
|
193
|
-
tp_size: Optional[int] = None,
|
194
|
-
prefix: str = "",
|
195
|
-
):
|
196
|
-
super().__init__()
|
197
|
-
|
198
|
-
if params_dtype is None:
|
199
|
-
params_dtype = torch.get_default_dtype()
|
200
|
-
|
201
|
-
self.tp_size = (
|
202
|
-
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
203
|
-
)
|
204
|
-
self.top_k = top_k
|
205
|
-
self.num_experts = num_experts
|
206
|
-
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
207
|
-
self.reduce_results = reduce_results
|
208
|
-
self.renormalize = renormalize
|
209
|
-
self.use_grouped_topk = use_grouped_topk
|
210
|
-
if self.use_grouped_topk:
|
211
|
-
assert num_expert_group is not None and topk_group is not None
|
212
|
-
self.num_expert_group = num_expert_group
|
213
|
-
self.topk_group = topk_group
|
214
|
-
|
215
|
-
if quant_config is None:
|
216
|
-
self.quant_method: Optional[QuantizeMethodBase] = (
|
217
|
-
UnquantizedFusedMoEMethod()
|
218
|
-
)
|
219
|
-
else:
|
220
|
-
if isinstance(quant_config, Fp8Config):
|
221
|
-
self.quant_method = Fp8MoEMethod(quant_config)
|
222
|
-
else:
|
223
|
-
self.quant_method = quant_config.get_quant_method(self, prefix)
|
224
|
-
assert self.quant_method is not None
|
225
|
-
|
226
|
-
self.quant_method.create_weights(
|
227
|
-
layer=self,
|
228
|
-
num_experts=num_experts,
|
229
|
-
hidden_size=hidden_size,
|
230
|
-
intermediate_size=self.intermediate_size_per_partition,
|
231
|
-
params_dtype=params_dtype,
|
232
|
-
weight_loader=self.weight_loader,
|
233
|
-
)
|
234
|
-
|
235
|
-
def weight_loader(
|
236
|
-
self,
|
237
|
-
param: torch.nn.Parameter,
|
238
|
-
loaded_weight: torch.Tensor,
|
239
|
-
weight_name: str,
|
240
|
-
shard_id: int,
|
241
|
-
expert_id: int,
|
242
|
-
use_presharded_weights: bool = False,
|
243
|
-
):
|
244
|
-
param_data = param.data
|
245
|
-
|
246
|
-
# Input scales can be loaded directly and should be equal.
|
247
|
-
if "input_scale" in weight_name:
|
248
|
-
if (
|
249
|
-
param_data[expert_id] != 1
|
250
|
-
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
251
|
-
):
|
252
|
-
raise ValueError(
|
253
|
-
"input_scales of w1 and w3 of a layer "
|
254
|
-
f"must be equal. But got {param_data[expert_id]} "
|
255
|
-
f"vs. {loaded_weight}"
|
256
|
-
)
|
257
|
-
param_data[expert_id] = loaded_weight
|
258
|
-
# Weight scales
|
259
|
-
elif "weight_scale" in weight_name:
|
260
|
-
# If we are in merged column case (gate_up_proj)
|
261
|
-
# shard_id 0 == gate_proj / w1
|
262
|
-
# shard_id 2 == up_proj / w3
|
263
|
-
if shard_id == 0 or shard_id == 2:
|
264
|
-
# We have to keep the weight scales of w1 and w3 because
|
265
|
-
# we need to re-quantize w1/w3 weights after weight loading.
|
266
|
-
idx = 0 if shard_id == 0 else 1
|
267
|
-
param_data[expert_id][idx] = loaded_weight
|
268
|
-
# If we are in the row parallel case (down_proj)
|
269
|
-
# shard_id 1 == down_proj / w2
|
270
|
-
else:
|
271
|
-
param_data[expert_id] = loaded_weight
|
272
|
-
# Weights
|
273
|
-
else:
|
274
|
-
tp_rank = get_tensor_model_parallel_rank()
|
275
|
-
shard_size = self.intermediate_size_per_partition
|
276
|
-
if use_presharded_weights:
|
277
|
-
shard = slice(None)
|
278
|
-
else:
|
279
|
-
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
280
|
-
|
281
|
-
# w1, gate_proj case: Load into first shard of w13.
|
282
|
-
if shard_id == 0:
|
283
|
-
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
284
|
-
# w3, up_proj case: Load into second shard of w13.
|
285
|
-
elif shard_id == 2:
|
286
|
-
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
287
|
-
shard, :
|
288
|
-
]
|
289
|
-
# w2, down_proj case: Load into only shard of w2.
|
290
|
-
elif shard_id == 1:
|
291
|
-
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
292
|
-
else:
|
293
|
-
raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}")
|
294
|
-
|
295
|
-
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
296
|
-
assert self.quant_method is not None
|
297
|
-
|
298
|
-
# Matrix multiply.
|
299
|
-
final_hidden_states = self.quant_method.apply(
|
300
|
-
self,
|
301
|
-
x=hidden_states,
|
302
|
-
router_logits=router_logits,
|
303
|
-
top_k=self.top_k,
|
304
|
-
renormalize=self.renormalize,
|
305
|
-
use_grouped_topk=self.use_grouped_topk,
|
306
|
-
num_expert_group=self.num_expert_group,
|
307
|
-
topk_group=self.topk_group,
|
308
|
-
)
|
309
|
-
|
310
|
-
if self.reduce_results and self.tp_size > 1:
|
311
|
-
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
312
|
-
|
313
|
-
return final_hidden_states
|
314
|
-
|
315
|
-
@classmethod
|
316
|
-
def make_expert_params_mapping(
|
317
|
-
cls,
|
318
|
-
ckpt_gate_proj_name: str,
|
319
|
-
ckpt_down_proj_name: str,
|
320
|
-
ckpt_up_proj_name: str,
|
321
|
-
num_experts: int,
|
322
|
-
) -> List[Tuple[str, str, int, int]]:
|
323
|
-
|
324
|
-
gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
|
325
|
-
gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name]
|
326
|
-
|
327
|
-
return (
|
328
|
-
[
|
329
|
-
# These are the weight scales for the experts
|
330
|
-
# (param_name, weight_name, expert_id, shard_id)
|
331
|
-
(
|
332
|
-
(
|
333
|
-
"experts.w13_scale"
|
334
|
-
if weight_name in gate_up
|
335
|
-
else "experts.w2_scale"
|
336
|
-
),
|
337
|
-
f"experts.{expert_id}.{weight_name}.weight_scale",
|
338
|
-
expert_id,
|
339
|
-
shard_id,
|
340
|
-
)
|
341
|
-
for expert_id in range(num_experts)
|
342
|
-
for shard_id, weight_name in enumerate(gate_down_up)
|
343
|
-
]
|
344
|
-
+ [
|
345
|
-
# These are the weights for the experts
|
346
|
-
# (param_name, weight_name, expert_id, shard_id)
|
347
|
-
(
|
348
|
-
(
|
349
|
-
"experts.w13_weight"
|
350
|
-
if weight_name in gate_up
|
351
|
-
else "experts.w2_weight"
|
352
|
-
),
|
353
|
-
f"experts.{expert_id}.{weight_name}.weight",
|
354
|
-
expert_id,
|
355
|
-
shard_id,
|
356
|
-
)
|
357
|
-
for expert_id in range(num_experts)
|
358
|
-
for shard_id, weight_name in enumerate(gate_down_up)
|
359
|
-
]
|
360
|
-
+ [
|
361
|
-
# These are the weight scales for the experts
|
362
|
-
# (param_name, weight_name, expert_id, shard_id)
|
363
|
-
(
|
364
|
-
(
|
365
|
-
"experts.a13_scale"
|
366
|
-
if weight_name in gate_up
|
367
|
-
else "experts.a2_scale"
|
368
|
-
),
|
369
|
-
f"experts.{expert_id}.{weight_name}.input_scale",
|
370
|
-
expert_id,
|
371
|
-
shard_id,
|
372
|
-
)
|
373
|
-
for expert_id in range(num_experts)
|
374
|
-
for shard_id, weight_name in enumerate(gate_down_up)
|
375
|
-
]
|
376
|
-
)
|
377
|
-
|
378
|
-
|
379
|
-
import torch
|
380
|
-
from torch.nn import Module
|
381
|
-
from vllm import _custom_ops as ops
|
382
|
-
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
383
|
-
all_close_1d,
|
384
|
-
normalize_e4m3fn_to_e4m3fnuz,
|
385
|
-
per_tensor_dequantize,
|
386
|
-
)
|
387
|
-
from vllm.utils import print_warning_once
|
388
|
-
|
389
|
-
|
390
|
-
class Fp8MoEMethod(FusedMoEMethodBase):
|
391
|
-
"""MoE method for FP8.
|
392
|
-
Supports loading FP8 checkpoints with static weight scale and
|
393
|
-
dynamic/static activation scale.
|
394
|
-
|
395
|
-
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
396
|
-
activation scaling. The weight scaling factor will be initialized after
|
397
|
-
the model weights are loaded.
|
398
|
-
|
399
|
-
Args:
|
400
|
-
quant_config: The quantization config.
|
401
|
-
"""
|
402
|
-
|
403
|
-
def __init__(self, quant_config: Fp8Config):
|
404
|
-
self.quant_config = quant_config
|
405
|
-
|
406
|
-
def create_weights(
|
407
|
-
self,
|
408
|
-
layer: Module,
|
409
|
-
num_experts: int,
|
410
|
-
hidden_size: int,
|
411
|
-
intermediate_size: int,
|
412
|
-
params_dtype: torch.dtype,
|
413
|
-
**extra_weight_attrs,
|
414
|
-
):
|
415
|
-
|
416
|
-
if self.quant_config.is_checkpoint_fp8_serialized:
|
417
|
-
params_dtype = torch.float8_e4m3fn
|
418
|
-
|
419
|
-
# WEIGHTS
|
420
|
-
w13_weight = torch.nn.Parameter(
|
421
|
-
torch.empty(
|
422
|
-
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
423
|
-
),
|
424
|
-
requires_grad=False,
|
425
|
-
)
|
426
|
-
layer.register_parameter("w13_weight", w13_weight)
|
427
|
-
set_weight_attrs(w13_weight, extra_weight_attrs)
|
428
|
-
|
429
|
-
w2_weight = torch.nn.Parameter(
|
430
|
-
torch.empty(
|
431
|
-
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
432
|
-
),
|
433
|
-
requires_grad=False,
|
434
|
-
)
|
435
|
-
layer.register_parameter("w2_weight", w2_weight)
|
436
|
-
set_weight_attrs(w2_weight, extra_weight_attrs)
|
437
|
-
|
438
|
-
# WEIGHT_SCALES
|
439
|
-
# Allocate 2 scales for w1 and w3 respectively.
|
440
|
-
# They will be combined to a single scale after weight loading.
|
441
|
-
w13_scale = torch.nn.Parameter(
|
442
|
-
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
443
|
-
)
|
444
|
-
layer.register_parameter("w13_scale", w13_scale)
|
445
|
-
|
446
|
-
w2_scale = torch.nn.Parameter(
|
447
|
-
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
448
|
-
)
|
449
|
-
layer.register_parameter("w2_scale", w2_scale)
|
450
|
-
|
451
|
-
# If loading fp8 checkpoint, pass the weight loaders.
|
452
|
-
# If loading an fp16 checkpoint, do not (we will quantize in
|
453
|
-
# process_weights_after_loading()
|
454
|
-
if self.quant_config.is_checkpoint_fp8_serialized:
|
455
|
-
set_weight_attrs(w13_scale, extra_weight_attrs)
|
456
|
-
set_weight_attrs(w2_scale, extra_weight_attrs)
|
457
|
-
|
458
|
-
# INPUT_SCALES
|
459
|
-
if self.quant_config.activation_scheme == "static":
|
460
|
-
if not self.quant_config.is_checkpoint_fp8_serialized:
|
461
|
-
raise ValueError(
|
462
|
-
"Found static activation scheme for checkpoint that "
|
463
|
-
"was not serialized fp8."
|
464
|
-
)
|
465
|
-
|
466
|
-
a13_scale = torch.nn.Parameter(
|
467
|
-
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
468
|
-
)
|
469
|
-
layer.register_parameter("a13_scale", a13_scale)
|
470
|
-
set_weight_attrs(a13_scale, extra_weight_attrs)
|
471
|
-
|
472
|
-
a2_scale = torch.nn.Parameter(
|
473
|
-
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
474
|
-
)
|
475
|
-
layer.register_parameter("a2_scale", a2_scale)
|
476
|
-
set_weight_attrs(a2_scale, extra_weight_attrs)
|
477
|
-
else:
|
478
|
-
layer.a13_scale = None
|
479
|
-
layer.a2_scale = None
|
480
|
-
|
481
|
-
def process_weights_after_loading(self, layer: Module) -> None:
|
482
|
-
|
483
|
-
# If checkpoint is fp16 or bfloat16, quantize in place.
|
484
|
-
if not self.quant_config.is_checkpoint_fp8_serialized:
|
485
|
-
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
486
|
-
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
487
|
-
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
488
|
-
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
489
|
-
|
490
|
-
# Re-initialize w13_scale because we directly quantize
|
491
|
-
# merged w13 weights and generate a single scaling factor.
|
492
|
-
layer.w13_scale = torch.nn.Parameter(
|
493
|
-
torch.ones(
|
494
|
-
layer.num_experts, dtype=torch.float32, device=w13_weight.device
|
495
|
-
),
|
496
|
-
requires_grad=False,
|
497
|
-
)
|
498
|
-
for expert in range(layer.num_experts):
|
499
|
-
w13_weight[expert, :, :], layer.w13_scale[expert] = (
|
500
|
-
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
501
|
-
)
|
502
|
-
w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant(
|
503
|
-
layer.w2_weight.data[expert, :, :]
|
504
|
-
)
|
505
|
-
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
506
|
-
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
507
|
-
|
508
|
-
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
509
|
-
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
510
|
-
layer.w13_weight = torch.nn.Parameter(
|
511
|
-
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
512
|
-
requires_grad=False,
|
513
|
-
)
|
514
|
-
torch.cuda.empty_cache()
|
515
|
-
layer.w2_weight = torch.nn.Parameter(
|
516
|
-
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
517
|
-
requires_grad=False,
|
518
|
-
)
|
519
|
-
torch.cuda.empty_cache()
|
520
|
-
return
|
521
|
-
|
522
|
-
# If checkpoint is fp8, we need to handle that the
|
523
|
-
# MoE kernels require single activation scale and single weight
|
524
|
-
# scale for w13 per expert.
|
525
|
-
else:
|
526
|
-
# Fp8 moe kernels require a single activation scale.
|
527
|
-
# We take the max of all the scales in case they differ.
|
528
|
-
if self.quant_config.activation_scheme == "static":
|
529
|
-
if layer.a13_scale is None or layer.a2_scale is None:
|
530
|
-
raise ValueError(
|
531
|
-
"QuantConfig has static quantization, but found "
|
532
|
-
"activation scales are None."
|
533
|
-
)
|
534
|
-
if not all_close_1d(layer.a13_scale) or not all_close_1d(
|
535
|
-
layer.a2_scale
|
536
|
-
):
|
537
|
-
print_warning_once(
|
538
|
-
"Found input_scales that are not equal for "
|
539
|
-
"fp8 MoE layer. Using the maximum across experts "
|
540
|
-
"for each layer. "
|
541
|
-
)
|
542
|
-
layer.a13_scale = torch.nn.Parameter(
|
543
|
-
layer.a13_scale.max(), requires_grad=False
|
544
|
-
)
|
545
|
-
layer.a2_scale = torch.nn.Parameter(
|
546
|
-
layer.a2_scale.max(), requires_grad=False
|
547
|
-
)
|
548
|
-
|
549
|
-
# If ROCm, normalize the weights and scales to e4m3fnuz
|
550
|
-
if is_hip():
|
551
|
-
# Normalize the weights and scales
|
552
|
-
w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
|
553
|
-
layer.w13_weight, layer.w13_scale, layer.a13_scale
|
554
|
-
)
|
555
|
-
w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
|
556
|
-
layer.w2_weight, layer.w2_scale, layer.a2_scale
|
557
|
-
)
|
558
|
-
# Reset the parameters
|
559
|
-
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
560
|
-
layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
|
561
|
-
if a13_scale is not None:
|
562
|
-
layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
|
563
|
-
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
564
|
-
layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
|
565
|
-
if a2_scale is not None:
|
566
|
-
layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
|
567
|
-
|
568
|
-
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
569
|
-
# We take the max then dequant and requant each expert.
|
570
|
-
assert layer.w13_scale is not None
|
571
|
-
shard_size = layer.intermediate_size_per_partition
|
572
|
-
max_w13_scales = layer.w13_scale.max(dim=1).values
|
573
|
-
for expert_id in range(layer.num_experts):
|
574
|
-
start = 0
|
575
|
-
for shard_id in range(2):
|
576
|
-
dq_weight = per_tensor_dequantize(
|
577
|
-
layer.w13_weight[expert_id][start : start + shard_size, :],
|
578
|
-
layer.w13_scale[expert_id][shard_id],
|
579
|
-
)
|
580
|
-
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
581
|
-
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
582
|
-
)
|
583
|
-
start += shard_size
|
584
|
-
|
585
|
-
layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
586
|
-
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
587
|
-
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
588
|
-
layer.w13_weight = torch.nn.Parameter(
|
589
|
-
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
590
|
-
requires_grad=False,
|
591
|
-
)
|
592
|
-
torch.cuda.empty_cache()
|
593
|
-
layer.w2_weight = torch.nn.Parameter(
|
594
|
-
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
595
|
-
requires_grad=False,
|
596
|
-
)
|
597
|
-
torch.cuda.empty_cache()
|
598
|
-
return
|
599
|
-
|
600
|
-
def apply(
|
601
|
-
self,
|
602
|
-
layer: torch.nn.Module,
|
603
|
-
x: torch.Tensor,
|
604
|
-
router_logits: torch.Tensor,
|
605
|
-
top_k: int,
|
606
|
-
renormalize: bool = True,
|
607
|
-
use_grouped_topk: bool = False,
|
608
|
-
num_expert_group: Optional[int] = None,
|
609
|
-
topk_group: Optional[int] = None,
|
610
|
-
) -> torch.Tensor:
|
611
|
-
|
612
|
-
from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe
|
613
|
-
|
614
|
-
return fused_moe(
|
615
|
-
x,
|
616
|
-
layer.w13_weight,
|
617
|
-
layer.w2_weight,
|
618
|
-
router_logits,
|
619
|
-
top_k,
|
620
|
-
renormalize=renormalize,
|
621
|
-
inplace=True,
|
622
|
-
use_fp8=True,
|
623
|
-
w1_scale=layer.w13_scale,
|
624
|
-
w2_scale=layer.w2_scale,
|
625
|
-
a1_scale=layer.a13_scale,
|
626
|
-
a2_scale=layer.a2_scale,
|
627
|
-
use_grouped_topk=use_grouped_topk,
|
628
|
-
num_expert_group=num_expert_group,
|
629
|
-
topk_group=topk_group,
|
630
|
-
)
|
File without changes
|
File without changes
|
File without changes
|