sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,652 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import logging
5
+ from contextlib import suppress
6
+ from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, cast
7
+
8
+ import torch
9
+ from compressed_tensors.config import (
10
+ CompressionFormat,
11
+ SparsityCompressionConfig,
12
+ SparsityStructure,
13
+ )
14
+ from compressed_tensors.quantization import (
15
+ QuantizationArgs,
16
+ QuantizationStrategy,
17
+ QuantizationType,
18
+ )
19
+ from pydantic import BaseModel
20
+
21
+ from sglang.srt.layers.linear import (
22
+ LinearBase,
23
+ LinearMethodBase,
24
+ UnquantizedLinearMethod,
25
+ )
26
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
27
+ from sglang.srt.layers.quantization.base_config import (
28
+ QuantizationConfig,
29
+ QuantizeMethodBase,
30
+ )
31
+ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
32
+ CompressedTensorsMoEMethod,
33
+ )
34
+ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
35
+ CompressedTensorsScheme,
36
+ CompressedTensorsW8A8Fp8,
37
+ )
38
+ from sglang.srt.layers.quantization.compressed_tensors.utils import (
39
+ find_matched_target,
40
+ is_activation_quantization_format,
41
+ should_ignore_layer,
42
+ )
43
+ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ __all__ = ["CompressedTensorsLinearMethod"]
48
+
49
+ SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
50
+ QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]
51
+
52
+
53
+ class DeviceCapability(NamedTuple):
54
+ major: int
55
+ minor: int
56
+
57
+ def as_version_str(self) -> str:
58
+ return f"{self.major}.{self.minor}"
59
+
60
+ def to_int(self) -> int:
61
+ """
62
+ Express device capability as an integer ``<major><minor>``.
63
+
64
+ It is assumed that the minor version is always a single digit.
65
+ """
66
+ assert 0 <= self.minor < 10
67
+ return self.major * 10 + self.minor
68
+
69
+
70
+ class CompressedTensorsConfig(QuantizationConfig):
71
+
72
+ def __init__(
73
+ self,
74
+ target_scheme_map: Dict[str, Any],
75
+ ignore: List[str],
76
+ quant_format: str,
77
+ sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
78
+ sparsity_ignore_list: List[str],
79
+ kv_cache_scheme: Optional[Dict[str, Any]] = None,
80
+ config: Optional[Dict[str, Any]] = None,
81
+ ):
82
+ super().__init__()
83
+ self.ignore = ignore
84
+ self.quant_format = quant_format
85
+ # Map from [target -> scheme]
86
+ self.target_scheme_map = target_scheme_map
87
+ self.kv_cache_scheme = kv_cache_scheme
88
+ self.sparsity_scheme_map = sparsity_scheme_map
89
+ self.sparsity_ignore_list = sparsity_ignore_list
90
+ self.config = config
91
+
92
+ def get_linear_method(self) -> "CompressedTensorsLinearMethod":
93
+ return CompressedTensorsLinearMethod(self)
94
+
95
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
96
+ return [torch.float16, torch.bfloat16]
97
+
98
+ @classmethod
99
+ def get_min_capability(cls) -> int:
100
+ return 70
101
+
102
+ def get_name(self) -> str:
103
+ return "compressed_tensors"
104
+
105
+ def get_scaled_act_names(self) -> List[str]:
106
+ return []
107
+
108
+ def get_quant_method(
109
+ self,
110
+ layer: torch.nn.Module,
111
+ prefix: str,
112
+ ) -> Optional["QuantizeMethodBase"]:
113
+
114
+ # Check if the layer is skipped for quantization.
115
+ # TODO (@robertgshaw2): support module names
116
+ if should_ignore_layer(
117
+ prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
118
+ ):
119
+ return UnquantizedLinearMethod()
120
+ if isinstance(layer, LinearBase):
121
+ scheme = self.get_scheme(layer=layer, layer_name=prefix)
122
+ if scheme is None:
123
+ return UnquantizedLinearMethod()
124
+ layer.scheme = scheme
125
+ return CompressedTensorsLinearMethod(self)
126
+ if isinstance(layer, FusedMoE):
127
+ return CompressedTensorsMoEMethod.get_moe_method(self)
128
+ return None
129
+
130
+ @classmethod
131
+ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
132
+ ignore: List[str] = cast(List[str], config.get("ignore", []))
133
+ quant_format = cast(str, config.get("format"))
134
+ target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
135
+ sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
136
+ config=config
137
+ )
138
+
139
+ return cls(
140
+ target_scheme_map=target_scheme_map,
141
+ ignore=ignore,
142
+ quant_format=quant_format,
143
+ sparsity_scheme_map=sparsity_scheme_map,
144
+ sparsity_ignore_list=sparsity_ignore_list,
145
+ config=config,
146
+ )
147
+
148
+ @classmethod
149
+ def _parse_sparsity_config(
150
+ cls, config: Dict[str, Any]
151
+ ) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]:
152
+ """
153
+ :param config: The `quantization_config` dictionary from config.json
154
+ :return: A tuple with two elements
155
+ 1. A dictionary mapping target layer names to their corresponding
156
+ sparsity_config
157
+ 2. A list of layer names to ignore for sparsity
158
+ """
159
+ if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)):
160
+ return dict(), []
161
+
162
+ sparsity_config = SparsityCompressionConfig.model_validate(sparsity_config)
163
+ sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
164
+ target: sparsity_config for target in sparsity_config.targets or list()
165
+ }
166
+ sparsity_ignore_list = sparsity_config.ignore or list()
167
+ return sparse_scheme_map, sparsity_ignore_list
168
+
169
+ @classmethod
170
+ def _quantization_scheme_map_from_config(
171
+ cls, config: Dict[str, Any]
172
+ ) -> QUANTIZATION_SCHEME_MAP_TYPE:
173
+ """
174
+ :param config: The `quantization_config` dictionary from config.json
175
+ :return: A dictionary mapping target layer names to their corresponding
176
+ quantization_args for weights and input activations
177
+ """
178
+ target_scheme_map: Dict[str, Any] = dict()
179
+ quant_format = cast(str, config.get("format"))
180
+
181
+ # The quant_config has multiple config_groups, each containing
182
+ # an input_activations key with details about how the activations are
183
+ # quantized, a weights key indicating how the weights are quantized,
184
+ # and a list of targets under the `targets` key, dictating which
185
+ # layers are impacted by the quantization details. The quantization
186
+ # details follow the structure defined by the QuantizationArgs
187
+ # pydantic model, which is used to verify the structure of the
188
+ # quant_config and also store the details for later use.
189
+
190
+ config_groups = config.get("config_groups", dict())
191
+ for _, quant_config in config_groups.items():
192
+ targets = quant_config.get("targets")
193
+ for target in targets:
194
+ target_scheme_map[target] = {}
195
+ target_scheme_map[target]["weights"] = QuantizationArgs.model_validate(
196
+ quant_config.get("weights")
197
+ )
198
+
199
+ target_scheme_map[target]["input_activations"] = None
200
+ if is_activation_quantization_format(quant_format):
201
+ input_activations = quant_config.get("input_activations")
202
+ # The only case where we have activation quant supported
203
+ # but no input_activations provided in the config
204
+ # should be w8a16fp8 w8a16fp8 can also run for cases where
205
+ # there is an input_quant but it is ignored
206
+ if not input_activations:
207
+ assert (
208
+ target_scheme_map[target]["weights"].type
209
+ == QuantizationType.FLOAT
210
+ )
211
+ else:
212
+ target_scheme_map[target]["input_activations"] = (
213
+ QuantizationArgs.model_validate( # noqa: E501
214
+ quant_config.get("input_activations")
215
+ )
216
+ )
217
+ return target_scheme_map
218
+
219
+ @classmethod
220
+ def get_config_filenames(cls) -> List[str]:
221
+ return []
222
+
223
+ def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool:
224
+ capability_tuple = DeviceCapability(*torch.cuda.get_device_capability())
225
+
226
+ if capability_tuple is not None:
227
+ capability = capability_tuple.to_int()
228
+ supported = capability >= min_capability
229
+ if error and not supported:
230
+ raise RuntimeError(
231
+ "Quantization scheme is not supported for ",
232
+ f"the current GPU. Min capability: {min_capability}. ",
233
+ f"Current capability: {capability}.",
234
+ )
235
+ return supported
236
+ else:
237
+ return False
238
+
239
+ def _is_static_tensor_w8a8(
240
+ self, weight_quant: BaseModel, input_quant: BaseModel
241
+ ) -> bool:
242
+ is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
243
+ weight_strategy = (
244
+ weight_quant.strategy == QuantizationStrategy.TENSOR.value
245
+ or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
246
+ )
247
+ is_tensor = (
248
+ weight_strategy
249
+ and input_quant.strategy == QuantizationStrategy.TENSOR.value
250
+ )
251
+ is_static = not weight_quant.dynamic and not input_quant.dynamic
252
+
253
+ # Both symmetric and asymmetric input quantization supported.
254
+ # Only symmetric weight quantization supported.
255
+ return is_8_bits and is_tensor and weight_quant.symmetric and is_static
256
+
257
+ def _is_dynamic_token_w8a8(
258
+ self, weight_quant: BaseModel, input_quant: BaseModel
259
+ ) -> bool:
260
+ is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
261
+ weight_strategy = (
262
+ weight_quant.strategy == QuantizationStrategy.TENSOR.value
263
+ or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
264
+ )
265
+ is_token = (
266
+ weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
267
+ )
268
+ is_dynamic = not weight_quant.dynamic and input_quant.dynamic
269
+
270
+ # Both symmetric and asymmetric input quantization supported.
271
+ # Only symmetric weight quantization supported.
272
+ return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
273
+
274
+ def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool:
275
+ # Confirm weights and activations quantized.
276
+ if weight_quant is None or input_quant is None:
277
+ return False
278
+
279
+ # Confirm weight scheme is supported.
280
+ is_floating_point = (
281
+ weight_quant.type == QuantizationType.FLOAT
282
+ and input_quant.type == QuantizationType.FLOAT
283
+ )
284
+ is_symmetric_weight = weight_quant.symmetric
285
+ is_static_weight = not weight_quant.dynamic
286
+ is_per_tensor_or_channel_weight = weight_quant.strategy in [
287
+ QuantizationStrategy.TENSOR,
288
+ QuantizationStrategy.CHANNEL,
289
+ ]
290
+ if not (
291
+ is_floating_point
292
+ and is_symmetric_weight
293
+ and is_static_weight
294
+ and is_per_tensor_or_channel_weight
295
+ ):
296
+ return False
297
+
298
+ # Dynamic quantization is always supported if weights supported.
299
+ if input_quant.dynamic:
300
+ return True
301
+
302
+ # Confirm activation scheme is supported.
303
+ is_symmetric_activation = input_quant.symmetric
304
+ is_per_tensor_activation = input_quant.strategy == QuantizationStrategy.TENSOR
305
+ return is_symmetric_activation and is_per_tensor_activation
306
+
307
+ def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool:
308
+ # Confirm weights quantized.
309
+ if weight_quant is None:
310
+ return False
311
+
312
+ # Confirm we have floating points.
313
+ if weight_quant.type != QuantizationType.FLOAT:
314
+ return False
315
+
316
+ # Confirm weight scheme is supported.
317
+ is_symmetric_weight = weight_quant.symmetric
318
+ is_static_weight = not weight_quant.dynamic
319
+ is_per_tensor_or_channel_weight = weight_quant.strategy in [
320
+ QuantizationStrategy.TENSOR,
321
+ QuantizationStrategy.CHANNEL,
322
+ ]
323
+ if not (
324
+ is_symmetric_weight
325
+ and is_static_weight # noqa: SIM103
326
+ and is_per_tensor_or_channel_weight
327
+ ):
328
+ return False
329
+
330
+ # All conditions satisfied.
331
+ return True
332
+
333
+ def _is_wNa16_group_channel(
334
+ self, weight_quant: BaseModel, input_quant: BaseModel
335
+ ) -> bool:
336
+ input_quant_none = input_quant is None
337
+ is_symmetric = weight_quant.symmetric
338
+ is_channel_group = (
339
+ weight_quant.strategy == QuantizationStrategy.CHANNEL.value
340
+ or weight_quant.strategy == QuantizationStrategy.GROUP.value
341
+ )
342
+ is_static = not weight_quant.dynamic
343
+
344
+ return is_channel_group and input_quant_none and is_symmetric and is_static
345
+
346
+ def _get_scheme_from_parts(
347
+ self, weight_quant: BaseModel, input_quant: BaseModel
348
+ ) -> "CompressedTensorsScheme":
349
+
350
+ # Detect If Mixed Precision
351
+ if self._is_wNa16_group_channel(weight_quant, input_quant):
352
+ if not VLLM_AVAILABLE:
353
+ raise ImportError(
354
+ "vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm"
355
+ )
356
+ if (
357
+ self.quant_format == CompressionFormat.marlin_24.value
358
+ and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS
359
+ ):
360
+ return CompressedTensorsW4A16Sparse24(
361
+ strategy=weight_quant.strategy,
362
+ num_bits=weight_quant.num_bits,
363
+ group_size=weight_quant.group_size,
364
+ )
365
+ if (
366
+ self.quant_format == CompressionFormat.pack_quantized.value
367
+ and weight_quant.num_bits in WNA16_SUPPORTED_BITS
368
+ ):
369
+ return CompressedTensorsWNA16(
370
+ num_bits=weight_quant.num_bits,
371
+ strategy=weight_quant.strategy,
372
+ group_size=weight_quant.group_size,
373
+ actorder=weight_quant.actorder,
374
+ )
375
+
376
+ if is_activation_quantization_format(self.quant_format):
377
+ if self._is_fp8_w8a8(weight_quant, input_quant):
378
+ is_fp8_w8a8_supported = self._check_scheme_supported(
379
+ CompressedTensorsW8A8Fp8.get_min_capability(), error=False
380
+ )
381
+ if is_fp8_w8a8_supported:
382
+ return CompressedTensorsW8A8Fp8(
383
+ strategy=weight_quant.strategy,
384
+ is_static_input_scheme=(
385
+ input_quant and not input_quant.dynamic
386
+ ),
387
+ )
388
+ else:
389
+ # note: input_quant will be present for converted models;
390
+ # will be ignored during inference post loading
391
+ return CompressedTensorsW8A16Fp8(
392
+ strategy=weight_quant.strategy,
393
+ is_static_input_scheme=not input_quant.dynamic,
394
+ )
395
+
396
+ # note: input_quant can be None
397
+ if self._is_fp8_w8a16(weight_quant, input_quant):
398
+ if not VLLM_AVAILABLE:
399
+ raise ImportError(
400
+ "vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm"
401
+ )
402
+ is_static_input_scheme = input_quant and not input_quant.dynamic
403
+ return CompressedTensorsW8A16Fp8(
404
+ strategy=weight_quant.strategy,
405
+ is_static_input_scheme=is_static_input_scheme,
406
+ )
407
+
408
+ if self._is_static_tensor_w8a8(weight_quant, input_quant):
409
+ return CompressedTensorsW8A8Int8(
410
+ strategy=weight_quant.strategy,
411
+ is_static_input_scheme=True,
412
+ input_symmetric=input_quant.symmetric,
413
+ )
414
+
415
+ if self._is_dynamic_token_w8a8(weight_quant, input_quant):
416
+ return CompressedTensorsW8A8Int8(
417
+ strategy=weight_quant.strategy,
418
+ is_static_input_scheme=False,
419
+ input_symmetric=input_quant.symmetric,
420
+ )
421
+
422
+ raise NotImplementedError("No compressed-tensors compatible scheme was found.")
423
+
424
+ def get_scheme(
425
+ self, layer: torch.nn.Module, layer_name: Optional[str] = None
426
+ ) -> Optional["CompressedTensorsScheme"]:
427
+ """
428
+ compressed-tensors supports non uniform in the following way:
429
+
430
+ targets of config_groups: There can be N config_groups which each
431
+ have a quantization scheme. Each config_group has a list of targets
432
+ which can be a full layer_name, a regex for a layer_name, or
433
+ an nn.Module name.
434
+
435
+ Detect whether a layer_name is found in any target and
436
+ use the quantization scheme corresponding to the matched target
437
+ to select the CompressedTensorsScheme used for infernece.
438
+ """
439
+
440
+ # Find the "target" in the compressed-tensors config
441
+ # that our layer conforms to.
442
+ # TODO (@robertgshaw): add compressed-tensors as dep
443
+ # so we do not have to re-write these functions
444
+ # need to make accelerate optional in ct to do this
445
+
446
+ # Will be empty for models with only sparsity
447
+ weight_quant = input_quant = None
448
+ if self.target_scheme_map:
449
+ matched_target = find_matched_target(
450
+ layer_name=layer_name,
451
+ module=layer,
452
+ targets=self.target_scheme_map.keys(),
453
+ fused_mapping=self.packed_modules_mapping,
454
+ )
455
+
456
+ scheme_dict = self.target_scheme_map[matched_target]
457
+ weight_quant = scheme_dict.get("weights")
458
+ input_quant = scheme_dict.get("input_activations")
459
+
460
+ # Find the sparsity scheme of the layer
461
+ # assume that fused layers inerhit first component's sparsity scheme
462
+ sparsity_targets = self.sparsity_scheme_map.keys() - set(
463
+ self.sparsity_ignore_list
464
+ )
465
+ sparsity_scheme: Optional[SparsityCompressionConfig] = None
466
+ with suppress(ValueError):
467
+ matched_target = find_matched_target(
468
+ layer_name=layer_name,
469
+ module=layer,
470
+ targets=sparsity_targets,
471
+ fused_mapping=self.packed_modules_mapping,
472
+ )
473
+ sparsity_scheme = self.sparsity_scheme_map[matched_target]
474
+
475
+ if self.supports_cutlass_24(
476
+ weight_quant=weight_quant,
477
+ input_quant=input_quant,
478
+ sparsity_scheme=sparsity_scheme,
479
+ ):
480
+ if not VLLM_AVAILABLE:
481
+ raise ImportError(
482
+ "vllm is not installed, to use CompressedTensors24, please install vllm"
483
+ )
484
+ # Have a valid sparsity scheme
485
+ # Validate layer is supported by Cutlass 2:4 Kernel
486
+ model_compression_config = (
487
+ None
488
+ if sparsity_scheme is None or sparsity_scheme.format == "dense"
489
+ else self.config
490
+ )
491
+
492
+ scheme = CompressedTensors24(
493
+ quantized=weight_quant is not None or input_quant is not None,
494
+ weight_quant=weight_quant,
495
+ input_quant=input_quant,
496
+ model_compression_config=model_compression_config,
497
+ )
498
+ elif weight_quant is None:
499
+ logger.warning_once(
500
+ "Acceleration for non-quantized schemes is "
501
+ "not supported by Compressed Tensors. "
502
+ "Falling back to UnquantizedLinearMethod"
503
+ )
504
+ return None
505
+
506
+ else:
507
+ # Find the quant_scheme
508
+ scheme = self._get_scheme_from_parts( # type: ignore
509
+ weight_quant=weight_quant,
510
+ input_quant=input_quant,
511
+ )
512
+
513
+ # Raise error if device does not support the scheme
514
+ # (e.g. fp8 needs ada lovelace)
515
+ self._check_scheme_supported(scheme.get_min_capability())
516
+ logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name)
517
+ return scheme
518
+
519
+ def get_cache_scale(self, name: str) -> Optional[str]:
520
+ """
521
+ Check whether the param name matches the format for k/v cache scales
522
+ in compressed-tensors. If this is the case, return its equivalent
523
+ param name expected by vLLM
524
+
525
+ :param name: param name
526
+ :return: matching param name for KV cache scale in vLLM
527
+ """
528
+ if name.endswith(".output_scale") and ".k_proj" in name:
529
+ return name.replace(".k_proj.output_scale", ".attn.k_scale")
530
+ if name.endswith(".output_scale") and ".v_proj" in name:
531
+ return name.replace(".v_proj.output_scale", ".attn.v_scale")
532
+ # If no matches, return None
533
+ return None
534
+
535
+ @staticmethod
536
+ def supports_cutlass_24(
537
+ weight_quant: Optional[QuantizationArgs],
538
+ input_quant: Optional[QuantizationArgs],
539
+ sparsity_scheme: Optional[SparsityCompressionConfig] = None,
540
+ ) -> bool:
541
+ """
542
+ Check if the layer is supported by the Cutlass 2:4 Kernel
543
+ Conditions:
544
+ - Overarching condition: Sparsity Structure is 2:4
545
+ - Unquantized cases are supported
546
+ - Weight only quantization is not-supported
547
+ - Supported weight quantization strategies are TENSOR and CHANNEL
548
+ - Supported input quantization strategies are TENSOR and TOKEN
549
+ - Only 8 bit quantization is supported
550
+
551
+ :return: True if the layer is supported by the Cutlass 2:4 Kernel
552
+ False otherwise
553
+ """
554
+ if sparsity_scheme is None:
555
+ return False
556
+
557
+ is_valid_sparsity_structure: bool = (
558
+ sparsity_scheme.sparsity_structure == SparsityStructure.TWO_FOUR.value
559
+ )
560
+
561
+ valid_compressors = {
562
+ CompressionFormat.dense.value,
563
+ CompressionFormat.sparse_24_bitmask.value,
564
+ }
565
+
566
+ is_valid_sparsity = (
567
+ is_valid_sparsity_structure and sparsity_scheme.format in valid_compressors
568
+ )
569
+
570
+ if not is_valid_sparsity:
571
+ return False
572
+
573
+ # Unquantized cases are supported
574
+ if weight_quant is None and input_quant is None:
575
+ return True
576
+
577
+ # Weight only quantization is not-supported
578
+ if weight_quant is not None and input_quant is None:
579
+ return False
580
+
581
+ supported_weight_quant_strategies = [
582
+ QuantizationStrategy.TENSOR.value,
583
+ QuantizationStrategy.CHANNEL.value,
584
+ ]
585
+
586
+ assert weight_quant is not None
587
+ assert input_quant is not None
588
+ if weight_quant.strategy not in supported_weight_quant_strategies:
589
+ return False
590
+
591
+ supported_input_quant_strategies = [
592
+ QuantizationStrategy.TENSOR.value,
593
+ QuantizationStrategy.TOKEN.value,
594
+ ]
595
+
596
+ if input_quant.strategy not in supported_input_quant_strategies:
597
+ return False
598
+
599
+ return weight_quant.num_bits == input_quant.num_bits == 8
600
+
601
+
602
+ class CompressedTensorsLinearMethod(LinearMethodBase):
603
+
604
+ def __init__(self, quantization_config: CompressedTensorsConfig):
605
+ self.quantization_config = quantization_config
606
+
607
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
608
+ layer.scheme.process_weights_after_loading(layer)
609
+
610
+ def create_weights(
611
+ self,
612
+ layer: torch.nn.Module,
613
+ input_size_per_partition: int,
614
+ output_partition_sizes: List[int],
615
+ input_size: int,
616
+ output_size: int,
617
+ params_dtype: torch.dtype,
618
+ **extra_weight_attrs,
619
+ ):
620
+ """
621
+ Use the CompressedTensorsScheme associated with each layer to create
622
+ the necessary parameters for the layer. See LinearMethodBase for param
623
+ details
624
+ """
625
+ weight_loader = extra_weight_attrs.get("weight_loader")
626
+ layer.scheme.create_weights(
627
+ layer=layer,
628
+ input_size=input_size,
629
+ input_size_per_partition=input_size_per_partition,
630
+ output_partition_sizes=output_partition_sizes,
631
+ output_size=output_size,
632
+ params_dtype=params_dtype,
633
+ weight_loader=weight_loader,
634
+ )
635
+
636
+ def apply(
637
+ self,
638
+ layer: torch.nn.Module,
639
+ x: torch.Tensor,
640
+ bias: Optional[torch.Tensor] = None,
641
+ ):
642
+ """
643
+ Use the output of create_weights and the CompressedTensorsScheme
644
+ associated with the layer to apply the forward pass with the
645
+ layer input. See LinearMethodBase for param details
646
+
647
+ """
648
+
649
+ scheme = layer.scheme
650
+ if scheme is None:
651
+ raise ValueError("A scheme must be defined for each layer")
652
+ return scheme.apply_weights(layer, x, bias=bias)