sglang 0.4.0__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +20 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/managers/schedule_batch.py +1 -0
- sglang/srt/managers/scheduler.py +69 -65
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +15 -1
- sglang/srt/model_executor/model_runner.py +11 -4
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/grok.py +0 -5
- sglang/srt/models/llama.py +0 -5
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +3 -3
- sglang/srt/server_args.py +43 -4
- sglang/srt/utils.py +50 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/RECORD +43 -38
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,559 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Any, Callable, Dict, List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch.nn import Module
|
8
|
+
from torch.nn.parameter import Parameter
|
9
|
+
from vllm import _custom_ops as ops
|
10
|
+
from vllm.model_executor.layers.linear import LinearBase
|
11
|
+
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
12
|
+
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
13
|
+
apply_fp8_marlin_linear,
|
14
|
+
prepare_fp8_layer_for_marlin,
|
15
|
+
)
|
16
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
17
|
+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
18
|
+
all_close_1d,
|
19
|
+
apply_fp8_linear,
|
20
|
+
convert_to_channelwise,
|
21
|
+
cutlass_fp8_supported,
|
22
|
+
per_tensor_dequantize,
|
23
|
+
requantize_with_max_scale,
|
24
|
+
)
|
25
|
+
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
26
|
+
|
27
|
+
from sglang.srt.layers.fused_moe_triton import (
|
28
|
+
FusedMoE,
|
29
|
+
FusedMoEMethodBase,
|
30
|
+
FusedMoeWeightScaleSupported,
|
31
|
+
)
|
32
|
+
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
33
|
+
from sglang.srt.layers.quantization.base_config import (
|
34
|
+
QuantizationConfig,
|
35
|
+
QuantizeMethodBase,
|
36
|
+
)
|
37
|
+
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
38
|
+
from sglang.srt.utils import (
|
39
|
+
get_bool_env_var,
|
40
|
+
is_hip,
|
41
|
+
print_warning_once,
|
42
|
+
set_weight_attrs,
|
43
|
+
)
|
44
|
+
|
45
|
+
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
46
|
+
|
47
|
+
logger = logging.getLogger(__name__)
|
48
|
+
|
49
|
+
|
50
|
+
class Fp8Config(QuantizationConfig):
|
51
|
+
"""Config class for FP8."""
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
is_checkpoint_fp8_serialized: bool = False,
|
56
|
+
activation_scheme: str = "dynamic",
|
57
|
+
ignored_layers: Optional[List[str]] = None,
|
58
|
+
) -> None:
|
59
|
+
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
60
|
+
if is_checkpoint_fp8_serialized:
|
61
|
+
logger.warning(
|
62
|
+
"Detected fp8 checkpoint. Please note that the "
|
63
|
+
"format is experimental and subject to change."
|
64
|
+
)
|
65
|
+
if activation_scheme not in ACTIVATION_SCHEMES:
|
66
|
+
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
|
67
|
+
self.activation_scheme = activation_scheme
|
68
|
+
self.ignored_layers = ignored_layers or []
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def get_name(cls) -> str:
|
72
|
+
return "fp8"
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
76
|
+
return [torch.bfloat16, torch.half]
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def get_min_capability(cls) -> int:
|
80
|
+
return 80
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def get_config_filenames(cls) -> List[str]:
|
84
|
+
return []
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
88
|
+
quant_method = cls.get_from_keys(config, ["quant_method"])
|
89
|
+
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
90
|
+
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
91
|
+
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
92
|
+
return cls(
|
93
|
+
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
94
|
+
activation_scheme=activation_scheme,
|
95
|
+
ignored_layers=ignored_layers,
|
96
|
+
)
|
97
|
+
|
98
|
+
def get_quant_method(
|
99
|
+
self, layer: torch.nn.Module, prefix: str
|
100
|
+
) -> Optional["QuantizeMethodBase"]:
|
101
|
+
from vllm.attention.layer import Attention # Avoid circular import
|
102
|
+
|
103
|
+
if isinstance(layer, LinearBase):
|
104
|
+
if is_layer_skipped(prefix, self.ignored_layers):
|
105
|
+
return UnquantizedLinearMethod()
|
106
|
+
return Fp8LinearMethod(self)
|
107
|
+
elif isinstance(layer, FusedMoE):
|
108
|
+
return Fp8MoEMethod(self)
|
109
|
+
elif isinstance(layer, Attention):
|
110
|
+
return Fp8KVCacheMethod(self)
|
111
|
+
return None
|
112
|
+
|
113
|
+
def get_scaled_act_names(self) -> List[str]:
|
114
|
+
return []
|
115
|
+
|
116
|
+
|
117
|
+
class Fp8LinearMethod(LinearMethodBase):
|
118
|
+
"""Linear method for FP8.
|
119
|
+
Supports loading FP8 checkpoints with static weight scale and
|
120
|
+
dynamic/static activation scale.
|
121
|
+
|
122
|
+
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
123
|
+
activation scaling. The weight scaling factor will be initialized after
|
124
|
+
the model weights are loaded.
|
125
|
+
|
126
|
+
Limitations:
|
127
|
+
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
128
|
+
2. Only support float8_e4m3fn data type due to the limitation of
|
129
|
+
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
|
130
|
+
|
131
|
+
Args:
|
132
|
+
quant_config: The quantization config.
|
133
|
+
"""
|
134
|
+
|
135
|
+
def __init__(self, quant_config: Fp8Config):
|
136
|
+
self.quant_config = quant_config
|
137
|
+
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
138
|
+
|
139
|
+
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
140
|
+
# kernel for fast weight-only FP8 quantization
|
141
|
+
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
142
|
+
# Disable marlin for ROCm
|
143
|
+
if is_hip():
|
144
|
+
self.use_marlin = False
|
145
|
+
|
146
|
+
def create_weights(
|
147
|
+
self,
|
148
|
+
layer: torch.nn.Module,
|
149
|
+
input_size_per_partition: int,
|
150
|
+
output_partition_sizes: List[int],
|
151
|
+
input_size: int,
|
152
|
+
output_size: int,
|
153
|
+
params_dtype: torch.dtype,
|
154
|
+
**extra_weight_attrs,
|
155
|
+
):
|
156
|
+
del input_size, output_size
|
157
|
+
output_size_per_partition = sum(output_partition_sizes)
|
158
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
159
|
+
|
160
|
+
layer.logical_widths = output_partition_sizes
|
161
|
+
|
162
|
+
layer.input_size_per_partition = input_size_per_partition
|
163
|
+
layer.output_size_per_partition = output_size_per_partition
|
164
|
+
layer.orig_dtype = params_dtype
|
165
|
+
|
166
|
+
# WEIGHT
|
167
|
+
weight_dtype = (
|
168
|
+
torch.float8_e4m3fn
|
169
|
+
if self.quant_config.is_checkpoint_fp8_serialized
|
170
|
+
else params_dtype
|
171
|
+
)
|
172
|
+
|
173
|
+
weight = ModelWeightParameter(
|
174
|
+
data=torch.empty(
|
175
|
+
output_size_per_partition, input_size_per_partition, dtype=weight_dtype
|
176
|
+
),
|
177
|
+
input_dim=1,
|
178
|
+
output_dim=0,
|
179
|
+
weight_loader=weight_loader,
|
180
|
+
)
|
181
|
+
layer.register_parameter("weight", weight)
|
182
|
+
|
183
|
+
# If checkpoint is serialized fp8, load them.
|
184
|
+
# Otherwise, wait until process_weights_after_loading.
|
185
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
186
|
+
# WEIGHT SCALE
|
187
|
+
scale = PerTensorScaleParameter(
|
188
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
189
|
+
weight_loader=weight_loader,
|
190
|
+
)
|
191
|
+
|
192
|
+
scale[:] = torch.finfo(torch.float32).min
|
193
|
+
layer.register_parameter("weight_scale", scale)
|
194
|
+
|
195
|
+
# INPUT ACTIVATION SCALE
|
196
|
+
if self.quant_config.activation_scheme == "static":
|
197
|
+
scale = PerTensorScaleParameter(
|
198
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
199
|
+
weight_loader=weight_loader,
|
200
|
+
)
|
201
|
+
|
202
|
+
scale[:] = torch.finfo(torch.float32).min
|
203
|
+
layer.register_parameter("input_scale", scale)
|
204
|
+
else:
|
205
|
+
layer.register_parameter("input_scale", None)
|
206
|
+
|
207
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
208
|
+
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
209
|
+
# If checkpoint not serialized fp8, quantize the weights.
|
210
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
211
|
+
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
212
|
+
|
213
|
+
# If using marlin (w8a16), kernel uses channelwise weights,
|
214
|
+
# so extend the weight scales to be channelwise.
|
215
|
+
if self.use_marlin:
|
216
|
+
assert weight_scale.numel() == 1
|
217
|
+
weight_scale = convert_to_channelwise(
|
218
|
+
weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
|
219
|
+
)
|
220
|
+
|
221
|
+
# Update the layer with the new values.
|
222
|
+
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
223
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
224
|
+
layer.input_scale = None
|
225
|
+
|
226
|
+
# If checkpoint is fp8, handle that there are N scales for N
|
227
|
+
# shards in a fused module
|
228
|
+
else:
|
229
|
+
layer.weight_scale = torch.nn.Parameter(
|
230
|
+
layer.weight_scale.data, requires_grad=False
|
231
|
+
)
|
232
|
+
if self.quant_config.activation_scheme == "static":
|
233
|
+
layer.input_scale = torch.nn.Parameter(
|
234
|
+
layer.input_scale.data, requires_grad=False
|
235
|
+
)
|
236
|
+
# If using marlin (w8a16), kernel uses channelwise weights,
|
237
|
+
# so extend the weight scales to be channelwise.
|
238
|
+
if self.use_marlin:
|
239
|
+
weight = layer.weight
|
240
|
+
weight_scale = convert_to_channelwise(
|
241
|
+
layer.weight_scale, layer.logical_widths
|
242
|
+
)
|
243
|
+
|
244
|
+
# If using w8a8, torch._scaled_mm needs per tensor, so
|
245
|
+
# requantize the logical shards as a single weight.
|
246
|
+
else:
|
247
|
+
# Dequant -> Quant with max scale so we can run per tensor.
|
248
|
+
weight = layer.weight
|
249
|
+
weight_scale = layer.weight_scale
|
250
|
+
|
251
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
252
|
+
if is_hip():
|
253
|
+
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
254
|
+
weight=weight,
|
255
|
+
weight_scale=weight_scale,
|
256
|
+
input_scale=layer.input_scale,
|
257
|
+
)
|
258
|
+
if input_scale is not None:
|
259
|
+
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
260
|
+
|
261
|
+
weight_scale, weight = requantize_with_max_scale(
|
262
|
+
weight=weight,
|
263
|
+
weight_scale=weight_scale,
|
264
|
+
logical_widths=layer.logical_widths,
|
265
|
+
)
|
266
|
+
|
267
|
+
# Update layer with new values.
|
268
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
269
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
270
|
+
if self.quant_config.activation_scheme == "static":
|
271
|
+
layer.input_scale = Parameter(
|
272
|
+
layer.input_scale.max(), requires_grad=False
|
273
|
+
)
|
274
|
+
|
275
|
+
if self.use_marlin:
|
276
|
+
prepare_fp8_layer_for_marlin(layer)
|
277
|
+
# Activations not quantized for marlin.
|
278
|
+
del layer.input_scale
|
279
|
+
|
280
|
+
def apply(
|
281
|
+
self,
|
282
|
+
layer: torch.nn.Module,
|
283
|
+
x: torch.Tensor,
|
284
|
+
bias: Optional[torch.Tensor] = None,
|
285
|
+
) -> torch.Tensor:
|
286
|
+
|
287
|
+
if self.use_marlin:
|
288
|
+
return apply_fp8_marlin_linear(
|
289
|
+
input=x,
|
290
|
+
weight=layer.weight,
|
291
|
+
weight_scale=layer.weight_scale,
|
292
|
+
workspace=layer.workspace,
|
293
|
+
size_n=layer.output_size_per_partition,
|
294
|
+
size_k=layer.input_size_per_partition,
|
295
|
+
bias=bias,
|
296
|
+
)
|
297
|
+
|
298
|
+
return apply_fp8_linear(
|
299
|
+
input=x,
|
300
|
+
weight=layer.weight,
|
301
|
+
weight_scale=layer.weight_scale,
|
302
|
+
input_scale=layer.input_scale,
|
303
|
+
bias=bias,
|
304
|
+
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
305
|
+
use_per_token_if_dynamic=False,
|
306
|
+
)
|
307
|
+
|
308
|
+
|
309
|
+
class Fp8MoEMethod(FusedMoEMethodBase):
|
310
|
+
"""MoE method for FP8.
|
311
|
+
Supports loading FP8 checkpoints with static weight scale and
|
312
|
+
dynamic/static activation scale.
|
313
|
+
|
314
|
+
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
315
|
+
activation scaling. The weight scaling factor will be initialized after
|
316
|
+
the model weights are loaded.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
quant_config: The quantization config.
|
320
|
+
"""
|
321
|
+
|
322
|
+
def __init__(self, quant_config: Fp8Config):
|
323
|
+
self.quant_config = quant_config
|
324
|
+
|
325
|
+
def create_weights(
|
326
|
+
self,
|
327
|
+
layer: Module,
|
328
|
+
num_experts: int,
|
329
|
+
hidden_size: int,
|
330
|
+
intermediate_size: int,
|
331
|
+
params_dtype: torch.dtype,
|
332
|
+
**extra_weight_attrs,
|
333
|
+
):
|
334
|
+
|
335
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
336
|
+
params_dtype = torch.float8_e4m3fn
|
337
|
+
|
338
|
+
# WEIGHTS
|
339
|
+
w13_weight = torch.nn.Parameter(
|
340
|
+
torch.empty(
|
341
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
342
|
+
),
|
343
|
+
requires_grad=False,
|
344
|
+
)
|
345
|
+
layer.register_parameter("w13_weight", w13_weight)
|
346
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
347
|
+
|
348
|
+
w2_weight = torch.nn.Parameter(
|
349
|
+
torch.empty(
|
350
|
+
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
351
|
+
),
|
352
|
+
requires_grad=False,
|
353
|
+
)
|
354
|
+
layer.register_parameter("w2_weight", w2_weight)
|
355
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
356
|
+
|
357
|
+
# WEIGHT_SCALES
|
358
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
359
|
+
# They will be combined to a single scale after weight loading.
|
360
|
+
w13_weight_scale = torch.nn.Parameter(
|
361
|
+
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
362
|
+
)
|
363
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
364
|
+
|
365
|
+
w2_weight_scale = torch.nn.Parameter(
|
366
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
367
|
+
)
|
368
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
369
|
+
# Add the quantization method used (per tensor/grouped/channel)
|
370
|
+
# to ensure the weight scales are loaded in properly
|
371
|
+
extra_weight_attrs.update(
|
372
|
+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
373
|
+
)
|
374
|
+
# If loading fp8 checkpoint, pass the weight loaders.
|
375
|
+
# If loading an fp16 checkpoint, do not (we will quantize in
|
376
|
+
# process_weights_after_loading()
|
377
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
378
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
379
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
380
|
+
|
381
|
+
# INPUT_SCALES
|
382
|
+
if self.quant_config.activation_scheme == "static":
|
383
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
384
|
+
raise ValueError(
|
385
|
+
"Found static activation scheme for checkpoint that "
|
386
|
+
"was not serialized fp8."
|
387
|
+
)
|
388
|
+
|
389
|
+
w13_input_scale = torch.nn.Parameter(
|
390
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
391
|
+
)
|
392
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
393
|
+
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
394
|
+
|
395
|
+
w2_input_scale = torch.nn.Parameter(
|
396
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
397
|
+
)
|
398
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
399
|
+
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
400
|
+
|
401
|
+
else:
|
402
|
+
layer.w13_input_scale = None
|
403
|
+
layer.w2_input_scale = None
|
404
|
+
|
405
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
406
|
+
|
407
|
+
# If checkpoint is fp16, quantize in place.
|
408
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
409
|
+
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
410
|
+
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
411
|
+
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
412
|
+
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
413
|
+
|
414
|
+
# Re-initialize w13_scale because we directly quantize
|
415
|
+
# merged w13 weights and generate a single scaling factor.
|
416
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
417
|
+
torch.ones(
|
418
|
+
layer.num_experts, dtype=torch.float32, device=w13_weight.device
|
419
|
+
),
|
420
|
+
requires_grad=False,
|
421
|
+
)
|
422
|
+
for expert in range(layer.num_experts):
|
423
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
424
|
+
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
425
|
+
)
|
426
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
427
|
+
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
428
|
+
)
|
429
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
430
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
431
|
+
return
|
432
|
+
|
433
|
+
# If checkpoint is fp8, we need to handle that the
|
434
|
+
# MoE kernels require single activation scale and single weight
|
435
|
+
# scale for w13 per expert.
|
436
|
+
else:
|
437
|
+
# Fp8 moe kernels require a single activation scale.
|
438
|
+
# We take the max of all the scales in case they differ.
|
439
|
+
if self.quant_config.activation_scheme == "static":
|
440
|
+
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
441
|
+
raise ValueError(
|
442
|
+
"QuantConfig has static quantization, but found "
|
443
|
+
"activation scales are None."
|
444
|
+
)
|
445
|
+
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
446
|
+
layer.w2_input_scale
|
447
|
+
):
|
448
|
+
print_warning_once(
|
449
|
+
"Found input_scales that are not equal for "
|
450
|
+
"fp8 MoE layer. Using the maximum across experts "
|
451
|
+
"for each layer. "
|
452
|
+
)
|
453
|
+
layer.w13_input_scale = torch.nn.Parameter(
|
454
|
+
layer.w13_input_scale.max(), requires_grad=False
|
455
|
+
)
|
456
|
+
layer.w2_input_scale = torch.nn.Parameter(
|
457
|
+
layer.w2_input_scale.max(), requires_grad=False
|
458
|
+
)
|
459
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
460
|
+
if is_hip():
|
461
|
+
# Normalize the weights and scales
|
462
|
+
w13_weight, w13_weight_scale, w13_input_scale = (
|
463
|
+
normalize_e4m3fn_to_e4m3fnuz(
|
464
|
+
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
|
465
|
+
)
|
466
|
+
)
|
467
|
+
w2_weight, w2_weight_scale, w2_input_scale = (
|
468
|
+
normalize_e4m3fn_to_e4m3fnuz(
|
469
|
+
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
|
470
|
+
)
|
471
|
+
)
|
472
|
+
# Reset the parameter
|
473
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
474
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
475
|
+
w13_weight_scale, requires_grad=False
|
476
|
+
)
|
477
|
+
if w13_input_scale is not None:
|
478
|
+
layer.w13_input_scale = torch.nn.Parameter(
|
479
|
+
w13_input_scale, requires_grad=False
|
480
|
+
)
|
481
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
482
|
+
layer.w2_weight_scale = torch.nn.Parameter(
|
483
|
+
w2_weight_scale, requires_grad=False
|
484
|
+
)
|
485
|
+
if w2_input_scale is not None:
|
486
|
+
layer.w2_input_scale = torch.nn.Parameter(
|
487
|
+
w2_input_scale, requires_grad=False
|
488
|
+
)
|
489
|
+
|
490
|
+
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
491
|
+
# We take the max then dequant and requant each expert.
|
492
|
+
assert layer.w13_weight_scale is not None
|
493
|
+
shard_size = layer.intermediate_size_per_partition
|
494
|
+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
495
|
+
for expert_id in range(layer.num_experts):
|
496
|
+
start = 0
|
497
|
+
for shard_id in range(2):
|
498
|
+
dq_weight = per_tensor_dequantize(
|
499
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
500
|
+
layer.w13_weight_scale[expert_id][shard_id],
|
501
|
+
)
|
502
|
+
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
503
|
+
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
504
|
+
)
|
505
|
+
start += shard_size
|
506
|
+
|
507
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
508
|
+
max_w13_scales, requires_grad=False
|
509
|
+
)
|
510
|
+
return
|
511
|
+
|
512
|
+
def apply(
|
513
|
+
self,
|
514
|
+
layer: torch.nn.Module,
|
515
|
+
x: torch.Tensor,
|
516
|
+
router_logits: torch.Tensor,
|
517
|
+
top_k: int,
|
518
|
+
renormalize: bool,
|
519
|
+
use_grouped_topk: bool,
|
520
|
+
topk_group: Optional[int] = None,
|
521
|
+
num_expert_group: Optional[int] = None,
|
522
|
+
custom_routing_function: Optional[Callable] = None,
|
523
|
+
) -> torch.Tensor:
|
524
|
+
|
525
|
+
from vllm.model_executor.layers.fused_moe import fused_experts
|
526
|
+
|
527
|
+
topk_weights, topk_ids = FusedMoE.select_experts(
|
528
|
+
hidden_states=x,
|
529
|
+
router_logits=router_logits,
|
530
|
+
use_grouped_topk=use_grouped_topk,
|
531
|
+
top_k=top_k,
|
532
|
+
renormalize=renormalize,
|
533
|
+
topk_group=topk_group,
|
534
|
+
num_expert_group=num_expert_group,
|
535
|
+
custom_routing_function=custom_routing_function,
|
536
|
+
)
|
537
|
+
|
538
|
+
return fused_experts(
|
539
|
+
x,
|
540
|
+
layer.w13_weight,
|
541
|
+
layer.w2_weight,
|
542
|
+
topk_weights=topk_weights,
|
543
|
+
topk_ids=topk_ids,
|
544
|
+
inplace=True,
|
545
|
+
use_fp8_w8a8=True,
|
546
|
+
w1_scale=layer.w13_weight_scale,
|
547
|
+
w2_scale=layer.w2_weight_scale,
|
548
|
+
a1_scale=layer.w13_input_scale,
|
549
|
+
a2_scale=layer.w2_input_scale,
|
550
|
+
)
|
551
|
+
|
552
|
+
|
553
|
+
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
554
|
+
"""
|
555
|
+
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
556
|
+
"""
|
557
|
+
|
558
|
+
def __init__(self, quant_config: Fp8Config):
|
559
|
+
super().__init__(quant_config)
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
|
6
|
+
def normalize_e4m3fn_to_e4m3fnuz(
|
7
|
+
weight: torch.Tensor,
|
8
|
+
weight_scale: torch.Tensor,
|
9
|
+
input_scale: Optional[torch.Tensor] = None,
|
10
|
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
11
|
+
assert weight.dtype == torch.float8_e4m3fn
|
12
|
+
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
13
|
+
# but NaN in e4m3fnuz. So here we set it to 0.
|
14
|
+
# https://onnx.ai/onnx/technical/float8.html
|
15
|
+
weight_as_int8 = weight.view(torch.int8)
|
16
|
+
ROCM_FP8_NAN_AS_INT = -128
|
17
|
+
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
18
|
+
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
19
|
+
|
20
|
+
# For the same bits representation, e4m3fnuz value is half of
|
21
|
+
# the e4m3fn value, so we should double the scaling factor to
|
22
|
+
# get the same dequantized value.
|
23
|
+
# https://onnx.ai/onnx/technical/float8.html
|
24
|
+
weight_scale = weight_scale * 2.0
|
25
|
+
if input_scale is not None:
|
26
|
+
input_scale = input_scale * 2.0
|
27
|
+
return weight, weight_scale, input_scale
|
@@ -48,11 +48,13 @@ class RadixAttention(nn.Module):
|
|
48
48
|
self.sliding_window_size = sliding_window_size or -1
|
49
49
|
self.is_cross_attention = is_cross_attention
|
50
50
|
|
51
|
-
def forward(self, q, k, v, forward_batch: ForwardBatch):
|
51
|
+
def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True):
|
52
52
|
if k is not None:
|
53
53
|
# For cross-layer sharing, kv can be None
|
54
54
|
assert v is not None
|
55
55
|
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
56
56
|
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
57
57
|
|
58
|
-
return forward_batch.attn_backend.forward(
|
58
|
+
return forward_batch.attn_backend.forward(
|
59
|
+
q, k, v, self, forward_batch, save_kv_cache
|
60
|
+
)
|
sglang/srt/layers/sampler.py
CHANGED
@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
111
111
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
112
112
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
113
113
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
114
|
+
# int32 range is enough to represent the token ids
|
115
|
+
probs_idx = probs_idx.to(torch.int32)
|
114
116
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
115
117
|
return batch_next_token_ids
|