compressed-tensors-nightly 0.3.3.20240612__py3-none-any.whl → 0.4.0.20240613__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,7 +17,7 @@ import logging
17
17
  import operator
18
18
  import os
19
19
  from copy import deepcopy
20
- from typing import Dict, Optional, Union
20
+ from typing import Any, Dict, Optional, Union
21
21
 
22
22
  from compressed_tensors.base import (
23
23
  COMPRESSION_CONFIG_NAME,
@@ -88,20 +88,41 @@ class ModelCompressor:
88
88
  """
89
89
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
90
90
  compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
91
+ return cls.from_compression_config(compression_config)
92
+
93
+ @classmethod
94
+ def from_compression_config(cls, compression_config: Dict[str, Any]):
95
+ """
96
+ :param compression_config: compression/quantization config dictionary
97
+ found under key "quantization_config" in HF model config
98
+ :return: compressor for the extracted configs
99
+ """
91
100
  if compression_config is None:
92
101
  return None
93
102
 
103
+ try:
104
+ from transformers.utils.quantization_config import CompressedTensorsConfig
105
+
106
+ if isinstance(compression_config, CompressedTensorsConfig):
107
+ compression_config = compression_config.to_dict()
108
+ except ImportError:
109
+ pass
110
+
94
111
  sparsity_config = cls.parse_sparsity_config(compression_config)
95
112
  quantization_config = cls.parse_quantization_config(compression_config)
96
113
  if sparsity_config is None and quantization_config is None:
97
114
  return None
98
115
 
99
- if sparsity_config is not None:
116
+ if sparsity_config is not None and not isinstance(
117
+ sparsity_config, SparsityCompressionConfig
118
+ ):
100
119
  format = sparsity_config.get("format")
101
120
  sparsity_config = SparsityCompressionConfig.load_from_registry(
102
121
  format, **sparsity_config
103
122
  )
104
- if quantization_config is not None:
123
+ if quantization_config is not None and not isinstance(
124
+ quantization_config, QuantizationConfig
125
+ ):
105
126
  quantization_config = QuantizationConfig.parse_obj(quantization_config)
106
127
 
107
128
  return cls(
@@ -146,15 +167,29 @@ class ModelCompressor:
146
167
  def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
147
168
  if compression_config is None:
148
169
  return None
170
+ if SPARSITY_CONFIG_NAME not in compression_config:
171
+ return None
172
+ if hasattr(compression_config, SPARSITY_CONFIG_NAME):
173
+ # for loaded HFQuantizer config
174
+ return getattr(compression_config, SPARSITY_CONFIG_NAME)
175
+
176
+ # SparseAutoModel format
149
177
  return compression_config.get(SPARSITY_CONFIG_NAME, None)
150
178
 
151
179
  @staticmethod
152
180
  def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
181
+ if compression_config is None:
182
+ return None
183
+
184
+ if hasattr(compression_config, QUANTIZATION_CONFIG_NAME):
185
+ # for loaded HFQuantizer config
186
+ return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
187
+
188
+ # SparseAutoModel format
153
189
  quantization_config = deepcopy(compression_config)
154
190
  quantization_config.pop(SPARSITY_CONFIG_NAME, None)
155
191
  if len(quantization_config) == 0:
156
192
  quantization_config = None
157
-
158
193
  return quantization_config
159
194
 
160
195
  def __init__(
@@ -123,11 +123,14 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
123
123
  if target is not None:
124
124
  # target matched - add layer and scheme to target list
125
125
  submodule.quantization_scheme = target_to_scheme[target]
126
- if set(config.ignore) - set(ignored_submodules):
127
- _LOGGER.warning(
128
- "Some layers that were to be ignored were "
129
- f"not found in the model: {set(config.ignore) - set(ignored_submodules)}"
130
- )
126
+
127
+ if config.ignore is not None and ignored_submodules is not None:
128
+ if set(config.ignore) - set(ignored_submodules):
129
+ _LOGGER.warning(
130
+ "Some layers that were to be ignored were "
131
+ "not found in the model: "
132
+ f"{set(config.ignore) - set(ignored_submodules)}"
133
+ )
131
134
  # apply current quantization status across all targeted layers
132
135
  apply_quantization_status(model, config.quantization_status)
133
136
 
@@ -146,7 +149,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
146
149
 
147
150
  if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
148
151
  model.apply(set_module_for_calibration)
149
-
150
152
  if current_status < status >= QuantizationStatus.FROZEN > current_status:
151
153
  model.apply(freeze_module_quantization)
152
154
 
@@ -160,9 +162,10 @@ def find_first_name_or_class_match(
160
162
  # first element of targets that matches the given name
161
163
  # if no name matches returns first target that matches the class name
162
164
  # returns None otherwise
163
- return _find_first_match(name, targets) or _find_first_match(
164
- module.__class__.__name__, targets, check_contains
165
- )
165
+ if isinstance(targets, Iterable):
166
+ return _find_first_match(name, targets) or _find_first_match(
167
+ module.__class__.__name__, targets, check_contains
168
+ )
166
169
 
167
170
 
168
171
  def _find_first_match(
@@ -212,7 +215,12 @@ def _load_quant_args_from_state_dict(
212
215
  scale = getattr(module, scale_name, None)
213
216
  zp = getattr(module, zp_name, None)
214
217
  if scale is not None:
215
- scale.data = state_dict[f"{module_name}.{scale_name}"].to(device)
218
+ state_dict_scale = state_dict.get(f"{module_name}.{scale_name}")
219
+ if state_dict_scale is not None:
220
+ scale.data = state_dict_scale.to(device).to(scale.dtype)
221
+ else:
222
+ scale.data = scale.data.to(device)
223
+
216
224
  if zp is not None:
217
225
  zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None)
218
226
  if zp_from_state is not None: # load the non-zero zero points
@@ -94,7 +94,7 @@ def dequantize(
94
94
  :return: dequantized float tensor
95
95
  """
96
96
  if args is None:
97
- if scale.ndim == 0:
97
+ if scale.ndim == 0 or scale.ndim == 1:
98
98
  args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
99
99
  elif scale.ndim == 2:
100
100
  if scale.shape[1] == 1:
@@ -20,7 +20,10 @@ import torch
20
20
  from compressed_tensors.quantization.lifecycle.forward import (
21
21
  wrap_module_forward_quantized,
22
22
  )
23
- from compressed_tensors.quantization.quant_args import QuantizationArgs
23
+ from compressed_tensors.quantization.quant_args import (
24
+ QuantizationArgs,
25
+ QuantizationStrategy,
26
+ )
24
27
  from compressed_tensors.quantization.quant_config import QuantizationStatus
25
28
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
26
29
  from torch.nn import Module, Parameter
@@ -58,7 +61,12 @@ def initialize_module_for_quantization(
58
61
  _initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
59
62
  if scheme.weights is not None:
60
63
  if hasattr(module, "weight"):
61
- _initialize_scale_zero_point_observer(module, "weight", scheme.weights)
64
+ weight_shape = None
65
+ if isinstance(module, torch.nn.Linear):
66
+ weight_shape = module.weight.shape
67
+ _initialize_scale_zero_point_observer(
68
+ module, "weight", scheme.weights, weight_shape=weight_shape
69
+ )
62
70
  else:
63
71
  _LOGGER.warning(
64
72
  f"module type {type(module)} targeted for weight quantization but "
@@ -78,7 +86,10 @@ def initialize_module_for_quantization(
78
86
 
79
87
 
80
88
  def _initialize_scale_zero_point_observer(
81
- module: Module, base_name: str, quantization_args: QuantizationArgs
89
+ module: Module,
90
+ base_name: str,
91
+ quantization_args: QuantizationArgs,
92
+ weight_shape: Optional[torch.Size] = None,
82
93
  ):
83
94
  # initialize observer module and attach as submodule
84
95
  observer = quantization_args.get_observer()
@@ -89,13 +100,28 @@ def _initialize_scale_zero_point_observer(
89
100
 
90
101
  device = next(module.parameters()).device
91
102
 
103
+ # infer expected scale/zero point shape
104
+ expected_shape = 1 # per tensor
105
+
106
+ if base_name == "weight" and weight_shape is not None:
107
+ if quantization_args.strategy == QuantizationStrategy.CHANNEL:
108
+ # (output_channels, 1)
109
+ expected_shape = (weight_shape[0], 1)
110
+ elif quantization_args.strategy == QuantizationStrategy.GROUP:
111
+ expected_shape = (
112
+ weight_shape[0],
113
+ weight_shape[1] // quantization_args.group_size,
114
+ )
115
+
92
116
  # initializes empty scale and zero point parameters for the module
93
117
  init_scale = Parameter(
94
- torch.empty(0, dtype=torch.float16, device=device), requires_grad=False
118
+ torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
119
+ requires_grad=False,
95
120
  )
96
121
  module.register_parameter(f"{base_name}_scale", init_scale)
97
122
 
98
123
  init_zero_point = Parameter(
99
- torch.empty(0, device=device, dtype=int), requires_grad=False
124
+ torch.empty(expected_shape, device=device, dtype=int),
125
+ requires_grad=False,
100
126
  )
101
127
  module.register_parameter(f"{base_name}_zero_point", init_zero_point)
@@ -51,4 +51,8 @@ def calculate_qparams(
51
51
  zero_points = bit_min - torch.round(min_vals / scales)
52
52
  zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)
53
53
 
54
+ if scales.ndim == 0:
55
+ scales = scales.reshape(1)
56
+ zero_points = zero_points.reshape(1)
57
+
54
58
  return scales, zero_points
@@ -144,6 +144,10 @@ class QuantizationConfig(BaseModel):
144
144
  targets=targets_or_scheme,
145
145
  )
146
146
 
147
+ def to_dict(self):
148
+ # for compatibility with HFQuantizer
149
+ return self.dict()
150
+
147
151
  @staticmethod
148
152
  def from_pretrained(
149
153
  model: Module, format: Optional[str] = None
@@ -19,7 +19,7 @@ Functionality for storing and setting the version info for SparseML
19
19
  from datetime import date
20
20
 
21
21
 
22
- version_base = "0.3.3"
22
+ version_base = "0.4.0"
23
23
  is_release = False # change to True to set the generated version as a release version
24
24
 
25
25
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors-nightly
3
- Version: 0.3.3.20240612
3
+ Version: 0.4.0.20240613
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,13 +1,13 @@
1
1
  compressed_tensors/__init__.py,sha256=SV1csvHUVCd8kHXz6UDZim1HZ_fAVG3vfk-j_4Bb6hY,789
2
2
  compressed_tensors/base.py,sha256=OA2TOLP1gP3LSH7gp508eqr2ZtDQ-pqRHElCp-aB0vs,755
3
- compressed_tensors/version.py,sha256=V8krJZctm43D4AGQhJY6dB0MvP1-T9TJ8BcGa8kESrI,1512
3
+ compressed_tensors/version.py,sha256=7shEvInzCEXAScJ2akpiQpgv_IjveX6mAfvi2D_wQDE,1512
4
4
  compressed_tensors/compressors/__init__.py,sha256=rhqPp3YXFxCJRLZs1KRNSHTIxK2rNU--sYwDI8MW47w,1061
5
5
  compressed_tensors/compressors/base.py,sha256=LWEgbpgTxzmoqQ7Xhq2OQszUgWoDtFuGCiV1Y8nlBGw,2134
6
6
  compressed_tensors/compressors/dense.py,sha256=G_XHbvuENyupIKlXSITOQgvPkNkcMEOLcLWQr70V9EE,1257
7
7
  compressed_tensors/compressors/helpers.py,sha256=k9avlkmeYj6vkOAvl-MgcixtP7ib24SCfhzZ-RusXfw,5403
8
8
  compressed_tensors/compressors/int_quantized.py,sha256=Ct2vCK0yoPm6vkIFlzDMGQ7m14xT1GyURsSwH9DP770,5242
9
9
  compressed_tensors/compressors/marlin_24.py,sha256=X_BjtFB3Mn0hqiLz56UM3jGX2eNmGLnvEIPfbg7di6U,9444
10
- compressed_tensors/compressors/model_compressor.py,sha256=jUktyujYdd9KqkA9IyZK6EMi09iEw4_itwhzSh805Jk,11150
10
+ compressed_tensors/compressors/model_compressor.py,sha256=h3ixQtfzt6HxSNtdnB9OVdpCucTmIo4paDoaM7XYZXE,12559
11
11
  compressed_tensors/compressors/pack_quantized.py,sha256=VPiLlgJlDgARrn7YmiQoLqUfxErKBfj54epMYWRsF8k,8451
12
12
  compressed_tensors/compressors/sparse_bitmask.py,sha256=H9oZSTYI1oRCzAMbd4zThUnZd1h2rfs8DmA3tPcvuNE,8637
13
13
  compressed_tensors/compressors/utils/__init__.py,sha256=-mbGDZh1hd9T6u62Ht_iBIK255UmMg0f5bLkSs1f9Cc,731
@@ -20,18 +20,18 @@ compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74j
20
20
  compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
21
21
  compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
22
22
  compressed_tensors/quantization/quant_args.py,sha256=Z9Zu20ooAwEWlliAdUw1f1zwSrheuD6vqm3YXgJ1Lws,4388
23
- compressed_tensors/quantization/quant_config.py,sha256=Nv9rvWNrlbeJgNZhQf-cPAEWJ9NU75ATWHCacWaiQ_s,8189
23
+ compressed_tensors/quantization/quant_config.py,sha256=hL42sXp1wAZxyrkHarw7tAMRcwSVEr0MT3wmrmL3NhE,8285
24
24
  compressed_tensors/quantization/quant_scheme.py,sha256=-hAK1-C67_wJl10eaVLUvbslPBTV04WyzL_J-u9f1ck,3571
25
25
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=ggRGWRqhCxCaTTDWRcgTVX3axnS2xV6rc5YvdzK7fSg,798
26
- compressed_tensors/quantization/lifecycle/apply.py,sha256=disclMUDaz2MLPvcTwGQ1oo1clhTTBkAeNz5J9NRxVw,8552
26
+ compressed_tensors/quantization/lifecycle/apply.py,sha256=aZrglJ5mR3Xaxwj51-1BVVB1JGVkKQEeHxGfBaVmsHI,8881
27
27
  compressed_tensors/quantization/lifecycle/calibration.py,sha256=mLns4jlaWmBwOW8Jtlm5bMX-JET1AiZYUBO7qa-XuxI,1776
28
28
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=VreB10xPwgSLQQlTu20UCrFpRS--cA7-lx5s7nrPPrg,2247
29
- compressed_tensors/quantization/lifecycle/forward.py,sha256=_1TwffkyaaXL5QpFgXH1gvueUivOLpuRkoXY7vRXktY,11094
29
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=0T817yzYqFR1wUjk2XCtOISwr4u7cdkKqAv13jjfu24,11113
30
30
  compressed_tensors/quantization/lifecycle/frozen.py,sha256=h1XYt89MouBTf3jTYLG_6OdFxIu5q2N8tPjsy6J4E6Y,1726
31
- compressed_tensors/quantization/lifecycle/initialize.py,sha256=pFfcu-pxdQKzlnn-18-RlkEktt2yDi6woNXJsiv1A2c,3732
31
+ compressed_tensors/quantization/lifecycle/initialize.py,sha256=9xgPzHejQUO_AkZcc_SH5kqFeieG-9uo0fMRYV51i7Y,4577
32
32
  compressed_tensors/quantization/observers/__init__.py,sha256=DNH31NQYrIBBcmHsMyFA6whh4pbRsLwuNa6L8AeXaGc,745
33
33
  compressed_tensors/quantization/observers/base.py,sha256=z_JC-CRz-PY7WlpSoyOoSQQWz5ekTEd5LbXt0iHQRes,5239
34
- compressed_tensors/quantization/observers/helpers.py,sha256=JwALNfBYY9Eyl8Q180t0lGh8szumQj8TygfNl-isErs,2166
34
+ compressed_tensors/quantization/observers/helpers.py,sha256=FUyYUNd-3LbXt0-8Lwr7EPI2m-LXXBTXW1l5iOajNhA,2272
35
35
  compressed_tensors/quantization/observers/memoryless.py,sha256=jH_c6K3gxf4W3VNXQ7tbnP-J_86QTrEfjBn6Kh1C-H8,2165
36
36
  compressed_tensors/quantization/observers/min_max.py,sha256=UK7zCMzxv9GGn6BflBxdajV20RiWaCY2RHcvZodCP1w,3669
37
37
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
@@ -41,8 +41,8 @@ compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85S
41
41
  compressed_tensors/utils/__init__.py,sha256=5DrYjoZbaEvSkJcC-GRSbM_RBHVF4tG9gMd3zsJnjLw,665
42
42
  compressed_tensors/utils/helpers.py,sha256=5ull5yFT31M2zVxKeFvpvvlvX5f1Sk1LGuj_wrfZWCY,2267
43
43
  compressed_tensors/utils/safetensors_load.py,sha256=0MheXwx1jeY12PeISppiSIZHs6rmN2YddwPpFb9V67I,8527
44
- compressed_tensors_nightly-0.3.3.20240612.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
45
- compressed_tensors_nightly-0.3.3.20240612.dist-info/METADATA,sha256=GjdOve1sMxN8qOUPu3EjXTNRFvnX0jrjA8lYwmq9CCY,5668
46
- compressed_tensors_nightly-0.3.3.20240612.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
47
- compressed_tensors_nightly-0.3.3.20240612.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
48
- compressed_tensors_nightly-0.3.3.20240612.dist-info/RECORD,,
44
+ compressed_tensors_nightly-0.4.0.20240613.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
45
+ compressed_tensors_nightly-0.4.0.20240613.dist-info/METADATA,sha256=l2A-QS0UgDFrWK2qdHCPcButIksWNOnW6G5UjCUAFPY,5668
46
+ compressed_tensors_nightly-0.4.0.20240613.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
47
+ compressed_tensors_nightly-0.4.0.20240613.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
48
+ compressed_tensors_nightly-0.4.0.20240613.dist-info/RECORD,,