sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. sglang/bench_serving.py +23 -3
  2. sglang/srt/configs/deepseekvl2.py +10 -1
  3. sglang/srt/configs/model_config.py +5 -16
  4. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  5. sglang/srt/distributed/parallel_state.py +32 -5
  6. sglang/srt/entrypoints/http_server.py +7 -1
  7. sglang/srt/entrypoints/verl_engine.py +2 -0
  8. sglang/srt/function_call_parser.py +0 -1
  9. sglang/srt/layers/attention/flashattention_backend.py +218 -79
  10. sglang/srt/layers/dp_attention.py +12 -1
  11. sglang/srt/layers/moe/topk.py +30 -3
  12. sglang/srt/layers/quantization/__init__.py +134 -165
  13. sglang/srt/layers/quantization/awq.py +200 -0
  14. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  15. sglang/srt/layers/quantization/gptq.py +30 -40
  16. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  17. sglang/srt/layers/rotary_embedding.py +12 -0
  18. sglang/srt/lora/backend/base_backend.py +4 -4
  19. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  20. sglang/srt/lora/backend/triton_backend.py +5 -8
  21. sglang/srt/lora/layers.py +19 -33
  22. sglang/srt/lora/lora_manager.py +20 -7
  23. sglang/srt/lora/mem_pool.py +12 -6
  24. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  25. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  26. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  27. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  28. sglang/srt/lora/utils.py +6 -0
  29. sglang/srt/managers/io_struct.py +4 -2
  30. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  31. sglang/srt/managers/schedule_batch.py +1 -0
  32. sglang/srt/managers/scheduler.py +25 -19
  33. sglang/srt/managers/tokenizer_manager.py +0 -1
  34. sglang/srt/managers/tp_worker.py +3 -0
  35. sglang/srt/model_executor/cuda_graph_runner.py +9 -8
  36. sglang/srt/model_executor/model_runner.py +9 -6
  37. sglang/srt/model_loader/loader.py +11 -1
  38. sglang/srt/model_loader/weight_utils.py +6 -3
  39. sglang/srt/models/clip.py +563 -0
  40. sglang/srt/models/deepseek_janus_pro.py +2 -2
  41. sglang/srt/models/deepseek_v2.py +151 -26
  42. sglang/srt/models/gemma3_causal.py +12 -2
  43. sglang/srt/models/gemma3_mm.py +6 -0
  44. sglang/srt/openai_api/adapter.py +88 -87
  45. sglang/srt/openai_api/protocol.py +10 -5
  46. sglang/srt/patch_torch.py +71 -0
  47. sglang/srt/server_args.py +21 -11
  48. sglang/srt/speculative/eagle_worker.py +1 -1
  49. sglang/srt/utils.py +33 -0
  50. sglang/test/runners.py +27 -2
  51. sglang/test/test_utils.py +1 -1
  52. sglang/version.py +1 -1
  53. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +8 -4
  54. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +57 -53
  55. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +0 -0
  56. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/licenses/LICENSE +0 -0
  57. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import functools
4
4
  import logging
5
5
  from contextlib import contextmanager
6
- from typing import TYPE_CHECKING, Union
6
+ from typing import TYPE_CHECKING, List
7
7
 
8
8
  import torch
9
9
  import triton
@@ -249,3 +249,14 @@ def dp_scatter(
249
249
  memcpy_triton(
250
250
  local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
251
251
  )
252
+
253
+
254
+ def tp_reduce_scatter(
255
+ output: torch.Tensor,
256
+ input_list: List[torch.Tensor],
257
+ ):
258
+ return get_attention_tp_group().reduce_scatter(output, input_list)
259
+
260
+
261
+ def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
262
+ return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
@@ -17,12 +17,12 @@ from typing import Callable, Optional
17
17
  import torch
18
18
  import torch.nn.functional as F
19
19
 
20
+ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
20
21
  from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
21
22
 
22
23
  _is_cuda = is_cuda()
23
24
  _is_hip = is_hip()
24
25
 
25
- from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
26
26
 
27
27
  expert_distribution_recorder = ExpertDistributionRecorder()
28
28
 
@@ -129,8 +129,7 @@ def grouped_topk(
129
129
  return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
130
130
 
131
131
 
132
- @torch.compile(dynamic=True, backend=get_compiler_backend())
133
- def biased_grouped_topk(
132
+ def biased_grouped_topk_impl(
134
133
  hidden_states: torch.Tensor,
135
134
  gating_output: torch.Tensor,
136
135
  correction_bias: torch.Tensor,
@@ -171,6 +170,34 @@ def biased_grouped_topk(
171
170
  return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
172
171
 
173
172
 
173
+ def biased_grouped_topk(
174
+ hidden_states: torch.Tensor,
175
+ gating_output: torch.Tensor,
176
+ correction_bias: torch.Tensor,
177
+ topk: int,
178
+ renormalize: bool,
179
+ num_expert_group: int = 0,
180
+ topk_group: int = 0,
181
+ compiled: bool = True,
182
+ ):
183
+ biased_grouped_topk_fn = (
184
+ torch.compile(
185
+ biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
186
+ )
187
+ if compiled
188
+ else biased_grouped_topk_impl
189
+ )
190
+ return biased_grouped_topk_fn(
191
+ hidden_states,
192
+ gating_output,
193
+ correction_bias,
194
+ topk,
195
+ renormalize,
196
+ num_expert_group,
197
+ topk_group,
198
+ )
199
+
200
+
174
201
  def select_experts(
175
202
  hidden_states: torch.Tensor,
176
203
  router_logits: torch.Tensor,
@@ -9,13 +9,24 @@ import torch
9
9
 
10
10
  try:
11
11
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
12
- from vllm.model_executor.layers.quantization.awq import AWQConfig
13
- from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
12
+ from vllm.model_executor.layers.quantization.awq_marlin import (
13
+ AWQMarlinConfig,
14
+ AWQMoEMethod,
15
+ )
14
16
  from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
17
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
18
+ CompressedTensorsW8A8Fp8MoEMethod,
19
+ CompressedTensorsWNA16MoEMethod,
20
+ )
15
21
  from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
16
22
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
17
23
  from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
18
24
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
25
+ from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
26
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
27
+ GPTQMarlinLinearMethod,
28
+ GPTQMarlinMoEMethod,
29
+ )
19
30
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
20
31
  GPTQMarlin24Config,
21
32
  )
@@ -23,33 +34,39 @@ try:
23
34
  from vllm.model_executor.layers.quantization.qqq import QQQConfig
24
35
  from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
25
36
 
26
- from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
27
-
28
37
  VLLM_AVAILABLE = True
29
38
  except ImportError:
30
39
  VLLM_AVAILABLE = False
31
40
 
32
41
  # Define empty classes as placeholders when vllm is not available
33
42
  class DummyConfig:
34
- pass
43
+ def override_quantization_method(self, *args, **kwargs):
44
+ return None
45
+
46
+ AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
47
+ DeepSpeedFPConfig
48
+ ) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
49
+ MarlinConfig
50
+ ) = QQQConfig = Int8TpuConfig = DummyConfig
35
51
 
36
- AQLMConfig = AWQConfig = AWQMarlinConfig = BitsAndBytesConfig = (
37
- CompressedTensorsConfig
38
- ) = DummyConfig
39
- DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = (
40
- GPTQMarlin24Config
41
- ) = DummyConfig
42
- MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
43
52
 
53
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
54
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
55
+ from sglang.srt.layers.quantization.awq import AWQConfig
44
56
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
57
  from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
46
58
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
47
59
  CompressedTensorsConfig,
48
60
  )
49
61
  from sglang.srt.layers.quantization.fp8 import Fp8Config
62
+ from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
50
63
  from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
51
64
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
52
65
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
66
+ from sglang.srt.layers.vocab_parallel_embedding import (
67
+ ParallelLMHead,
68
+ UnquantizedEmbeddingMethod,
69
+ )
53
70
 
54
71
  # Base quantization methods that don't depend on vllm
55
72
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
@@ -61,26 +78,25 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
61
78
  "compressed-tensors": CompressedTensorsConfig,
62
79
  }
63
80
 
64
- # Add vllm-dependent methods if available
65
- QUANTIZATION_METHODS = BASE_QUANTIZATION_METHODS.copy()
66
- if VLLM_AVAILABLE:
67
- VLLM_QUANTIZATION_METHODS = {
68
- "aqlm": AQLMConfig,
69
- "awq": AWQConfig,
70
- "deepspeedfp": DeepSpeedFPConfig,
71
- "tpu_int8": Int8TpuConfig,
72
- "fbgemm_fp8": FBGEMMFp8Config,
73
- "marlin": MarlinConfig,
74
- "gguf": GGUFConfig,
75
- "gptq_marlin_24": GPTQMarlin24Config,
76
- "awq_marlin": AWQMarlinConfig,
77
- "bitsandbytes": BitsAndBytesConfig,
78
- "qqq": QQQConfig,
79
- "experts_int8": ExpertsInt8Config,
80
- "gptq_marlin": GPTQMarlinConfig,
81
- "gptq": GPTQConfig,
82
- }
83
- QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
81
+ # VLLM-dependent quantization methods
82
+ VLLM_QUANTIZATION_METHODS = {
83
+ "aqlm": AQLMConfig,
84
+ "awq": AWQConfig,
85
+ "deepspeedfp": DeepSpeedFPConfig,
86
+ "tpu_int8": Int8TpuConfig,
87
+ "fbgemm_fp8": FBGEMMFp8Config,
88
+ "marlin": MarlinConfig,
89
+ "gguf": GGUFConfig,
90
+ "gptq_marlin_24": GPTQMarlin24Config,
91
+ "awq_marlin": AWQMarlinConfig,
92
+ "bitsandbytes": BitsAndBytesConfig,
93
+ "qqq": QQQConfig,
94
+ "experts_int8": ExpertsInt8Config,
95
+ "gptq_marlin": GPTQMarlinConfig,
96
+ "gptq": GPTQConfig,
97
+ }
98
+
99
+ QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
84
100
 
85
101
 
86
102
  def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
@@ -89,6 +105,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
89
105
  f"Invalid quantization method: {quantization}. "
90
106
  f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
91
107
  )
108
+ if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
109
+ raise ValueError(
110
+ f"{quantization} quantization requires some operators from vllm. "
111
+ "Pleaes install vllm by `pip install vllm==0.7.2`"
112
+ )
113
+
92
114
  return QUANTIZATION_METHODS[quantization]
93
115
 
94
116
 
@@ -153,13 +175,6 @@ def get_linear_quant_method(
153
175
  prefix: str,
154
176
  linear_method_cls: type,
155
177
  ):
156
-
157
- from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
158
- from sglang.srt.layers.vocab_parallel_embedding import (
159
- ParallelLMHead,
160
- UnquantizedEmbeddingMethod,
161
- )
162
-
163
178
  cloned_config = deepcopy(config)
164
179
  parallel_lm_head_quantized = (
165
180
  isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
@@ -186,31 +201,17 @@ def get_linear_quant_method(
186
201
 
187
202
 
188
203
  def gptq_get_quant_method(self, layer, prefix):
189
- if not VLLM_AVAILABLE:
190
- return None
204
+ if isinstance(layer, FusedMoE):
205
+ return GPTQMarlinMoEMethod(self)
191
206
 
192
- try:
193
- from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
194
- from vllm.model_executor.layers.quantization.gptq_marlin import (
195
- GPTQMarlinLinearMethod,
196
- GPTQMarlinMoEMethod,
207
+ if isinstance(self, GPTQConfig):
208
+ return get_linear_quant_method(
209
+ self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
210
+ )
211
+ elif isinstance(self, GPTQMarlinConfig):
212
+ return get_linear_quant_method(
213
+ self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
197
214
  )
198
-
199
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
200
-
201
- if isinstance(layer, FusedMoE):
202
- return GPTQMarlinMoEMethod(self)
203
-
204
- if isinstance(self, GPTQConfig):
205
- return get_linear_quant_method(
206
- self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
207
- )
208
- elif isinstance(self, GPTQMarlinConfig):
209
- return get_linear_quant_method(
210
- self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
211
- )
212
- except ImportError:
213
- pass
214
215
  return None
215
216
 
216
217
 
@@ -229,33 +230,28 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
229
230
  builtins.isinstance = original_isinstance
230
231
  return
231
232
 
232
- try:
233
- from vllm.model_executor.layers.fused_moe import FusedMoE
234
- from vllm.model_executor.layers.linear import LinearBase
235
- from vllm.model_executor.layers.vocab_parallel_embedding import (
236
- VocabParallelEmbedding,
237
- )
233
+ from vllm.model_executor.layers.fused_moe import FusedMoE
234
+ from vllm.model_executor.layers.linear import LinearBase
235
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
236
+ VocabParallelEmbedding,
237
+ )
238
238
 
239
- from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
240
- from sglang.srt.layers.moe.fused_moe_triton.layer import (
241
- FusedMoE as PatchedFusedMoE,
242
- )
243
- from sglang.srt.layers.vocab_parallel_embedding import (
244
- VocabParallelEmbedding as PatchedVocabParallelEmbedding,
245
- )
239
+ from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
240
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
241
+ from sglang.srt.layers.vocab_parallel_embedding import (
242
+ VocabParallelEmbedding as PatchedVocabParallelEmbedding,
243
+ )
246
244
 
247
- def patched_isinstance(obj, classinfo):
248
- if classinfo is LinearBase:
249
- return original_isinstance(obj, PatchedLinearBase)
250
- if classinfo is FusedMoE:
251
- return original_isinstance(obj, PatchedFusedMoE)
252
- if classinfo is VocabParallelEmbedding:
253
- return original_isinstance(obj, PatchedVocabParallelEmbedding)
254
- return original_isinstance(obj, classinfo)
255
-
256
- builtins.isinstance = patched_isinstance
257
- except ImportError:
258
- return
245
+ def patched_isinstance(obj, classinfo):
246
+ if classinfo is LinearBase:
247
+ return original_isinstance(obj, PatchedLinearBase)
248
+ if classinfo is FusedMoE:
249
+ return original_isinstance(obj, PatchedFusedMoE)
250
+ if classinfo is VocabParallelEmbedding:
251
+ return original_isinstance(obj, PatchedVocabParallelEmbedding)
252
+ return original_isinstance(obj, classinfo)
253
+
254
+ builtins.isinstance = patched_isinstance
259
255
 
260
256
 
261
257
  def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
@@ -263,91 +259,64 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
263
259
  Monkey patch the apply function of vllm's FusedMoEMethodBase.
264
260
  Convert sglang arguments to vllm arguments.
265
261
  """
266
- if not VLLM_AVAILABLE:
267
- return
268
-
269
- try:
270
- original_apply = class_obj.apply
271
- sig = inspect.signature(original_apply)
272
- param_names = list(sig.parameters.keys())
273
- has_correction_bias = "e_score_correction_bias" in param_names
274
-
275
- def new_apply(
276
- self,
277
- layer: torch.nn.Module,
278
- x: torch.Tensor,
279
- router_logits: torch.Tensor,
280
- top_k: int,
281
- renormalize: bool,
282
- use_grouped_topk: bool,
283
- topk_group: Optional[int] = None,
284
- num_expert_group: Optional[int] = None,
285
- custom_routing_function: Optional[Callable] = None,
286
- correction_bias: Optional[torch.Tensor] = None,
287
- activation: str = "silu",
288
- inplace: bool = True,
289
- no_combine: bool = False,
290
- ):
291
- assert activation == "silu"
292
- assert inplace and not no_combine
293
-
294
- kwargs = {
295
- "self": self,
296
- "layer": layer,
297
- "x": x,
298
- "router_logits": router_logits,
299
- "top_k": top_k,
300
- "renormalize": renormalize,
301
- "use_grouped_topk": use_grouped_topk,
302
- "topk_group": topk_group,
303
- "num_expert_group": num_expert_group,
304
- "custom_routing_function": custom_routing_function,
305
- }
306
- if correction_bias is not None:
307
- if not has_correction_bias:
308
- raise ValueError(
309
- "Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
310
- )
311
- kwargs["e_score_correction_bias"] = correction_bias
312
- return original_apply(**kwargs)
313
-
314
- setattr(class_obj, "apply", new_apply)
315
- except (ImportError, AttributeError):
316
- return
262
+ original_apply = class_obj.apply
263
+ sig = inspect.signature(original_apply)
264
+ param_names = list(sig.parameters.keys())
265
+ has_correction_bias = "e_score_correction_bias" in param_names
266
+
267
+ def new_apply(
268
+ self,
269
+ layer: torch.nn.Module,
270
+ x: torch.Tensor,
271
+ router_logits: torch.Tensor,
272
+ top_k: int,
273
+ renormalize: bool,
274
+ use_grouped_topk: bool,
275
+ topk_group: Optional[int] = None,
276
+ num_expert_group: Optional[int] = None,
277
+ custom_routing_function: Optional[Callable] = None,
278
+ correction_bias: Optional[torch.Tensor] = None,
279
+ activation: str = "silu",
280
+ inplace: bool = True,
281
+ no_combine: bool = False,
282
+ ):
283
+ assert activation == "silu"
284
+ assert inplace and not no_combine
285
+
286
+ kwargs = {
287
+ "self": self,
288
+ "layer": layer,
289
+ "x": x,
290
+ "router_logits": router_logits,
291
+ "top_k": top_k,
292
+ "renormalize": renormalize,
293
+ "use_grouped_topk": use_grouped_topk,
294
+ "topk_group": topk_group,
295
+ "num_expert_group": num_expert_group,
296
+ "custom_routing_function": custom_routing_function,
297
+ }
298
+ if correction_bias is not None:
299
+ if not has_correction_bias:
300
+ raise ValueError(
301
+ "Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
302
+ )
303
+ kwargs["e_score_correction_bias"] = correction_bias
304
+ return original_apply(**kwargs)
305
+
306
+ setattr(class_obj, "apply", new_apply)
317
307
 
318
308
 
319
309
  def monkey_patch_quant_configs():
320
310
  """Apply all monkey patches in one place."""
321
- if not VLLM_AVAILABLE:
322
- return
311
+ setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
312
+ setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
323
313
 
324
- try:
325
- from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
326
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
327
- CompressedTensorsW8A8Fp8MoEMethod,
328
- CompressedTensorsWNA16MoEMethod,
329
- )
330
- from vllm.model_executor.layers.quantization.gptq_marlin import (
331
- GPTQMarlinMoEMethod,
332
- )
333
-
334
- setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
335
- setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
336
-
337
- monkey_patch_moe_apply(AWQMoEMethod)
338
- monkey_patch_moe_apply(GPTQMarlinMoEMethod)
339
- monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
340
- monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
341
- except ImportError:
342
- return
314
+ monkey_patch_moe_apply(AWQMoEMethod)
315
+ monkey_patch_moe_apply(GPTQMarlinMoEMethod)
316
+ monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
317
+ monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
343
318
 
344
319
 
345
320
  # Only apply monkey patches if vllm is available
346
321
  if VLLM_AVAILABLE:
347
322
  monkey_patch_quant_configs()
348
-
349
-
350
- __all__ = [
351
- "get_quantization_config",
352
- "QUANTIZATION_METHODS",
353
- ]
@@ -0,0 +1,200 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import logging
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import torch
6
+ from sgl_kernel import awq_dequantize
7
+
8
+ from sglang.srt.layers.linear import (
9
+ LinearBase,
10
+ LinearMethodBase,
11
+ UnquantizedLinearMethod,
12
+ )
13
+ from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
14
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
20
+ return any(module_name in prefix for module_name in modules_to_not_convert)
21
+
22
+
23
+ class AWQConfig(QuantizationConfig):
24
+ """Config class for AWQ.
25
+
26
+ Reference: https://arxiv.org/abs/2306.00978
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ weight_bits: int,
32
+ group_size: int,
33
+ zero_point: bool,
34
+ modules_to_not_convert: Optional[List[str]] = None,
35
+ ) -> None:
36
+ super().__init__()
37
+ self.weight_bits = weight_bits
38
+ self.group_size = group_size
39
+ self.zero_point = zero_point
40
+ self.modules_to_not_convert = modules_to_not_convert or []
41
+
42
+ if self.weight_bits != 4:
43
+ raise ValueError(
44
+ "Currently, only 4-bit weight quantization is supported for "
45
+ f"AWQ, but got {self.weight_bits} bits."
46
+ )
47
+ self.pack_factor = 32 // self.weight_bits
48
+
49
+ def __repr__(self) -> str:
50
+ return (
51
+ f"AWQConfig(weight_bits={self.weight_bits}, "
52
+ f"group_size={self.group_size}, "
53
+ f"zero_point={self.zero_point}, "
54
+ f"modules_to_not_convert={self.modules_to_not_convert})"
55
+ )
56
+
57
+ def get_scaled_act_names(self) -> List[str]:
58
+ return []
59
+
60
+ def get_name(self) -> str:
61
+ return "awq"
62
+
63
+ def get_supported_act_dtypes(self) -> List[torch.dtype]:
64
+ return [torch.half]
65
+
66
+ @classmethod
67
+ def get_min_capability(cls) -> int:
68
+ # The AWQ kernel only supports Turing or newer GPUs.
69
+ return 75
70
+
71
+ @staticmethod
72
+ def get_config_filenames() -> List[str]:
73
+ return [
74
+ "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
75
+ # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
76
+ "quantize_config.json",
77
+ ]
78
+
79
+ @classmethod
80
+ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
81
+ weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
82
+ group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
83
+ zero_point = cls.get_from_keys(config, ["zero_point"])
84
+ modules_to_not_convert = cls.get_from_keys_or(
85
+ config, ["modules_to_not_convert"], None
86
+ )
87
+ return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
88
+
89
+ def get_quant_method(
90
+ self, layer: torch.nn.Module, prefix: str
91
+ ) -> Optional["LinearMethodBase"]:
92
+
93
+ if isinstance(layer, LinearBase):
94
+ if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
95
+ return UnquantizedLinearMethod()
96
+ return AWQLinearMethod(self)
97
+ return None
98
+
99
+
100
+ class AWQLinearMethod(LinearMethodBase):
101
+ """Linear method for AWQ.
102
+
103
+ Args:
104
+ quant_config: The AWQ quantization config.
105
+ """
106
+
107
+ def __init__(self, quant_config: AWQConfig):
108
+ self.quant_config = quant_config
109
+
110
+ def create_weights(
111
+ self,
112
+ layer: torch.nn.Module,
113
+ input_size_per_partition: int,
114
+ output_partition_sizes: List[int],
115
+ input_size: int,
116
+ output_size: int,
117
+ params_dtype: torch.dtype,
118
+ **extra_weight_attrs,
119
+ ):
120
+ if input_size_per_partition % self.quant_config.group_size != 0:
121
+ raise ValueError(
122
+ "The input size is not aligned with the quantized "
123
+ "weight shape. This can be caused by too large "
124
+ "tensor parallel size."
125
+ )
126
+
127
+ output_size_per_partition = sum(output_partition_sizes)
128
+ if output_size_per_partition % self.quant_config.pack_factor != 0:
129
+ raise ValueError(
130
+ "The output size is not aligned with the quantized "
131
+ "weight shape. This can be caused by too large "
132
+ "tensor parallel size."
133
+ )
134
+
135
+ weight_loader = extra_weight_attrs.get("weight_loader")
136
+ qweight = PackedvLLMParameter(
137
+ data=torch.empty(
138
+ input_size_per_partition,
139
+ output_size_per_partition // self.quant_config.pack_factor,
140
+ dtype=torch.int32,
141
+ ),
142
+ input_dim=0,
143
+ output_dim=1,
144
+ packed_dim=1,
145
+ packed_factor=self.quant_config.pack_factor,
146
+ weight_loader=weight_loader,
147
+ )
148
+
149
+ qzeros = PackedvLLMParameter(
150
+ data=torch.empty(
151
+ input_size_per_partition // self.quant_config.group_size,
152
+ output_size_per_partition // self.quant_config.pack_factor,
153
+ dtype=torch.int32,
154
+ ),
155
+ input_dim=0,
156
+ output_dim=1,
157
+ packed_dim=1,
158
+ packed_factor=self.quant_config.pack_factor,
159
+ weight_loader=weight_loader,
160
+ )
161
+
162
+ scales = GroupQuantScaleParameter(
163
+ data=torch.empty(
164
+ input_size_per_partition // self.quant_config.group_size,
165
+ output_size_per_partition,
166
+ dtype=params_dtype,
167
+ ),
168
+ input_dim=0,
169
+ output_dim=1,
170
+ weight_loader=weight_loader,
171
+ )
172
+
173
+ layer.register_parameter("qweight", qweight)
174
+ layer.register_parameter("qzeros", qzeros)
175
+ layer.register_parameter("scales", scales)
176
+
177
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
178
+ layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
179
+ layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
180
+ layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
181
+
182
+ def apply(
183
+ self,
184
+ layer: torch.nn.Module,
185
+ x: torch.Tensor,
186
+ bias: Optional[torch.Tensor] = None,
187
+ ) -> torch.Tensor:
188
+ qweight = layer.qweight
189
+ scales = layer.scales
190
+ qzeros = layer.qzeros
191
+ pack_factor = self.quant_config.pack_factor
192
+ out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
193
+ reshaped_x = x.reshape(-1, x.shape[-1])
194
+
195
+ out = awq_dequantize(qweight, scales, qzeros)
196
+ out = torch.matmul(reshaped_x, out)
197
+
198
+ if bias is not None:
199
+ out.add_(bias)
200
+ return out.reshape(out_shape)