sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +1 -0
- sglang/bench_serving.py +9 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +51 -5
- sglang/srt/layers/attention/triton_backend.py +16 -25
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +5 -50
- sglang/srt/layers/quantization/fp8.py +221 -36
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/radix_attention.py +8 -1
- sglang/srt/layers/sampler.py +27 -5
- sglang/srt/layers/torchao_utils.py +31 -0
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +54 -34
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +171 -136
- sglang/srt/managers/tokenizer_manager.py +184 -133
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -11
- sglang/srt/model_executor/model_runner.py +28 -14
- sglang/srt/model_parallel.py +66 -5
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +73 -9
- sglang/srt/models/llama.py +22 -0
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +8 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/server.py +2 -1
- sglang/srt/server_args.py +19 -9
- sglang/srt/utils.py +40 -54
- sglang/test/test_block_fp8.py +341 -0
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,15 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
2
|
|
3
3
|
import logging
|
4
|
+
import os
|
4
5
|
from typing import Any, Callable, Dict, List, Optional
|
5
6
|
|
6
7
|
import torch
|
8
|
+
import torch.nn.functional as F
|
7
9
|
from torch.nn import Module
|
8
10
|
from torch.nn.parameter import Parameter
|
9
11
|
from vllm import _custom_ops as ops
|
12
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
10
13
|
from vllm.model_executor.layers.linear import LinearBase
|
11
14
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
12
15
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
@@ -24,17 +27,17 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
24
27
|
)
|
25
28
|
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
26
29
|
|
27
|
-
from sglang.srt.layers.fused_moe_triton import (
|
28
|
-
FusedMoE,
|
29
|
-
FusedMoEMethodBase,
|
30
|
-
FusedMoeWeightScaleSupported,
|
31
|
-
)
|
32
30
|
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
31
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
|
33
32
|
from sglang.srt.layers.quantization.base_config import (
|
34
33
|
QuantizationConfig,
|
35
34
|
QuantizeMethodBase,
|
36
35
|
)
|
37
|
-
from sglang.srt.layers.quantization.fp8_utils import
|
36
|
+
from sglang.srt.layers.quantization.fp8_utils import (
|
37
|
+
BlockQuantScaleParameter,
|
38
|
+
apply_w8a8_block_fp8_linear,
|
39
|
+
normalize_e4m3fn_to_e4m3fnuz,
|
40
|
+
)
|
38
41
|
from sglang.srt.utils import (
|
39
42
|
get_bool_env_var,
|
40
43
|
is_hip,
|
@@ -55,6 +58,7 @@ class Fp8Config(QuantizationConfig):
|
|
55
58
|
is_checkpoint_fp8_serialized: bool = False,
|
56
59
|
activation_scheme: str = "dynamic",
|
57
60
|
ignored_layers: Optional[List[str]] = None,
|
61
|
+
weight_block_size: List[int] = None,
|
58
62
|
) -> None:
|
59
63
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
60
64
|
if is_checkpoint_fp8_serialized:
|
@@ -66,6 +70,20 @@ class Fp8Config(QuantizationConfig):
|
|
66
70
|
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
|
67
71
|
self.activation_scheme = activation_scheme
|
68
72
|
self.ignored_layers = ignored_layers or []
|
73
|
+
if weight_block_size is not None:
|
74
|
+
if not is_checkpoint_fp8_serialized:
|
75
|
+
raise ValueError(
|
76
|
+
f"The block-wise quantization only supports fp8-serialized checkpoint for now."
|
77
|
+
)
|
78
|
+
if len(weight_block_size) != 2:
|
79
|
+
raise ValueError(
|
80
|
+
f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions."
|
81
|
+
)
|
82
|
+
if activation_scheme != "dynamic":
|
83
|
+
raise ValueError(
|
84
|
+
f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme."
|
85
|
+
)
|
86
|
+
self.weight_block_size = weight_block_size
|
69
87
|
|
70
88
|
@classmethod
|
71
89
|
def get_name(cls) -> str:
|
@@ -89,10 +107,12 @@ class Fp8Config(QuantizationConfig):
|
|
89
107
|
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
90
108
|
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
91
109
|
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
110
|
+
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
|
92
111
|
return cls(
|
93
112
|
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
94
113
|
activation_scheme=activation_scheme,
|
95
114
|
ignored_layers=ignored_layers,
|
115
|
+
weight_block_size=weight_block_size,
|
96
116
|
)
|
97
117
|
|
98
118
|
def get_quant_method(
|
@@ -100,6 +120,8 @@ class Fp8Config(QuantizationConfig):
|
|
100
120
|
) -> Optional["QuantizeMethodBase"]:
|
101
121
|
from vllm.attention.layer import Attention # Avoid circular import
|
102
122
|
|
123
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
124
|
+
|
103
125
|
if isinstance(layer, LinearBase):
|
104
126
|
if is_layer_skipped(prefix, self.ignored_layers):
|
105
127
|
return UnquantizedLinearMethod()
|
@@ -143,6 +165,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
143
165
|
if is_hip():
|
144
166
|
self.use_marlin = False
|
145
167
|
|
168
|
+
self.block_quant = self.quant_config.weight_block_size is not None
|
169
|
+
if self.block_quant:
|
170
|
+
# Marlin doesn't support block-wise fp8
|
171
|
+
self.use_marlin = False
|
172
|
+
|
146
173
|
def create_weights(
|
147
174
|
self,
|
148
175
|
layer: torch.nn.Module,
|
@@ -153,10 +180,35 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
153
180
|
params_dtype: torch.dtype,
|
154
181
|
**extra_weight_attrs,
|
155
182
|
):
|
156
|
-
del input_size, output_size
|
157
183
|
output_size_per_partition = sum(output_partition_sizes)
|
158
184
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
159
185
|
|
186
|
+
tp_size = get_tensor_model_parallel_world_size()
|
187
|
+
if self.block_quant:
|
188
|
+
block_n, block_k = (
|
189
|
+
self.quant_config.weight_block_size[0],
|
190
|
+
self.quant_config.weight_block_size[1],
|
191
|
+
)
|
192
|
+
# Required by row parallel
|
193
|
+
if tp_size > 1 and input_size // input_size_per_partition == tp_size:
|
194
|
+
if input_size_per_partition % block_k != 0:
|
195
|
+
raise ValueError(
|
196
|
+
f"Weight input_size_per_partition = "
|
197
|
+
f"{input_size_per_partition} is not divisible by "
|
198
|
+
f"weight quantization block_k = {block_k}."
|
199
|
+
)
|
200
|
+
# Required by collum parallel or enabling merged weights
|
201
|
+
if (
|
202
|
+
tp_size > 1 and output_size // output_size_per_partition == tp_size
|
203
|
+
) or len(output_partition_sizes) > 1:
|
204
|
+
for output_partition_size in output_partition_sizes:
|
205
|
+
if output_partition_size % block_n != 0:
|
206
|
+
raise ValueError(
|
207
|
+
f"Weight output_partition_size = "
|
208
|
+
f"{output_partition_size} is not divisible by "
|
209
|
+
f"weight quantization block_n = {block_n}."
|
210
|
+
)
|
211
|
+
|
160
212
|
layer.logical_widths = output_partition_sizes
|
161
213
|
|
162
214
|
layer.input_size_per_partition = input_size_per_partition
|
@@ -184,13 +236,27 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
184
236
|
# Otherwise, wait until process_weights_after_loading.
|
185
237
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
186
238
|
# WEIGHT SCALE
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
239
|
+
if self.block_quant:
|
240
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
241
|
+
scale = BlockQuantScaleParameter(
|
242
|
+
data=torch.empty(
|
243
|
+
(output_size_per_partition + block_n - 1) // block_n,
|
244
|
+
(input_size_per_partition + block_k - 1) // block_k,
|
245
|
+
dtype=torch.float32,
|
246
|
+
),
|
247
|
+
input_dim=1,
|
248
|
+
output_dim=0,
|
249
|
+
weight_loader=weight_loader,
|
250
|
+
)
|
251
|
+
scale[:] = torch.finfo(torch.float32).min
|
252
|
+
layer.register_parameter("weight_scale_inv", scale)
|
253
|
+
else:
|
254
|
+
scale = PerTensorScaleParameter(
|
255
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
256
|
+
weight_loader=weight_loader,
|
257
|
+
)
|
258
|
+
scale[:] = torch.finfo(torch.float32).min
|
259
|
+
layer.register_parameter("weight_scale", scale)
|
194
260
|
|
195
261
|
# INPUT ACTIVATION SCALE
|
196
262
|
if self.quant_config.activation_scheme == "static":
|
@@ -205,6 +271,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
205
271
|
layer.register_parameter("input_scale", None)
|
206
272
|
|
207
273
|
def process_weights_after_loading(self, layer: Module) -> None:
|
274
|
+
# Block quant doesn't need to process weights after loading
|
275
|
+
if self.block_quant:
|
276
|
+
return
|
208
277
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
209
278
|
# If checkpoint not serialized fp8, quantize the weights.
|
210
279
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
@@ -295,6 +364,16 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
295
364
|
bias=bias,
|
296
365
|
)
|
297
366
|
|
367
|
+
if self.block_quant:
|
368
|
+
return apply_w8a8_block_fp8_linear(
|
369
|
+
input=x,
|
370
|
+
weight=layer.weight,
|
371
|
+
block_size=self.quant_config.weight_block_size,
|
372
|
+
weight_scale=layer.weight_scale_inv,
|
373
|
+
input_scale=layer.input_scale,
|
374
|
+
bias=bias,
|
375
|
+
)
|
376
|
+
|
298
377
|
return apply_fp8_linear(
|
299
378
|
input=x,
|
300
379
|
weight=layer.weight,
|
@@ -306,7 +385,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
306
385
|
)
|
307
386
|
|
308
387
|
|
309
|
-
class Fp8MoEMethod
|
388
|
+
class Fp8MoEMethod:
|
310
389
|
"""MoE method for FP8.
|
311
390
|
Supports loading FP8 checkpoints with static weight scale and
|
312
391
|
dynamic/static activation scale.
|
@@ -319,8 +398,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
319
398
|
quant_config: The quantization config.
|
320
399
|
"""
|
321
400
|
|
322
|
-
def
|
401
|
+
def __new__(cls, *args, **kwargs):
|
402
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
403
|
+
|
404
|
+
if not hasattr(cls, "_initialized"):
|
405
|
+
original_init = cls.__init__
|
406
|
+
new_cls = type(
|
407
|
+
cls.__name__,
|
408
|
+
(FusedMoEMethodBase,),
|
409
|
+
{
|
410
|
+
"__init__": original_init,
|
411
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
412
|
+
},
|
413
|
+
)
|
414
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
415
|
+
obj.__init__(*args, **kwargs)
|
416
|
+
return obj
|
417
|
+
return super().__new__(cls)
|
418
|
+
|
419
|
+
def __init__(self, quant_config):
|
323
420
|
self.quant_config = quant_config
|
421
|
+
self.block_quant = self.quant_config.weight_block_size is not None
|
324
422
|
|
325
423
|
def create_weights(
|
326
424
|
self,
|
@@ -331,9 +429,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
331
429
|
params_dtype: torch.dtype,
|
332
430
|
**extra_weight_attrs,
|
333
431
|
):
|
432
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
334
433
|
|
335
434
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
336
435
|
params_dtype = torch.float8_e4m3fn
|
436
|
+
tp_size = get_tensor_model_parallel_world_size()
|
437
|
+
if self.block_quant:
|
438
|
+
block_n, block_k = (
|
439
|
+
self.quant_config.weight_block_size[0],
|
440
|
+
self.quant_config.weight_block_size[1],
|
441
|
+
)
|
442
|
+
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
443
|
+
# Required by collum parallel or enabling merged weights
|
444
|
+
if intermediate_size % block_n != 0:
|
445
|
+
raise ValueError(
|
446
|
+
f"The output_size of gate's and up's weight = "
|
447
|
+
f"{intermediate_size} is not divisible by "
|
448
|
+
f"weight quantization block_n = {block_n}."
|
449
|
+
)
|
450
|
+
if tp_size > 1:
|
451
|
+
# Required by row parallel
|
452
|
+
if intermediate_size % block_k != 0:
|
453
|
+
raise ValueError(
|
454
|
+
f"The input_size of down's weight = "
|
455
|
+
f"{intermediate_size} is not divisible by "
|
456
|
+
f"weight quantization block_k = {block_k}."
|
457
|
+
)
|
337
458
|
|
338
459
|
# WEIGHTS
|
339
460
|
w13_weight = torch.nn.Parameter(
|
@@ -355,21 +476,45 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
355
476
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
356
477
|
|
357
478
|
# WEIGHT_SCALES
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
479
|
+
if self.block_quant:
|
480
|
+
w13_weight_scale = torch.nn.Parameter(
|
481
|
+
torch.ones(
|
482
|
+
num_experts,
|
483
|
+
2 * ((intermediate_size + block_n - 1) // block_n),
|
484
|
+
(hidden_size + block_k - 1) // block_k,
|
485
|
+
dtype=torch.float32,
|
486
|
+
),
|
487
|
+
requires_grad=False,
|
488
|
+
)
|
489
|
+
w2_weight_scale = torch.nn.Parameter(
|
490
|
+
torch.ones(
|
491
|
+
num_experts,
|
492
|
+
(hidden_size + block_n - 1) // block_n,
|
493
|
+
(intermediate_size + block_k - 1) // block_k,
|
494
|
+
dtype=torch.float32,
|
495
|
+
),
|
496
|
+
requires_grad=False,
|
497
|
+
)
|
498
|
+
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
499
|
+
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
500
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
501
|
+
else:
|
502
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
503
|
+
# They will be combined to a single scale after weight loading.
|
504
|
+
w13_weight_scale = torch.nn.Parameter(
|
505
|
+
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
506
|
+
)
|
507
|
+
w2_weight_scale = torch.nn.Parameter(
|
508
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
509
|
+
)
|
510
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
511
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
369
512
|
# Add the quantization method used (per tensor/grouped/channel)
|
370
513
|
# to ensure the weight scales are loaded in properly
|
371
514
|
extra_weight_attrs.update(
|
372
|
-
{"quant_method": FusedMoeWeightScaleSupported.
|
515
|
+
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
516
|
+
if self.block_quant
|
517
|
+
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
373
518
|
)
|
374
519
|
# If loading fp8 checkpoint, pass the weight loaders.
|
375
520
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
@@ -403,8 +548,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
403
548
|
layer.w2_input_scale = None
|
404
549
|
|
405
550
|
def process_weights_after_loading(self, layer: Module) -> None:
|
406
|
-
|
407
|
-
|
551
|
+
# Block quant doesn't need to process weights after loading
|
552
|
+
if self.block_quant:
|
553
|
+
return
|
554
|
+
# If checkpoint is fp16 or bfloat16, quantize in place.
|
408
555
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
409
556
|
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
410
557
|
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
@@ -428,6 +575,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
428
575
|
)
|
429
576
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
430
577
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
578
|
+
|
579
|
+
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
580
|
+
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
581
|
+
layer.w13_weight = torch.nn.Parameter(
|
582
|
+
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
583
|
+
requires_grad=False,
|
584
|
+
)
|
585
|
+
torch.cuda.empty_cache()
|
586
|
+
layer.w2_weight = torch.nn.Parameter(
|
587
|
+
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
588
|
+
requires_grad=False,
|
589
|
+
)
|
590
|
+
torch.cuda.empty_cache()
|
431
591
|
return
|
432
592
|
|
433
593
|
# If checkpoint is fp8, we need to handle that the
|
@@ -456,6 +616,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
456
616
|
layer.w2_input_scale = torch.nn.Parameter(
|
457
617
|
layer.w2_input_scale.max(), requires_grad=False
|
458
618
|
)
|
619
|
+
|
459
620
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
460
621
|
if is_hip():
|
461
622
|
# Normalize the weights and scales
|
@@ -486,7 +647,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
486
647
|
layer.w2_input_scale = torch.nn.Parameter(
|
487
648
|
w2_input_scale, requires_grad=False
|
488
649
|
)
|
489
|
-
|
490
650
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
491
651
|
# We take the max then dequant and requant each expert.
|
492
652
|
assert layer.w13_weight_scale is not None
|
@@ -507,6 +667,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
507
667
|
layer.w13_weight_scale = torch.nn.Parameter(
|
508
668
|
max_w13_scales, requires_grad=False
|
509
669
|
)
|
670
|
+
|
671
|
+
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
672
|
+
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
673
|
+
layer.w13_weight = torch.nn.Parameter(
|
674
|
+
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
675
|
+
requires_grad=False,
|
676
|
+
)
|
677
|
+
torch.cuda.empty_cache()
|
678
|
+
layer.w2_weight = torch.nn.Parameter(
|
679
|
+
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
680
|
+
requires_grad=False,
|
681
|
+
)
|
682
|
+
torch.cuda.empty_cache()
|
510
683
|
return
|
511
684
|
|
512
685
|
def apply(
|
@@ -520,11 +693,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
520
693
|
topk_group: Optional[int] = None,
|
521
694
|
num_expert_group: Optional[int] = None,
|
522
695
|
custom_routing_function: Optional[Callable] = None,
|
696
|
+
correction_bias: Optional[torch.Tensor] = None,
|
523
697
|
) -> torch.Tensor:
|
698
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
699
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
700
|
+
from sglang.srt.layers.moe.topk import select_experts
|
524
701
|
|
525
|
-
|
526
|
-
|
527
|
-
topk_weights, topk_ids = FusedMoE.select_experts(
|
702
|
+
# Expert selection
|
703
|
+
topk_weights, topk_ids = select_experts(
|
528
704
|
hidden_states=x,
|
529
705
|
router_logits=router_logits,
|
530
706
|
use_grouped_topk=use_grouped_topk,
|
@@ -533,8 +709,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
533
709
|
topk_group=topk_group,
|
534
710
|
num_expert_group=num_expert_group,
|
535
711
|
custom_routing_function=custom_routing_function,
|
712
|
+
correction_bias=correction_bias,
|
536
713
|
)
|
537
714
|
|
715
|
+
# Expert fusion with FP8 quantization
|
538
716
|
return fused_experts(
|
539
717
|
x,
|
540
718
|
layer.w13_weight,
|
@@ -543,10 +721,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
543
721
|
topk_ids=topk_ids,
|
544
722
|
inplace=True,
|
545
723
|
use_fp8_w8a8=True,
|
546
|
-
w1_scale=
|
547
|
-
|
724
|
+
w1_scale=(
|
725
|
+
layer.w13_weight_scale_inv
|
726
|
+
if self.block_quant
|
727
|
+
else layer.w13_weight_scale
|
728
|
+
),
|
729
|
+
w2_scale=(
|
730
|
+
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
731
|
+
),
|
548
732
|
a1_scale=layer.w13_input_scale,
|
549
733
|
a2_scale=layer.w2_input_scale,
|
734
|
+
block_shape=self.quant_config.weight_block_size,
|
550
735
|
)
|
551
736
|
|
552
737
|
|