sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -6,53 +6,98 @@ from copy import deepcopy
6
6
  from typing import Callable, Dict, Optional, Type, Union
7
7
 
8
8
  import torch
9
- from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
10
- from vllm.model_executor.layers.quantization.awq import AWQConfig
11
- from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
12
- from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
13
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
14
- CompressedTensorsConfig,
15
- )
16
- from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
17
- from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
18
- from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
19
- from vllm.model_executor.layers.quantization.gguf import GGUFConfig
20
- from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
21
- from vllm.model_executor.layers.quantization.marlin import MarlinConfig
22
- from vllm.model_executor.layers.quantization.qqq import QQQConfig
23
- from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
24
9
 
10
+ try:
11
+ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
12
+ from vllm.model_executor.layers.quantization.awq_marlin import (
13
+ AWQMarlinConfig,
14
+ AWQMoEMethod,
15
+ )
16
+ from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
17
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
18
+ CompressedTensorsW8A8Fp8MoEMethod,
19
+ CompressedTensorsWNA16MoEMethod,
20
+ )
21
+ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
22
+ from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
23
+ from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
24
+ from vllm.model_executor.layers.quantization.gguf import GGUFConfig
25
+ from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
26
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
27
+ GPTQMarlinLinearMethod,
28
+ GPTQMarlinMoEMethod,
29
+ )
30
+ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
31
+ GPTQMarlin24Config,
32
+ )
33
+ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
34
+ from vllm.model_executor.layers.quantization.qqq import QQQConfig
35
+ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
36
+
37
+ VLLM_AVAILABLE = True
38
+ except ImportError:
39
+ VLLM_AVAILABLE = False
40
+
41
+ # Define empty classes as placeholders when vllm is not available
42
+ class DummyConfig:
43
+ def override_quantization_method(self, *args, **kwargs):
44
+ return None
45
+
46
+ AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
47
+ DeepSpeedFPConfig
48
+ ) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
49
+ MarlinConfig
50
+ ) = QQQConfig = Int8TpuConfig = DummyConfig
51
+
52
+
53
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
54
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
55
+ from sglang.srt.layers.quantization.awq import AWQConfig
25
56
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
57
  from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
58
+ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
59
+ CompressedTensorsConfig,
60
+ )
27
61
  from sglang.srt.layers.quantization.fp8 import Fp8Config
28
62
  from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
29
63
  from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
30
64
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
31
65
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
66
+ from sglang.srt.layers.vocab_parallel_embedding import (
67
+ ParallelLMHead,
68
+ UnquantizedEmbeddingMethod,
69
+ )
32
70
 
33
- QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
71
+ # Base quantization methods that don't depend on vllm
72
+ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
73
+ "fp8": Fp8Config,
74
+ "blockwise_int8": BlockInt8Config,
75
+ "modelopt": ModelOptFp8Config,
76
+ "w8a8_int8": W8A8Int8Config,
77
+ "w8a8_fp8": W8A8Fp8Config,
78
+ "compressed-tensors": CompressedTensorsConfig,
79
+ }
80
+
81
+ # VLLM-dependent quantization methods
82
+ VLLM_QUANTIZATION_METHODS = {
34
83
  "aqlm": AQLMConfig,
35
84
  "awq": AWQConfig,
36
85
  "deepspeedfp": DeepSpeedFPConfig,
37
86
  "tpu_int8": Int8TpuConfig,
38
- "fp8": Fp8Config,
39
- "blockwise_int8": BlockInt8Config,
40
87
  "fbgemm_fp8": FBGEMMFp8Config,
41
88
  "marlin": MarlinConfig,
42
- "modelopt": ModelOptFp8Config,
43
89
  "gguf": GGUFConfig,
44
90
  "gptq_marlin_24": GPTQMarlin24Config,
45
- "gptq_marlin": GPTQMarlinConfig,
46
91
  "awq_marlin": AWQMarlinConfig,
47
- "gptq": GPTQConfig,
48
- "compressed-tensors": CompressedTensorsConfig,
49
92
  "bitsandbytes": BitsAndBytesConfig,
50
93
  "qqq": QQQConfig,
51
94
  "experts_int8": ExpertsInt8Config,
52
- "w8a8_int8": W8A8Int8Config,
53
- "w8a8_fp8": W8A8Fp8Config,
95
+ "gptq_marlin": GPTQMarlinConfig,
96
+ "gptq": GPTQConfig,
54
97
  }
55
98
 
99
+ QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
100
+
56
101
 
57
102
  def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
58
103
  if quantization not in QUANTIZATION_METHODS:
@@ -60,6 +105,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
60
105
  f"Invalid quantization method: {quantization}. "
61
106
  f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
62
107
  )
108
+ if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
109
+ raise ValueError(
110
+ f"{quantization} quantization requires some operators from vllm. "
111
+ "Pleaes install vllm by `pip install vllm==0.7.2`"
112
+ )
113
+
63
114
  return QUANTIZATION_METHODS[quantization]
64
115
 
65
116
 
@@ -124,13 +175,6 @@ def get_linear_quant_method(
124
175
  prefix: str,
125
176
  linear_method_cls: type,
126
177
  ):
127
-
128
- from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
129
- from sglang.srt.layers.vocab_parallel_embedding import (
130
- ParallelLMHead,
131
- UnquantizedEmbeddingMethod,
132
- )
133
-
134
178
  cloned_config = deepcopy(config)
135
179
  parallel_lm_head_quantized = (
136
180
  isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
@@ -157,14 +201,6 @@ def get_linear_quant_method(
157
201
 
158
202
 
159
203
  def gptq_get_quant_method(self, layer, prefix):
160
- from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
161
- from vllm.model_executor.layers.quantization.gptq_marlin import (
162
- GPTQMarlinLinearMethod,
163
- GPTQMarlinMoEMethod,
164
- )
165
-
166
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
167
-
168
204
  if isinstance(layer, FusedMoE):
169
205
  return GPTQMarlinMoEMethod(self)
170
206
 
@@ -187,6 +223,8 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
187
223
  Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
188
224
  can recognize sglang layers
189
225
  """
226
+ if not VLLM_AVAILABLE:
227
+ return
190
228
 
191
229
  if reverse:
192
230
  builtins.isinstance = original_isinstance
@@ -270,13 +308,6 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
270
308
 
271
309
  def monkey_patch_quant_configs():
272
310
  """Apply all monkey patches in one place."""
273
- from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
274
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
275
- CompressedTensorsW8A8Fp8MoEMethod,
276
- CompressedTensorsWNA16MoEMethod,
277
- )
278
- from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinMoEMethod
279
-
280
311
  setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
281
312
  setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
282
313
 
@@ -286,10 +317,6 @@ def monkey_patch_quant_configs():
286
317
  monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
287
318
 
288
319
 
289
- monkey_patch_quant_configs()
290
-
291
-
292
- __all__ = [
293
- "get_quantization_config",
294
- "QUANTIZATION_METHODS",
295
- ]
320
+ # Only apply monkey patches if vllm is available
321
+ if VLLM_AVAILABLE:
322
+ monkey_patch_quant_configs()
@@ -0,0 +1,200 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import logging
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import torch
6
+ from sgl_kernel import awq_dequantize
7
+
8
+ from sglang.srt.layers.linear import (
9
+ LinearBase,
10
+ LinearMethodBase,
11
+ UnquantizedLinearMethod,
12
+ )
13
+ from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
14
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
20
+ return any(module_name in prefix for module_name in modules_to_not_convert)
21
+
22
+
23
+ class AWQConfig(QuantizationConfig):
24
+ """Config class for AWQ.
25
+
26
+ Reference: https://arxiv.org/abs/2306.00978
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ weight_bits: int,
32
+ group_size: int,
33
+ zero_point: bool,
34
+ modules_to_not_convert: Optional[List[str]] = None,
35
+ ) -> None:
36
+ super().__init__()
37
+ self.weight_bits = weight_bits
38
+ self.group_size = group_size
39
+ self.zero_point = zero_point
40
+ self.modules_to_not_convert = modules_to_not_convert or []
41
+
42
+ if self.weight_bits != 4:
43
+ raise ValueError(
44
+ "Currently, only 4-bit weight quantization is supported for "
45
+ f"AWQ, but got {self.weight_bits} bits."
46
+ )
47
+ self.pack_factor = 32 // self.weight_bits
48
+
49
+ def __repr__(self) -> str:
50
+ return (
51
+ f"AWQConfig(weight_bits={self.weight_bits}, "
52
+ f"group_size={self.group_size}, "
53
+ f"zero_point={self.zero_point}, "
54
+ f"modules_to_not_convert={self.modules_to_not_convert})"
55
+ )
56
+
57
+ def get_scaled_act_names(self) -> List[str]:
58
+ return []
59
+
60
+ def get_name(self) -> str:
61
+ return "awq"
62
+
63
+ def get_supported_act_dtypes(self) -> List[torch.dtype]:
64
+ return [torch.half]
65
+
66
+ @classmethod
67
+ def get_min_capability(cls) -> int:
68
+ # The AWQ kernel only supports Turing or newer GPUs.
69
+ return 75
70
+
71
+ @staticmethod
72
+ def get_config_filenames() -> List[str]:
73
+ return [
74
+ "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
75
+ # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
76
+ "quantize_config.json",
77
+ ]
78
+
79
+ @classmethod
80
+ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
81
+ weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
82
+ group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
83
+ zero_point = cls.get_from_keys(config, ["zero_point"])
84
+ modules_to_not_convert = cls.get_from_keys_or(
85
+ config, ["modules_to_not_convert"], None
86
+ )
87
+ return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
88
+
89
+ def get_quant_method(
90
+ self, layer: torch.nn.Module, prefix: str
91
+ ) -> Optional["LinearMethodBase"]:
92
+
93
+ if isinstance(layer, LinearBase):
94
+ if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
95
+ return UnquantizedLinearMethod()
96
+ return AWQLinearMethod(self)
97
+ return None
98
+
99
+
100
+ class AWQLinearMethod(LinearMethodBase):
101
+ """Linear method for AWQ.
102
+
103
+ Args:
104
+ quant_config: The AWQ quantization config.
105
+ """
106
+
107
+ def __init__(self, quant_config: AWQConfig):
108
+ self.quant_config = quant_config
109
+
110
+ def create_weights(
111
+ self,
112
+ layer: torch.nn.Module,
113
+ input_size_per_partition: int,
114
+ output_partition_sizes: List[int],
115
+ input_size: int,
116
+ output_size: int,
117
+ params_dtype: torch.dtype,
118
+ **extra_weight_attrs,
119
+ ):
120
+ if input_size_per_partition % self.quant_config.group_size != 0:
121
+ raise ValueError(
122
+ "The input size is not aligned with the quantized "
123
+ "weight shape. This can be caused by too large "
124
+ "tensor parallel size."
125
+ )
126
+
127
+ output_size_per_partition = sum(output_partition_sizes)
128
+ if output_size_per_partition % self.quant_config.pack_factor != 0:
129
+ raise ValueError(
130
+ "The output size is not aligned with the quantized "
131
+ "weight shape. This can be caused by too large "
132
+ "tensor parallel size."
133
+ )
134
+
135
+ weight_loader = extra_weight_attrs.get("weight_loader")
136
+ qweight = PackedvLLMParameter(
137
+ data=torch.empty(
138
+ input_size_per_partition,
139
+ output_size_per_partition // self.quant_config.pack_factor,
140
+ dtype=torch.int32,
141
+ ),
142
+ input_dim=0,
143
+ output_dim=1,
144
+ packed_dim=1,
145
+ packed_factor=self.quant_config.pack_factor,
146
+ weight_loader=weight_loader,
147
+ )
148
+
149
+ qzeros = PackedvLLMParameter(
150
+ data=torch.empty(
151
+ input_size_per_partition // self.quant_config.group_size,
152
+ output_size_per_partition // self.quant_config.pack_factor,
153
+ dtype=torch.int32,
154
+ ),
155
+ input_dim=0,
156
+ output_dim=1,
157
+ packed_dim=1,
158
+ packed_factor=self.quant_config.pack_factor,
159
+ weight_loader=weight_loader,
160
+ )
161
+
162
+ scales = GroupQuantScaleParameter(
163
+ data=torch.empty(
164
+ input_size_per_partition // self.quant_config.group_size,
165
+ output_size_per_partition,
166
+ dtype=params_dtype,
167
+ ),
168
+ input_dim=0,
169
+ output_dim=1,
170
+ weight_loader=weight_loader,
171
+ )
172
+
173
+ layer.register_parameter("qweight", qweight)
174
+ layer.register_parameter("qzeros", qzeros)
175
+ layer.register_parameter("scales", scales)
176
+
177
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
178
+ layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
179
+ layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
180
+ layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
181
+
182
+ def apply(
183
+ self,
184
+ layer: torch.nn.Module,
185
+ x: torch.Tensor,
186
+ bias: Optional[torch.Tensor] = None,
187
+ ) -> torch.Tensor:
188
+ qweight = layer.qweight
189
+ scales = layer.scales
190
+ qzeros = layer.qzeros
191
+ pack_factor = self.quant_config.pack_factor
192
+ out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
193
+ reshaped_x = x.reshape(-1, x.shape[-1])
194
+
195
+ out = awq_dequantize(qweight, scales, qzeros)
196
+ out = torch.matmul(reshaped_x, out)
197
+
198
+ if bias is not None:
199
+ out.add_(bias)
200
+ return out.reshape(out_shape)
@@ -38,6 +38,11 @@ class QuantizeMethodBase(ABC):
38
38
  class QuantizationConfig(ABC):
39
39
  """Base class for quantization configs."""
40
40
 
41
+ def __init__(self):
42
+ super().__init__()
43
+ # mapping is updated by models as they initialize
44
+ self.packed_modules_mapping: Dict[str, List[str]] = dict()
45
+
41
46
  @abstractmethod
42
47
  def get_name(self) -> str:
43
48
  """Name of the quantization method."""
@@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional
5
5
 
6
6
  import torch
7
7
  from torch.nn import Module
8
- from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
9
8
 
10
9
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
11
10
  from sglang.srt.layers.linear import (
@@ -19,6 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
19
18
  QuantizeMethodBase,
20
19
  )
21
20
  from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
21
+ from sglang.srt.layers.quantization.utils import is_layer_skipped
22
22
  from sglang.srt.utils import set_weight_attrs
23
23
 
24
24
  ACTIVATION_SCHEMES = ["static", "dynamic"]