sglang 0.3.6.post2__py3-none-any.whl → 0.3.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,630 +0,0 @@
1
- # Adapted from
2
- # https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
3
- import os
4
- from abc import abstractmethod
5
- from typing import List, Optional, Tuple
6
-
7
- import torch
8
- import torch.nn.functional as F
9
- from vllm.distributed import (
10
- get_tensor_model_parallel_rank,
11
- get_tensor_model_parallel_world_size,
12
- tensor_model_parallel_all_reduce,
13
- )
14
- from vllm.logger import init_logger
15
- from vllm.model_executor.custom_op import CustomOp
16
- from vllm.model_executor.layers.quantization.base_config import (
17
- QuantizationConfig,
18
- QuantizeMethodBase,
19
- )
20
- from vllm.model_executor.layers.quantization.fp8 import Fp8Config
21
- from vllm.model_executor.utils import set_weight_attrs
22
-
23
- from sglang.srt.layers.fused_moe_grok.fused_moe import padding_size
24
- from sglang.srt.utils import is_hip
25
-
26
- logger = init_logger(__name__)
27
-
28
-
29
- class FusedMoEMethodBase(QuantizeMethodBase):
30
-
31
- @abstractmethod
32
- def create_weights(
33
- self,
34
- layer: torch.nn.Module,
35
- num_experts: int,
36
- hidden_size: int,
37
- intermediate_size: int,
38
- params_dtype: torch.dtype,
39
- **extra_weight_attrs,
40
- ):
41
- raise NotImplementedError
42
-
43
- @abstractmethod
44
- def apply(
45
- self,
46
- layer: torch.nn.Module,
47
- x: torch.Tensor,
48
- router_logits: torch.Tensor,
49
- top_k: int,
50
- renormalize: bool = True,
51
- use_grouped_topk: bool = False,
52
- num_expert_group: Optional[int] = None,
53
- topk_group: Optional[int] = None,
54
- ) -> torch.Tensor:
55
- raise NotImplementedError
56
-
57
-
58
- class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
59
- """MoE method without quantization."""
60
-
61
- def create_weights(
62
- self,
63
- layer: torch.nn.Module,
64
- num_experts: int,
65
- hidden_size: int,
66
- intermediate_size: int,
67
- params_dtype: torch.dtype,
68
- **extra_weight_attrs,
69
- ):
70
-
71
- # Fused gate_up_proj (column parallel)
72
- w13_weight = torch.nn.Parameter(
73
- torch.empty(
74
- num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
75
- ),
76
- requires_grad=False,
77
- )
78
- layer.register_parameter("w13_weight", w13_weight)
79
- set_weight_attrs(w13_weight, extra_weight_attrs)
80
-
81
- # down_proj (row parallel)
82
- w2_weight = torch.nn.Parameter(
83
- torch.empty(
84
- num_experts, hidden_size, intermediate_size, dtype=params_dtype
85
- ),
86
- requires_grad=False,
87
- )
88
- layer.register_parameter("w2_weight", w2_weight)
89
- set_weight_attrs(w2_weight, extra_weight_attrs)
90
-
91
- def apply(
92
- self,
93
- layer: torch.nn.Module,
94
- x: torch.Tensor,
95
- router_logits: torch.Tensor,
96
- top_k: int,
97
- renormalize: bool = True,
98
- use_grouped_topk: bool = False,
99
- num_expert_group: Optional[int] = None,
100
- topk_group: Optional[int] = None,
101
- ) -> torch.Tensor:
102
- return self.forward(
103
- x,
104
- layer.w13_weight,
105
- layer.w2_weight,
106
- router_logits,
107
- top_k,
108
- renormalize,
109
- use_grouped_topk,
110
- num_expert_group,
111
- topk_group,
112
- )
113
-
114
- def forward_cuda(
115
- self,
116
- x: torch.Tensor,
117
- w1: torch.Tensor,
118
- w2: torch.Tensor,
119
- router_logits: torch.Tensor,
120
- top_k: int,
121
- renormalize: bool,
122
- use_grouped_topk: bool,
123
- num_expert_group: Optional[int],
124
- topk_group: Optional[int],
125
- ) -> torch.Tensor:
126
- from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe
127
-
128
- return fused_moe(
129
- x,
130
- w1,
131
- w2,
132
- router_logits,
133
- top_k,
134
- renormalize=renormalize,
135
- inplace=True,
136
- use_grouped_topk=use_grouped_topk,
137
- num_expert_group=num_expert_group,
138
- topk_group=topk_group,
139
- )
140
-
141
- def forward_cpu(self, *args, **kwargs):
142
- raise NotImplementedError("The CPU backend currently does not support MoE.")
143
-
144
- def forward_tpu(
145
- self,
146
- x: torch.Tensor,
147
- w1: torch.Tensor,
148
- w2: torch.Tensor,
149
- router_logits: torch.Tensor,
150
- top_k: int,
151
- renormalize: bool,
152
- use_grouped_topk: bool,
153
- num_expert_group: Optional[int],
154
- topk_group: Optional[int],
155
- ) -> torch.Tensor:
156
- raise NotImplementedError("The TPU backend currently does not support MoE.")
157
-
158
-
159
- class FusedMoE(torch.nn.Module):
160
- """FusedMoE layer for MoE models.
161
-
162
- This layer contains both MergedColumnParallel weights (gate_up_proj /
163
- w13) and RowParallelLinear weights (down_proj/ w2).
164
-
165
- Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
166
- copy that naming convention here and handle any remapping in the
167
- load_weights function in each model implementation.
168
-
169
- Args:
170
- num_experts: Number of experts in the model
171
- top_k: Number of experts selected for each token
172
- hidden_size: Input hidden state size of the transformer
173
- intermediate_size: Intermediate size of the experts
174
- params_dtype: Data type for the parameters.
175
- reduce_results: Whether to all all_reduce on the output of the layer
176
- renomalize: Whether to renormalize the logits in the fused_moe kernel
177
- quant_config: Quantization configure.
178
- """
179
-
180
- def __init__(
181
- self,
182
- num_experts: int,
183
- top_k: int,
184
- hidden_size: int,
185
- intermediate_size: int,
186
- params_dtype: Optional[torch.dtype] = None,
187
- reduce_results: bool = False,
188
- renormalize: bool = True,
189
- use_grouped_topk: bool = False,
190
- num_expert_group: Optional[int] = None,
191
- topk_group: Optional[int] = None,
192
- quant_config: Optional[QuantizationConfig] = None,
193
- tp_size: Optional[int] = None,
194
- prefix: str = "",
195
- ):
196
- super().__init__()
197
-
198
- if params_dtype is None:
199
- params_dtype = torch.get_default_dtype()
200
-
201
- self.tp_size = (
202
- tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
203
- )
204
- self.top_k = top_k
205
- self.num_experts = num_experts
206
- self.intermediate_size_per_partition = intermediate_size // self.tp_size
207
- self.reduce_results = reduce_results
208
- self.renormalize = renormalize
209
- self.use_grouped_topk = use_grouped_topk
210
- if self.use_grouped_topk:
211
- assert num_expert_group is not None and topk_group is not None
212
- self.num_expert_group = num_expert_group
213
- self.topk_group = topk_group
214
-
215
- if quant_config is None:
216
- self.quant_method: Optional[QuantizeMethodBase] = (
217
- UnquantizedFusedMoEMethod()
218
- )
219
- else:
220
- if isinstance(quant_config, Fp8Config):
221
- self.quant_method = Fp8MoEMethod(quant_config)
222
- else:
223
- self.quant_method = quant_config.get_quant_method(self, prefix)
224
- assert self.quant_method is not None
225
-
226
- self.quant_method.create_weights(
227
- layer=self,
228
- num_experts=num_experts,
229
- hidden_size=hidden_size,
230
- intermediate_size=self.intermediate_size_per_partition,
231
- params_dtype=params_dtype,
232
- weight_loader=self.weight_loader,
233
- )
234
-
235
- def weight_loader(
236
- self,
237
- param: torch.nn.Parameter,
238
- loaded_weight: torch.Tensor,
239
- weight_name: str,
240
- shard_id: int,
241
- expert_id: int,
242
- use_presharded_weights: bool = False,
243
- ):
244
- param_data = param.data
245
-
246
- # Input scales can be loaded directly and should be equal.
247
- if "input_scale" in weight_name:
248
- if (
249
- param_data[expert_id] != 1
250
- and (param_data[expert_id] - loaded_weight).abs() > 1e-5
251
- ):
252
- raise ValueError(
253
- "input_scales of w1 and w3 of a layer "
254
- f"must be equal. But got {param_data[expert_id]} "
255
- f"vs. {loaded_weight}"
256
- )
257
- param_data[expert_id] = loaded_weight
258
- # Weight scales
259
- elif "weight_scale" in weight_name:
260
- # If we are in merged column case (gate_up_proj)
261
- # shard_id 0 == gate_proj / w1
262
- # shard_id 2 == up_proj / w3
263
- if shard_id == 0 or shard_id == 2:
264
- # We have to keep the weight scales of w1 and w3 because
265
- # we need to re-quantize w1/w3 weights after weight loading.
266
- idx = 0 if shard_id == 0 else 1
267
- param_data[expert_id][idx] = loaded_weight
268
- # If we are in the row parallel case (down_proj)
269
- # shard_id 1 == down_proj / w2
270
- else:
271
- param_data[expert_id] = loaded_weight
272
- # Weights
273
- else:
274
- tp_rank = get_tensor_model_parallel_rank()
275
- shard_size = self.intermediate_size_per_partition
276
- if use_presharded_weights:
277
- shard = slice(None)
278
- else:
279
- shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
280
-
281
- # w1, gate_proj case: Load into first shard of w13.
282
- if shard_id == 0:
283
- param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
284
- # w3, up_proj case: Load into second shard of w13.
285
- elif shard_id == 2:
286
- param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
287
- shard, :
288
- ]
289
- # w2, down_proj case: Load into only shard of w2.
290
- elif shard_id == 1:
291
- param_data[expert_id, :, :] = loaded_weight[:, shard]
292
- else:
293
- raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}")
294
-
295
- def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
296
- assert self.quant_method is not None
297
-
298
- # Matrix multiply.
299
- final_hidden_states = self.quant_method.apply(
300
- self,
301
- x=hidden_states,
302
- router_logits=router_logits,
303
- top_k=self.top_k,
304
- renormalize=self.renormalize,
305
- use_grouped_topk=self.use_grouped_topk,
306
- num_expert_group=self.num_expert_group,
307
- topk_group=self.topk_group,
308
- )
309
-
310
- if self.reduce_results and self.tp_size > 1:
311
- final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
312
-
313
- return final_hidden_states
314
-
315
- @classmethod
316
- def make_expert_params_mapping(
317
- cls,
318
- ckpt_gate_proj_name: str,
319
- ckpt_down_proj_name: str,
320
- ckpt_up_proj_name: str,
321
- num_experts: int,
322
- ) -> List[Tuple[str, str, int, int]]:
323
-
324
- gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
325
- gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name]
326
-
327
- return (
328
- [
329
- # These are the weight scales for the experts
330
- # (param_name, weight_name, expert_id, shard_id)
331
- (
332
- (
333
- "experts.w13_scale"
334
- if weight_name in gate_up
335
- else "experts.w2_scale"
336
- ),
337
- f"experts.{expert_id}.{weight_name}.weight_scale",
338
- expert_id,
339
- shard_id,
340
- )
341
- for expert_id in range(num_experts)
342
- for shard_id, weight_name in enumerate(gate_down_up)
343
- ]
344
- + [
345
- # These are the weights for the experts
346
- # (param_name, weight_name, expert_id, shard_id)
347
- (
348
- (
349
- "experts.w13_weight"
350
- if weight_name in gate_up
351
- else "experts.w2_weight"
352
- ),
353
- f"experts.{expert_id}.{weight_name}.weight",
354
- expert_id,
355
- shard_id,
356
- )
357
- for expert_id in range(num_experts)
358
- for shard_id, weight_name in enumerate(gate_down_up)
359
- ]
360
- + [
361
- # These are the weight scales for the experts
362
- # (param_name, weight_name, expert_id, shard_id)
363
- (
364
- (
365
- "experts.a13_scale"
366
- if weight_name in gate_up
367
- else "experts.a2_scale"
368
- ),
369
- f"experts.{expert_id}.{weight_name}.input_scale",
370
- expert_id,
371
- shard_id,
372
- )
373
- for expert_id in range(num_experts)
374
- for shard_id, weight_name in enumerate(gate_down_up)
375
- ]
376
- )
377
-
378
-
379
- import torch
380
- from torch.nn import Module
381
- from vllm import _custom_ops as ops
382
- from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
383
- all_close_1d,
384
- normalize_e4m3fn_to_e4m3fnuz,
385
- per_tensor_dequantize,
386
- )
387
- from vllm.utils import print_warning_once
388
-
389
-
390
- class Fp8MoEMethod(FusedMoEMethodBase):
391
- """MoE method for FP8.
392
- Supports loading FP8 checkpoints with static weight scale and
393
- dynamic/static activation scale.
394
-
395
- Also supports loading quantized FP16/BF16 model checkpoints with dynamic
396
- activation scaling. The weight scaling factor will be initialized after
397
- the model weights are loaded.
398
-
399
- Args:
400
- quant_config: The quantization config.
401
- """
402
-
403
- def __init__(self, quant_config: Fp8Config):
404
- self.quant_config = quant_config
405
-
406
- def create_weights(
407
- self,
408
- layer: Module,
409
- num_experts: int,
410
- hidden_size: int,
411
- intermediate_size: int,
412
- params_dtype: torch.dtype,
413
- **extra_weight_attrs,
414
- ):
415
-
416
- if self.quant_config.is_checkpoint_fp8_serialized:
417
- params_dtype = torch.float8_e4m3fn
418
-
419
- # WEIGHTS
420
- w13_weight = torch.nn.Parameter(
421
- torch.empty(
422
- num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
423
- ),
424
- requires_grad=False,
425
- )
426
- layer.register_parameter("w13_weight", w13_weight)
427
- set_weight_attrs(w13_weight, extra_weight_attrs)
428
-
429
- w2_weight = torch.nn.Parameter(
430
- torch.empty(
431
- num_experts, hidden_size, intermediate_size, dtype=params_dtype
432
- ),
433
- requires_grad=False,
434
- )
435
- layer.register_parameter("w2_weight", w2_weight)
436
- set_weight_attrs(w2_weight, extra_weight_attrs)
437
-
438
- # WEIGHT_SCALES
439
- # Allocate 2 scales for w1 and w3 respectively.
440
- # They will be combined to a single scale after weight loading.
441
- w13_scale = torch.nn.Parameter(
442
- torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
443
- )
444
- layer.register_parameter("w13_scale", w13_scale)
445
-
446
- w2_scale = torch.nn.Parameter(
447
- torch.ones(num_experts, dtype=torch.float32), requires_grad=False
448
- )
449
- layer.register_parameter("w2_scale", w2_scale)
450
-
451
- # If loading fp8 checkpoint, pass the weight loaders.
452
- # If loading an fp16 checkpoint, do not (we will quantize in
453
- # process_weights_after_loading()
454
- if self.quant_config.is_checkpoint_fp8_serialized:
455
- set_weight_attrs(w13_scale, extra_weight_attrs)
456
- set_weight_attrs(w2_scale, extra_weight_attrs)
457
-
458
- # INPUT_SCALES
459
- if self.quant_config.activation_scheme == "static":
460
- if not self.quant_config.is_checkpoint_fp8_serialized:
461
- raise ValueError(
462
- "Found static activation scheme for checkpoint that "
463
- "was not serialized fp8."
464
- )
465
-
466
- a13_scale = torch.nn.Parameter(
467
- torch.ones(num_experts, dtype=torch.float32), requires_grad=False
468
- )
469
- layer.register_parameter("a13_scale", a13_scale)
470
- set_weight_attrs(a13_scale, extra_weight_attrs)
471
-
472
- a2_scale = torch.nn.Parameter(
473
- torch.ones(num_experts, dtype=torch.float32), requires_grad=False
474
- )
475
- layer.register_parameter("a2_scale", a2_scale)
476
- set_weight_attrs(a2_scale, extra_weight_attrs)
477
- else:
478
- layer.a13_scale = None
479
- layer.a2_scale = None
480
-
481
- def process_weights_after_loading(self, layer: Module) -> None:
482
-
483
- # If checkpoint is fp16 or bfloat16, quantize in place.
484
- if not self.quant_config.is_checkpoint_fp8_serialized:
485
- # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
486
- fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
487
- w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
488
- w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
489
-
490
- # Re-initialize w13_scale because we directly quantize
491
- # merged w13 weights and generate a single scaling factor.
492
- layer.w13_scale = torch.nn.Parameter(
493
- torch.ones(
494
- layer.num_experts, dtype=torch.float32, device=w13_weight.device
495
- ),
496
- requires_grad=False,
497
- )
498
- for expert in range(layer.num_experts):
499
- w13_weight[expert, :, :], layer.w13_scale[expert] = (
500
- ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
501
- )
502
- w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant(
503
- layer.w2_weight.data[expert, :, :]
504
- )
505
- layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
506
- layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
507
-
508
- # If ROCm, apply weight padding (min. Mem channel contention) only if set
509
- if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
510
- layer.w13_weight = torch.nn.Parameter(
511
- F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
512
- requires_grad=False,
513
- )
514
- torch.cuda.empty_cache()
515
- layer.w2_weight = torch.nn.Parameter(
516
- F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
517
- requires_grad=False,
518
- )
519
- torch.cuda.empty_cache()
520
- return
521
-
522
- # If checkpoint is fp8, we need to handle that the
523
- # MoE kernels require single activation scale and single weight
524
- # scale for w13 per expert.
525
- else:
526
- # Fp8 moe kernels require a single activation scale.
527
- # We take the max of all the scales in case they differ.
528
- if self.quant_config.activation_scheme == "static":
529
- if layer.a13_scale is None or layer.a2_scale is None:
530
- raise ValueError(
531
- "QuantConfig has static quantization, but found "
532
- "activation scales are None."
533
- )
534
- if not all_close_1d(layer.a13_scale) or not all_close_1d(
535
- layer.a2_scale
536
- ):
537
- print_warning_once(
538
- "Found input_scales that are not equal for "
539
- "fp8 MoE layer. Using the maximum across experts "
540
- "for each layer. "
541
- )
542
- layer.a13_scale = torch.nn.Parameter(
543
- layer.a13_scale.max(), requires_grad=False
544
- )
545
- layer.a2_scale = torch.nn.Parameter(
546
- layer.a2_scale.max(), requires_grad=False
547
- )
548
-
549
- # If ROCm, normalize the weights and scales to e4m3fnuz
550
- if is_hip():
551
- # Normalize the weights and scales
552
- w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
553
- layer.w13_weight, layer.w13_scale, layer.a13_scale
554
- )
555
- w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
556
- layer.w2_weight, layer.w2_scale, layer.a2_scale
557
- )
558
- # Reset the parameters
559
- layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
560
- layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
561
- if a13_scale is not None:
562
- layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
563
- layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
564
- layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
565
- if a2_scale is not None:
566
- layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
567
-
568
- # Fp8 moe kernel needs single weight scale for w13 per expert.
569
- # We take the max then dequant and requant each expert.
570
- assert layer.w13_scale is not None
571
- shard_size = layer.intermediate_size_per_partition
572
- max_w13_scales = layer.w13_scale.max(dim=1).values
573
- for expert_id in range(layer.num_experts):
574
- start = 0
575
- for shard_id in range(2):
576
- dq_weight = per_tensor_dequantize(
577
- layer.w13_weight[expert_id][start : start + shard_size, :],
578
- layer.w13_scale[expert_id][shard_id],
579
- )
580
- layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
581
- ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
582
- )
583
- start += shard_size
584
-
585
- layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
586
- # If ROCm, apply weight padding (min. Mem channel contention) only if set
587
- if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
588
- layer.w13_weight = torch.nn.Parameter(
589
- F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
590
- requires_grad=False,
591
- )
592
- torch.cuda.empty_cache()
593
- layer.w2_weight = torch.nn.Parameter(
594
- F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
595
- requires_grad=False,
596
- )
597
- torch.cuda.empty_cache()
598
- return
599
-
600
- def apply(
601
- self,
602
- layer: torch.nn.Module,
603
- x: torch.Tensor,
604
- router_logits: torch.Tensor,
605
- top_k: int,
606
- renormalize: bool = True,
607
- use_grouped_topk: bool = False,
608
- num_expert_group: Optional[int] = None,
609
- topk_group: Optional[int] = None,
610
- ) -> torch.Tensor:
611
-
612
- from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe
613
-
614
- return fused_moe(
615
- x,
616
- layer.w13_weight,
617
- layer.w2_weight,
618
- router_logits,
619
- top_k,
620
- renormalize=renormalize,
621
- inplace=True,
622
- use_fp8=True,
623
- w1_scale=layer.w13_scale,
624
- w2_scale=layer.w2_scale,
625
- a1_scale=layer.a13_scale,
626
- a2_scale=layer.a2_scale,
627
- use_grouped_topk=use_grouped_topk,
628
- num_expert_group=num_expert_group,
629
- topk_group=topk_group,
630
- )