sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +23 -3
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +5 -16
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +218 -79
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/topk.py +30 -3
- sglang/srt/layers/quantization/__init__.py +134 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +12 -0
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/schedule_batch.py +1 -0
- sglang/srt/managers/scheduler.py +25 -19
- sglang/srt/managers/tokenizer_manager.py +0 -1
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -8
- sglang/srt/model_executor/model_runner.py +9 -6
- sglang/srt/model_loader/loader.py +11 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +151 -26
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +6 -0
- sglang/srt/openai_api/adapter.py +88 -87
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/server_args.py +21 -11
- sglang/srt/speculative/eagle_worker.py +1 -1
- sglang/srt/utils.py +33 -0
- sglang/test/runners.py +27 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +8 -4
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +57 -53
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import functools
|
4
4
|
import logging
|
5
5
|
from contextlib import contextmanager
|
6
|
-
from typing import TYPE_CHECKING,
|
6
|
+
from typing import TYPE_CHECKING, List
|
7
7
|
|
8
8
|
import torch
|
9
9
|
import triton
|
@@ -249,3 +249,14 @@ def dp_scatter(
|
|
249
249
|
memcpy_triton(
|
250
250
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
251
251
|
)
|
252
|
+
|
253
|
+
|
254
|
+
def tp_reduce_scatter(
|
255
|
+
output: torch.Tensor,
|
256
|
+
input_list: List[torch.Tensor],
|
257
|
+
):
|
258
|
+
return get_attention_tp_group().reduce_scatter(output, input_list)
|
259
|
+
|
260
|
+
|
261
|
+
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
262
|
+
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -17,12 +17,12 @@ from typing import Callable, Optional
|
|
17
17
|
import torch
|
18
18
|
import torch.nn.functional as F
|
19
19
|
|
20
|
+
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
20
21
|
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
21
22
|
|
22
23
|
_is_cuda = is_cuda()
|
23
24
|
_is_hip = is_hip()
|
24
25
|
|
25
|
-
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
26
26
|
|
27
27
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
28
28
|
|
@@ -129,8 +129,7 @@ def grouped_topk(
|
|
129
129
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
130
130
|
|
131
131
|
|
132
|
-
|
133
|
-
def biased_grouped_topk(
|
132
|
+
def biased_grouped_topk_impl(
|
134
133
|
hidden_states: torch.Tensor,
|
135
134
|
gating_output: torch.Tensor,
|
136
135
|
correction_bias: torch.Tensor,
|
@@ -171,6 +170,34 @@ def biased_grouped_topk(
|
|
171
170
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
172
171
|
|
173
172
|
|
173
|
+
def biased_grouped_topk(
|
174
|
+
hidden_states: torch.Tensor,
|
175
|
+
gating_output: torch.Tensor,
|
176
|
+
correction_bias: torch.Tensor,
|
177
|
+
topk: int,
|
178
|
+
renormalize: bool,
|
179
|
+
num_expert_group: int = 0,
|
180
|
+
topk_group: int = 0,
|
181
|
+
compiled: bool = True,
|
182
|
+
):
|
183
|
+
biased_grouped_topk_fn = (
|
184
|
+
torch.compile(
|
185
|
+
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
|
186
|
+
)
|
187
|
+
if compiled
|
188
|
+
else biased_grouped_topk_impl
|
189
|
+
)
|
190
|
+
return biased_grouped_topk_fn(
|
191
|
+
hidden_states,
|
192
|
+
gating_output,
|
193
|
+
correction_bias,
|
194
|
+
topk,
|
195
|
+
renormalize,
|
196
|
+
num_expert_group,
|
197
|
+
topk_group,
|
198
|
+
)
|
199
|
+
|
200
|
+
|
174
201
|
def select_experts(
|
175
202
|
hidden_states: torch.Tensor,
|
176
203
|
router_logits: torch.Tensor,
|
@@ -9,13 +9,24 @@ import torch
|
|
9
9
|
|
10
10
|
try:
|
11
11
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
12
|
-
from vllm.model_executor.layers.quantization.
|
13
|
-
|
12
|
+
from vllm.model_executor.layers.quantization.awq_marlin import (
|
13
|
+
AWQMarlinConfig,
|
14
|
+
AWQMoEMethod,
|
15
|
+
)
|
14
16
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
17
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
18
|
+
CompressedTensorsW8A8Fp8MoEMethod,
|
19
|
+
CompressedTensorsWNA16MoEMethod,
|
20
|
+
)
|
15
21
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
16
22
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
17
23
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
18
24
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
25
|
+
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
26
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
27
|
+
GPTQMarlinLinearMethod,
|
28
|
+
GPTQMarlinMoEMethod,
|
29
|
+
)
|
19
30
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
20
31
|
GPTQMarlin24Config,
|
21
32
|
)
|
@@ -23,33 +34,39 @@ try:
|
|
23
34
|
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
24
35
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
25
36
|
|
26
|
-
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
27
|
-
|
28
37
|
VLLM_AVAILABLE = True
|
29
38
|
except ImportError:
|
30
39
|
VLLM_AVAILABLE = False
|
31
40
|
|
32
41
|
# Define empty classes as placeholders when vllm is not available
|
33
42
|
class DummyConfig:
|
34
|
-
|
43
|
+
def override_quantization_method(self, *args, **kwargs):
|
44
|
+
return None
|
45
|
+
|
46
|
+
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
|
47
|
+
DeepSpeedFPConfig
|
48
|
+
) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
|
49
|
+
MarlinConfig
|
50
|
+
) = QQQConfig = Int8TpuConfig = DummyConfig
|
35
51
|
|
36
|
-
AQLMConfig = AWQConfig = AWQMarlinConfig = BitsAndBytesConfig = (
|
37
|
-
CompressedTensorsConfig
|
38
|
-
) = DummyConfig
|
39
|
-
DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = (
|
40
|
-
GPTQMarlin24Config
|
41
|
-
) = DummyConfig
|
42
|
-
MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
|
43
52
|
|
53
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
54
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
55
|
+
from sglang.srt.layers.quantization.awq import AWQConfig
|
44
56
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
45
57
|
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
46
58
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
47
59
|
CompressedTensorsConfig,
|
48
60
|
)
|
49
61
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
62
|
+
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
50
63
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
51
64
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
52
65
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
66
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
67
|
+
ParallelLMHead,
|
68
|
+
UnquantizedEmbeddingMethod,
|
69
|
+
)
|
53
70
|
|
54
71
|
# Base quantization methods that don't depend on vllm
|
55
72
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
@@ -61,26 +78,25 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
61
78
|
"compressed-tensors": CompressedTensorsConfig,
|
62
79
|
}
|
63
80
|
|
64
|
-
#
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
|
81
|
+
# VLLM-dependent quantization methods
|
82
|
+
VLLM_QUANTIZATION_METHODS = {
|
83
|
+
"aqlm": AQLMConfig,
|
84
|
+
"awq": AWQConfig,
|
85
|
+
"deepspeedfp": DeepSpeedFPConfig,
|
86
|
+
"tpu_int8": Int8TpuConfig,
|
87
|
+
"fbgemm_fp8": FBGEMMFp8Config,
|
88
|
+
"marlin": MarlinConfig,
|
89
|
+
"gguf": GGUFConfig,
|
90
|
+
"gptq_marlin_24": GPTQMarlin24Config,
|
91
|
+
"awq_marlin": AWQMarlinConfig,
|
92
|
+
"bitsandbytes": BitsAndBytesConfig,
|
93
|
+
"qqq": QQQConfig,
|
94
|
+
"experts_int8": ExpertsInt8Config,
|
95
|
+
"gptq_marlin": GPTQMarlinConfig,
|
96
|
+
"gptq": GPTQConfig,
|
97
|
+
}
|
98
|
+
|
99
|
+
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
|
84
100
|
|
85
101
|
|
86
102
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
@@ -89,6 +105,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
89
105
|
f"Invalid quantization method: {quantization}. "
|
90
106
|
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
91
107
|
)
|
108
|
+
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
109
|
+
raise ValueError(
|
110
|
+
f"{quantization} quantization requires some operators from vllm. "
|
111
|
+
"Pleaes install vllm by `pip install vllm==0.7.2`"
|
112
|
+
)
|
113
|
+
|
92
114
|
return QUANTIZATION_METHODS[quantization]
|
93
115
|
|
94
116
|
|
@@ -153,13 +175,6 @@ def get_linear_quant_method(
|
|
153
175
|
prefix: str,
|
154
176
|
linear_method_cls: type,
|
155
177
|
):
|
156
|
-
|
157
|
-
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
158
|
-
from sglang.srt.layers.vocab_parallel_embedding import (
|
159
|
-
ParallelLMHead,
|
160
|
-
UnquantizedEmbeddingMethod,
|
161
|
-
)
|
162
|
-
|
163
178
|
cloned_config = deepcopy(config)
|
164
179
|
parallel_lm_head_quantized = (
|
165
180
|
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
@@ -186,31 +201,17 @@ def get_linear_quant_method(
|
|
186
201
|
|
187
202
|
|
188
203
|
def gptq_get_quant_method(self, layer, prefix):
|
189
|
-
if
|
190
|
-
return
|
204
|
+
if isinstance(layer, FusedMoE):
|
205
|
+
return GPTQMarlinMoEMethod(self)
|
191
206
|
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
207
|
+
if isinstance(self, GPTQConfig):
|
208
|
+
return get_linear_quant_method(
|
209
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
210
|
+
)
|
211
|
+
elif isinstance(self, GPTQMarlinConfig):
|
212
|
+
return get_linear_quant_method(
|
213
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
197
214
|
)
|
198
|
-
|
199
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
200
|
-
|
201
|
-
if isinstance(layer, FusedMoE):
|
202
|
-
return GPTQMarlinMoEMethod(self)
|
203
|
-
|
204
|
-
if isinstance(self, GPTQConfig):
|
205
|
-
return get_linear_quant_method(
|
206
|
-
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
207
|
-
)
|
208
|
-
elif isinstance(self, GPTQMarlinConfig):
|
209
|
-
return get_linear_quant_method(
|
210
|
-
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
211
|
-
)
|
212
|
-
except ImportError:
|
213
|
-
pass
|
214
215
|
return None
|
215
216
|
|
216
217
|
|
@@ -229,33 +230,28 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
|
229
230
|
builtins.isinstance = original_isinstance
|
230
231
|
return
|
231
232
|
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
)
|
233
|
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
234
|
+
from vllm.model_executor.layers.linear import LinearBase
|
235
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
236
|
+
VocabParallelEmbedding,
|
237
|
+
)
|
238
238
|
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
245
|
-
)
|
239
|
+
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
240
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
|
241
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
242
|
+
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
243
|
+
)
|
246
244
|
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
except ImportError:
|
258
|
-
return
|
245
|
+
def patched_isinstance(obj, classinfo):
|
246
|
+
if classinfo is LinearBase:
|
247
|
+
return original_isinstance(obj, PatchedLinearBase)
|
248
|
+
if classinfo is FusedMoE:
|
249
|
+
return original_isinstance(obj, PatchedFusedMoE)
|
250
|
+
if classinfo is VocabParallelEmbedding:
|
251
|
+
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
252
|
+
return original_isinstance(obj, classinfo)
|
253
|
+
|
254
|
+
builtins.isinstance = patched_isinstance
|
259
255
|
|
260
256
|
|
261
257
|
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
@@ -263,91 +259,64 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
263
259
|
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
264
260
|
Convert sglang arguments to vllm arguments.
|
265
261
|
"""
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
kwargs["e_score_correction_bias"] = correction_bias
|
312
|
-
return original_apply(**kwargs)
|
313
|
-
|
314
|
-
setattr(class_obj, "apply", new_apply)
|
315
|
-
except (ImportError, AttributeError):
|
316
|
-
return
|
262
|
+
original_apply = class_obj.apply
|
263
|
+
sig = inspect.signature(original_apply)
|
264
|
+
param_names = list(sig.parameters.keys())
|
265
|
+
has_correction_bias = "e_score_correction_bias" in param_names
|
266
|
+
|
267
|
+
def new_apply(
|
268
|
+
self,
|
269
|
+
layer: torch.nn.Module,
|
270
|
+
x: torch.Tensor,
|
271
|
+
router_logits: torch.Tensor,
|
272
|
+
top_k: int,
|
273
|
+
renormalize: bool,
|
274
|
+
use_grouped_topk: bool,
|
275
|
+
topk_group: Optional[int] = None,
|
276
|
+
num_expert_group: Optional[int] = None,
|
277
|
+
custom_routing_function: Optional[Callable] = None,
|
278
|
+
correction_bias: Optional[torch.Tensor] = None,
|
279
|
+
activation: str = "silu",
|
280
|
+
inplace: bool = True,
|
281
|
+
no_combine: bool = False,
|
282
|
+
):
|
283
|
+
assert activation == "silu"
|
284
|
+
assert inplace and not no_combine
|
285
|
+
|
286
|
+
kwargs = {
|
287
|
+
"self": self,
|
288
|
+
"layer": layer,
|
289
|
+
"x": x,
|
290
|
+
"router_logits": router_logits,
|
291
|
+
"top_k": top_k,
|
292
|
+
"renormalize": renormalize,
|
293
|
+
"use_grouped_topk": use_grouped_topk,
|
294
|
+
"topk_group": topk_group,
|
295
|
+
"num_expert_group": num_expert_group,
|
296
|
+
"custom_routing_function": custom_routing_function,
|
297
|
+
}
|
298
|
+
if correction_bias is not None:
|
299
|
+
if not has_correction_bias:
|
300
|
+
raise ValueError(
|
301
|
+
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
302
|
+
)
|
303
|
+
kwargs["e_score_correction_bias"] = correction_bias
|
304
|
+
return original_apply(**kwargs)
|
305
|
+
|
306
|
+
setattr(class_obj, "apply", new_apply)
|
317
307
|
|
318
308
|
|
319
309
|
def monkey_patch_quant_configs():
|
320
310
|
"""Apply all monkey patches in one place."""
|
321
|
-
|
322
|
-
|
311
|
+
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
312
|
+
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
323
313
|
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
CompressedTensorsWNA16MoEMethod,
|
329
|
-
)
|
330
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
331
|
-
GPTQMarlinMoEMethod,
|
332
|
-
)
|
333
|
-
|
334
|
-
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
335
|
-
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
336
|
-
|
337
|
-
monkey_patch_moe_apply(AWQMoEMethod)
|
338
|
-
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
339
|
-
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
340
|
-
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
341
|
-
except ImportError:
|
342
|
-
return
|
314
|
+
monkey_patch_moe_apply(AWQMoEMethod)
|
315
|
+
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
316
|
+
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
317
|
+
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
343
318
|
|
344
319
|
|
345
320
|
# Only apply monkey patches if vllm is available
|
346
321
|
if VLLM_AVAILABLE:
|
347
322
|
monkey_patch_quant_configs()
|
348
|
-
|
349
|
-
|
350
|
-
__all__ = [
|
351
|
-
"get_quantization_config",
|
352
|
-
"QUANTIZATION_METHODS",
|
353
|
-
]
|
@@ -0,0 +1,200 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
import logging
|
3
|
+
from typing import Any, Dict, List, Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from sgl_kernel import awq_dequantize
|
7
|
+
|
8
|
+
from sglang.srt.layers.linear import (
|
9
|
+
LinearBase,
|
10
|
+
LinearMethodBase,
|
11
|
+
UnquantizedLinearMethod,
|
12
|
+
)
|
13
|
+
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
14
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
20
|
+
return any(module_name in prefix for module_name in modules_to_not_convert)
|
21
|
+
|
22
|
+
|
23
|
+
class AWQConfig(QuantizationConfig):
|
24
|
+
"""Config class for AWQ.
|
25
|
+
|
26
|
+
Reference: https://arxiv.org/abs/2306.00978
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
weight_bits: int,
|
32
|
+
group_size: int,
|
33
|
+
zero_point: bool,
|
34
|
+
modules_to_not_convert: Optional[List[str]] = None,
|
35
|
+
) -> None:
|
36
|
+
super().__init__()
|
37
|
+
self.weight_bits = weight_bits
|
38
|
+
self.group_size = group_size
|
39
|
+
self.zero_point = zero_point
|
40
|
+
self.modules_to_not_convert = modules_to_not_convert or []
|
41
|
+
|
42
|
+
if self.weight_bits != 4:
|
43
|
+
raise ValueError(
|
44
|
+
"Currently, only 4-bit weight quantization is supported for "
|
45
|
+
f"AWQ, but got {self.weight_bits} bits."
|
46
|
+
)
|
47
|
+
self.pack_factor = 32 // self.weight_bits
|
48
|
+
|
49
|
+
def __repr__(self) -> str:
|
50
|
+
return (
|
51
|
+
f"AWQConfig(weight_bits={self.weight_bits}, "
|
52
|
+
f"group_size={self.group_size}, "
|
53
|
+
f"zero_point={self.zero_point}, "
|
54
|
+
f"modules_to_not_convert={self.modules_to_not_convert})"
|
55
|
+
)
|
56
|
+
|
57
|
+
def get_scaled_act_names(self) -> List[str]:
|
58
|
+
return []
|
59
|
+
|
60
|
+
def get_name(self) -> str:
|
61
|
+
return "awq"
|
62
|
+
|
63
|
+
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
64
|
+
return [torch.half]
|
65
|
+
|
66
|
+
@classmethod
|
67
|
+
def get_min_capability(cls) -> int:
|
68
|
+
# The AWQ kernel only supports Turing or newer GPUs.
|
69
|
+
return 75
|
70
|
+
|
71
|
+
@staticmethod
|
72
|
+
def get_config_filenames() -> List[str]:
|
73
|
+
return [
|
74
|
+
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
75
|
+
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
|
76
|
+
"quantize_config.json",
|
77
|
+
]
|
78
|
+
|
79
|
+
@classmethod
|
80
|
+
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
81
|
+
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
82
|
+
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
83
|
+
zero_point = cls.get_from_keys(config, ["zero_point"])
|
84
|
+
modules_to_not_convert = cls.get_from_keys_or(
|
85
|
+
config, ["modules_to_not_convert"], None
|
86
|
+
)
|
87
|
+
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
|
88
|
+
|
89
|
+
def get_quant_method(
|
90
|
+
self, layer: torch.nn.Module, prefix: str
|
91
|
+
) -> Optional["LinearMethodBase"]:
|
92
|
+
|
93
|
+
if isinstance(layer, LinearBase):
|
94
|
+
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
95
|
+
return UnquantizedLinearMethod()
|
96
|
+
return AWQLinearMethod(self)
|
97
|
+
return None
|
98
|
+
|
99
|
+
|
100
|
+
class AWQLinearMethod(LinearMethodBase):
|
101
|
+
"""Linear method for AWQ.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
quant_config: The AWQ quantization config.
|
105
|
+
"""
|
106
|
+
|
107
|
+
def __init__(self, quant_config: AWQConfig):
|
108
|
+
self.quant_config = quant_config
|
109
|
+
|
110
|
+
def create_weights(
|
111
|
+
self,
|
112
|
+
layer: torch.nn.Module,
|
113
|
+
input_size_per_partition: int,
|
114
|
+
output_partition_sizes: List[int],
|
115
|
+
input_size: int,
|
116
|
+
output_size: int,
|
117
|
+
params_dtype: torch.dtype,
|
118
|
+
**extra_weight_attrs,
|
119
|
+
):
|
120
|
+
if input_size_per_partition % self.quant_config.group_size != 0:
|
121
|
+
raise ValueError(
|
122
|
+
"The input size is not aligned with the quantized "
|
123
|
+
"weight shape. This can be caused by too large "
|
124
|
+
"tensor parallel size."
|
125
|
+
)
|
126
|
+
|
127
|
+
output_size_per_partition = sum(output_partition_sizes)
|
128
|
+
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
129
|
+
raise ValueError(
|
130
|
+
"The output size is not aligned with the quantized "
|
131
|
+
"weight shape. This can be caused by too large "
|
132
|
+
"tensor parallel size."
|
133
|
+
)
|
134
|
+
|
135
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
136
|
+
qweight = PackedvLLMParameter(
|
137
|
+
data=torch.empty(
|
138
|
+
input_size_per_partition,
|
139
|
+
output_size_per_partition // self.quant_config.pack_factor,
|
140
|
+
dtype=torch.int32,
|
141
|
+
),
|
142
|
+
input_dim=0,
|
143
|
+
output_dim=1,
|
144
|
+
packed_dim=1,
|
145
|
+
packed_factor=self.quant_config.pack_factor,
|
146
|
+
weight_loader=weight_loader,
|
147
|
+
)
|
148
|
+
|
149
|
+
qzeros = PackedvLLMParameter(
|
150
|
+
data=torch.empty(
|
151
|
+
input_size_per_partition // self.quant_config.group_size,
|
152
|
+
output_size_per_partition // self.quant_config.pack_factor,
|
153
|
+
dtype=torch.int32,
|
154
|
+
),
|
155
|
+
input_dim=0,
|
156
|
+
output_dim=1,
|
157
|
+
packed_dim=1,
|
158
|
+
packed_factor=self.quant_config.pack_factor,
|
159
|
+
weight_loader=weight_loader,
|
160
|
+
)
|
161
|
+
|
162
|
+
scales = GroupQuantScaleParameter(
|
163
|
+
data=torch.empty(
|
164
|
+
input_size_per_partition // self.quant_config.group_size,
|
165
|
+
output_size_per_partition,
|
166
|
+
dtype=params_dtype,
|
167
|
+
),
|
168
|
+
input_dim=0,
|
169
|
+
output_dim=1,
|
170
|
+
weight_loader=weight_loader,
|
171
|
+
)
|
172
|
+
|
173
|
+
layer.register_parameter("qweight", qweight)
|
174
|
+
layer.register_parameter("qzeros", qzeros)
|
175
|
+
layer.register_parameter("scales", scales)
|
176
|
+
|
177
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
178
|
+
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
|
179
|
+
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
|
180
|
+
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
181
|
+
|
182
|
+
def apply(
|
183
|
+
self,
|
184
|
+
layer: torch.nn.Module,
|
185
|
+
x: torch.Tensor,
|
186
|
+
bias: Optional[torch.Tensor] = None,
|
187
|
+
) -> torch.Tensor:
|
188
|
+
qweight = layer.qweight
|
189
|
+
scales = layer.scales
|
190
|
+
qzeros = layer.qzeros
|
191
|
+
pack_factor = self.quant_config.pack_factor
|
192
|
+
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
193
|
+
reshaped_x = x.reshape(-1, x.shape[-1])
|
194
|
+
|
195
|
+
out = awq_dequantize(qweight, scales, qzeros)
|
196
|
+
out = torch.matmul(reshaped_x, out)
|
197
|
+
|
198
|
+
if bias is not None:
|
199
|
+
out.add_(bias)
|
200
|
+
return out.reshape(out_shape)
|