compressed-tensors 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. compressed_tensors/compressors/model_compressors/model_compressor.py +17 -5
  2. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +4 -2
  3. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +2 -0
  4. compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
  5. compressed_tensors/config/base.py +60 -2
  6. compressed_tensors/linear/compressed_linear.py +3 -1
  7. compressed_tensors/quantization/__init__.py +0 -1
  8. compressed_tensors/quantization/lifecycle/__init__.py +0 -2
  9. compressed_tensors/quantization/lifecycle/apply.py +3 -17
  10. compressed_tensors/quantization/lifecycle/forward.py +24 -87
  11. compressed_tensors/quantization/lifecycle/initialize.py +21 -24
  12. compressed_tensors/quantization/quant_args.py +27 -25
  13. compressed_tensors/quantization/quant_config.py +2 -2
  14. compressed_tensors/quantization/quant_scheme.py +17 -24
  15. compressed_tensors/quantization/utils/helpers.py +125 -8
  16. compressed_tensors/registry/registry.py +1 -1
  17. compressed_tensors/utils/helpers.py +33 -1
  18. compressed_tensors/version.py +1 -1
  19. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/METADATA +1 -1
  20. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/RECORD +23 -31
  21. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/WHEEL +1 -1
  22. compressed_tensors/quantization/cache.py +0 -201
  23. compressed_tensors/quantization/lifecycle/calibration.py +0 -70
  24. compressed_tensors/quantization/lifecycle/frozen.py +0 -55
  25. compressed_tensors/quantization/observers/__init__.py +0 -21
  26. compressed_tensors/quantization/observers/base.py +0 -213
  27. compressed_tensors/quantization/observers/helpers.py +0 -149
  28. compressed_tensors/quantization/observers/min_max.py +0 -104
  29. compressed_tensors/quantization/observers/mse.py +0 -162
  30. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/LICENSE +0 -0
  31. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/top_level.txt +0 -0
@@ -13,14 +13,14 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from copy import deepcopy
16
- from typing import List, Optional
16
+ from typing import Any, Dict, List, Optional
17
17
 
18
18
  from compressed_tensors.quantization.quant_args import (
19
19
  QuantizationArgs,
20
20
  QuantizationStrategy,
21
21
  QuantizationType,
22
22
  )
23
- from pydantic import BaseModel
23
+ from pydantic import BaseModel, model_validator
24
24
 
25
25
 
26
26
  __all__ = [
@@ -36,7 +36,7 @@ class QuantizationScheme(BaseModel):
36
36
  of modules should be quantized
37
37
 
38
38
  :param targets: list of modules to apply the QuantizationArgs to, can be layer
39
- names, layer types or a regular expression
39
+ names, layer types or a regular expression, typically ["Linear"]
40
40
  :param weights: quantization config for layer weights
41
41
  :param input_activations: quantization config for layer inputs
42
42
  :param output_activations: quantization config for layer outputs
@@ -47,27 +47,20 @@ class QuantizationScheme(BaseModel):
47
47
  input_activations: Optional[QuantizationArgs] = None
48
48
  output_activations: Optional[QuantizationArgs] = None
49
49
 
50
- @classmethod
51
- def default_scheme(
52
- cls,
53
- targets: Optional[List[str]] = None,
54
- ):
55
-
56
- if targets is None:
57
- # default to quantizing all Linear layers
58
- targets = ["Linear"]
59
-
60
- # by default, activations and weights are left unquantized
61
- weights = None
62
- input_activations = None
63
- output_activations = None
64
-
65
- return cls(
66
- targets=targets,
67
- weights=weights,
68
- input_activations=input_activations,
69
- output_activations=output_activations,
70
- )
50
+ @model_validator(mode="after")
51
+ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
52
+ inputs = model.input_activations
53
+ outputs = model.output_activations
54
+
55
+ if inputs is not None:
56
+ if inputs.actorder is not None:
57
+ raise ValueError("Cannot apply actorder to input activations")
58
+
59
+ if outputs is not None:
60
+ if outputs.actorder is not None:
61
+ raise ValueError("Cannot apply actorder to output activations")
62
+
63
+ return model
71
64
 
72
65
 
73
66
  """
@@ -16,9 +16,14 @@ import logging
16
16
  from typing import Generator, List, Optional, Tuple
17
17
 
18
18
  import torch
19
- from compressed_tensors.quantization.observers.base import Observer
20
- from compressed_tensors.quantization.quant_args import QuantizationArgs
19
+ from compressed_tensors.quantization.quant_args import (
20
+ FP8_DTYPE,
21
+ QuantizationArgs,
22
+ QuantizationStrategy,
23
+ QuantizationType,
24
+ )
21
25
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
26
+ from torch import FloatTensor, IntTensor, Tensor
22
27
  from torch.nn import Module
23
28
  from tqdm import tqdm
24
29
 
@@ -36,6 +41,9 @@ __all__ = [
36
41
  "is_kv_cache_quant_scheme",
37
42
  "iter_named_leaf_modules",
38
43
  "iter_named_quantizable_modules",
44
+ "compute_dynamic_scales_and_zp",
45
+ "calculate_range",
46
+ "calculate_qparams",
39
47
  ]
40
48
 
41
49
  # target the self_attn layer
@@ -45,6 +53,105 @@ KV_CACHE_TARGETS = ["re:.*self_attn$"]
45
53
  _LOGGER: logging.Logger = logging.getLogger(__name__)
46
54
 
47
55
 
56
+ def calculate_qparams(
57
+ min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
58
+ ) -> Tuple[FloatTensor, IntTensor]:
59
+ """
60
+ :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
61
+ from
62
+ :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
63
+ from
64
+ :param quantization_args: settings to quantization
65
+ :return: tuple of the calculated scale(s) and zero point(s)
66
+ """
67
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
68
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
69
+ device = min_vals.device
70
+
71
+ bit_min, bit_max = calculate_range(quantization_args, device)
72
+ bit_range = bit_max - bit_min
73
+ zp_dtype = quantization_args.pytorch_dtype()
74
+
75
+ if quantization_args.symmetric:
76
+ max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
77
+ scales = max_val_pos / (float(bit_range) / 2)
78
+ scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
79
+ zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
80
+ else:
81
+ scales = (max_vals - min_vals) / float(bit_range)
82
+ scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
83
+ zero_points = bit_min - (min_vals / scales)
84
+ zero_points = torch.clamp(zero_points, bit_min, bit_max)
85
+
86
+ # match zero-points to quantized type
87
+ zero_points = zero_points.to(zp_dtype)
88
+
89
+ if scales.ndim == 0:
90
+ scales = scales.reshape(1)
91
+ zero_points = zero_points.reshape(1)
92
+
93
+ return scales, zero_points
94
+
95
+
96
+ def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
97
+ """
98
+ Returns the computed scales and zero points for dynamic activation
99
+ qunatization.
100
+
101
+ :param value: tensor to calculate quantization parameters for
102
+ :param args: quantization args
103
+ :param reduce_dims: optional tuple of dimensions to reduce along,
104
+ returned scale and zero point will be shaped (1,) along the
105
+ reduced dimensions
106
+ :return: tuple of scale and zero point derived from the observed tensor
107
+ """
108
+ if args.strategy == QuantizationStrategy.TOKEN:
109
+ dim = {1, 2}
110
+ reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
111
+ elif args.strategy == QuantizationStrategy.TENSOR:
112
+ reduce_dims = None
113
+ else:
114
+ raise ValueError(
115
+ f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
116
+ "must be used for dynamic quantization",
117
+ )
118
+
119
+ if not reduce_dims:
120
+ min_val, max_val = torch.aminmax(value)
121
+ else:
122
+ min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
123
+ max_val = torch.amax(value, dim=reduce_dims, keepdims=True)
124
+
125
+ return calculate_qparams(min_val, max_val, args)
126
+
127
+
128
+ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
129
+ """
130
+ Calculated the effective quantization range for the given Quantization Args
131
+
132
+ :param quantization_args: quantization args to get range of
133
+ :param device: device to store the range to
134
+ :return: tuple endpoints for the given quantization range
135
+ """
136
+ if quantization_args.type == QuantizationType.INT:
137
+ bit_range = 2**quantization_args.num_bits
138
+ q_max = torch.tensor(bit_range / 2 - 1, device=device)
139
+ q_min = torch.tensor(-bit_range / 2, device=device)
140
+ elif quantization_args.type == QuantizationType.FLOAT:
141
+ if quantization_args.num_bits != 8:
142
+ raise ValueError(
143
+ "Floating point quantization is only supported for 8 bits,"
144
+ f"got {quantization_args.num_bits}"
145
+ )
146
+ fp_range_info = torch.finfo(FP8_DTYPE)
147
+ q_max = torch.tensor(fp_range_info.max, device=device)
148
+ q_min = torch.tensor(fp_range_info.min, device=device)
149
+ else:
150
+ raise ValueError(f"Invalid quantization type {quantization_args.type}")
151
+
152
+ return q_min, q_max
153
+
154
+
48
155
  def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa
49
156
  """
50
157
  Checks the quantization status of a model. Assumes all modules in the model have
@@ -118,12 +225,17 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None
118
225
  """
119
226
  for name, submodule in model.named_modules():
120
227
  children = list(submodule.children())
121
- if len(children) == 0 and not isinstance(submodule, Observer):
228
+ # TODO: verify if an observer would ever be attached in this case/remove check
229
+ if len(children) == 0 and "observer" in name:
122
230
  yield name, submodule
123
231
  else:
232
+ if len(children) > 0:
233
+ named_children, children = zip(*list(submodule.named_children()))
124
234
  has_non_observer_children = False
125
- for child in children:
126
- if not isinstance(child, Observer):
235
+ for i in range(len(children)):
236
+ child_name = named_children[i]
237
+
238
+ if "observer" not in child_name:
127
239
  has_non_observer_children = True
128
240
 
129
241
  if not has_non_observer_children:
@@ -144,14 +256,19 @@ def iter_named_quantizable_modules(
144
256
  :returns: generator tuple of (name, submodule)
145
257
  """
146
258
  for name, submodule in model.named_modules():
259
+ # TODO: verify if an observer would ever be attached in this case/remove check
147
260
  if include_children:
148
261
  children = list(submodule.children())
149
- if len(children) == 0 and not isinstance(submodule, Observer):
262
+ if len(children) == 0 and "observer" not in name:
150
263
  yield name, submodule
151
264
  else:
265
+ if len(children) > 0:
266
+ named_children, children = zip(*list(submodule.named_children()))
152
267
  has_non_observer_children = False
153
- for child in children:
154
- if not isinstance(child, Observer):
268
+ for i in range(len(children)):
269
+ child_name = named_children[i]
270
+
271
+ if "observer" not in child_name:
155
272
  has_non_observer_children = True
156
273
 
157
274
  if not has_non_observer_children:
@@ -258,7 +258,7 @@ def get_from_registry(
258
258
  retrieved_value = _import_and_get_value_from_module(module_path, value_name)
259
259
  else:
260
260
  # look up name in alias registry
261
- name = _ALIAS_REGISTRY[parent_class].get(name)
261
+ name = _ALIAS_REGISTRY[parent_class].get(name, name)
262
262
  # look up name in registry
263
263
  retrieved_value = _REGISTRY[parent_class].get(name)
264
264
  if retrieved_value is None:
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Optional
15
+ from typing import Any, Dict, Optional
16
16
 
17
17
  import torch
18
18
  from transformers import AutoConfig
@@ -24,6 +24,7 @@ __all__ = [
24
24
  "tensor_follows_mask_structure",
25
25
  "replace_module",
26
26
  "is_compressed_tensors_config",
27
+ "Aliasable",
27
28
  ]
28
29
 
29
30
  FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -119,3 +120,34 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
119
120
  return isinstance(compression_config, CompressedTensorsConfig)
120
121
  except ImportError:
121
122
  return False
123
+
124
+
125
+ class Aliasable:
126
+ """
127
+ A mixin for enums to allow aliasing of enum members
128
+
129
+ Example:
130
+ >>> class MyClass(Aliasable, int, Enum):
131
+ >>> ...
132
+ """
133
+
134
+ @staticmethod
135
+ def get_aliases() -> Dict[str, str]:
136
+ raise NotImplementedError()
137
+
138
+ def __eq__(self, other):
139
+ if isinstance(other, self.__class__):
140
+ aliases = self.get_aliases()
141
+ return self.value == other.value or (
142
+ aliases.get(self.value, self.value)
143
+ == aliases.get(other.value, other.value)
144
+ )
145
+ else:
146
+ aliases = self.get_aliases()
147
+ self_value = aliases.get(self.value, self.value)
148
+ other_value = aliases.get(other, other)
149
+ return self_value == other_value
150
+
151
+ def __hash__(self):
152
+ canonical_value = self.aliases.get(self.value, self.value)
153
+ return hash(canonical_value)
@@ -17,7 +17,7 @@ Functionality for storing and setting the version info for SparseML
17
17
  """
18
18
 
19
19
 
20
- version_base = "0.7.1"
20
+ version_base = "0.8.1"
21
21
  is_release = True # change to True to set the generated version as a release version
22
22
 
23
23
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors
3
- Version: 0.7.1
3
+ Version: 0.8.1
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,58 +1,50 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=U13sp7AiFBqeNdF8kzErXdcc0TAgy3S096kUMFPSGV0,1585
3
+ compressed_tensors/version.py,sha256=U6bppqc5inOxvcJDHWhDoSXvBrvbH425oJM2WG7TECY,1585
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=D9TNwQcjanDiAHODPbg8JUqc66e3j50rctY7A708NEs,6743
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
7
7
  compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
8
- compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=XJgPsq8KiDfiR4e8bSI38lmoOd2ApqRk1aPcXS2obqY,15600
8
+ compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=sxh1TvW1Bp9YJE41hW0XZfd0kYYB85nhJvBLVRTDcV0,15886
9
9
  compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=09UJq68Pht6Bf-4iP9xYl3tetKsncNPHD8IAGbePsr4,714
10
10
  compressed_tensors/compressors/quantized_compressors/base.py,sha256=K1KOnS6Y8nUA1-HN7VhyfsDc01nilW0WfXMUhuD-l8w,5954
11
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=Mmfr-hap-4zw7CzE1mXi0UirknqGidNxw38GGWVgTqM,4916
12
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=9H8UrG5v1GRtslLjOEiUM2dnyxJnR-HJmlsFezQs_r0,7706
11
+ compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=MMUya3Iwarm0BkeYXqKTUnEDPiBw98GKF09QiNST45k,4960
12
+ compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=1CLwvBlu4AtGkuo3IisD1-rQzwLiA6hE1bCc-pF_XGo,7758
13
13
  compressed_tensors/compressors/sparse_compressors/__init__.py,sha256=i2TESH27l7KXeOhJ6hShIoI904XX96l-cRQiMR6MAaU,704
14
14
  compressed_tensors/compressors/sparse_compressors/base.py,sha256=Ua4rUSGyucEs-YJI5z3oIUF-zqQLrFsQ9f-qKasEdUM,4410
15
15
  compressed_tensors/compressors/sparse_compressors/dense.py,sha256=lSKNWRx6H7aUqaJj1j4qbXk8Gkm1UohbnvW1Rvq6Ra4,1284
16
16
  compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=4fKwCG7ZM8mUtSnjPvubzEHl-mTnxMzwjmcs7L43WLY,6622
17
17
  compressed_tensors/compressors/sparse_quantized_compressors/__init__.py,sha256=4f_cwcKXB1nVVMoiKgTFAc8jAPjPLElo-Df_EDm1_xw,675
18
- compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=akqE7eW8CLTslpWRxERaZ8R0TSm1lS7D1bgZXKL0xi8,9427
18
+ compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=BMIQWTLlnUvxy14iEJegtiP75WHJeOVojey9mKOK1hE,9427
19
19
  compressed_tensors/config/__init__.py,sha256=ZBqWn3r6ku1qfmlHHYp0mQueY0i7Pwhr9rbQk9dDlMc,704
20
- compressed_tensors/config/base.py,sha256=BNTFKy12isY7qblwxdi_R1f00EzgrNOXLrfxqLCPT8w,1903
20
+ compressed_tensors/config/base.py,sha256=3bFAdwDZjOt-U3fneOeL8dRci-PS8DqstnXuQVtkfiQ,3435
21
21
  compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74jNbjks,1317
22
22
  compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
23
23
  compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
24
- compressed_tensors/linear/compressed_linear.py,sha256=0jTTf6XxOAjAYs3tvFtgiNMAO4W10sSeR-pdH2M413g,3218
25
- compressed_tensors/quantization/__init__.py,sha256=nWP_fsl6Nn0ksEgZPzerGiETdvF-ZfNwPnwGlRiR5pY,805
26
- compressed_tensors/quantization/cache.py,sha256=vnBB5zasO_XpHomZvzUPVVbzyCz2VgebsHePm0kANzY,6831
27
- compressed_tensors/quantization/quant_args.py,sha256=k7NuZn8OqjgzmAVaN2-jHPQ1bgDkMuUoLJtLnhkvIOI,9085
28
- compressed_tensors/quantization/quant_config.py,sha256=NCiMvUMnnz5kTyAkDylxjtEGQnjgsIYIeNR2zyHEdTQ,10371
29
- compressed_tensors/quantization/quant_scheme.py,sha256=5ggPz5sqEfTUgvJJeiPIINA74QtO-08hb3szsm7UHGE,6000
30
- compressed_tensors/quantization/lifecycle/__init__.py,sha256=MXE2E7GfIfRRfhrdGy2Og3AZOz5N59B0ZGFcsD89y6c,821
31
- compressed_tensors/quantization/lifecycle/apply.py,sha256=czaayvpeUYyWRJhO_klffw6esptOgA9sBKL5TWQcRdw,15805
32
- compressed_tensors/quantization/lifecycle/calibration.py,sha256=IuLeRkVQPrMxkMcIjr4OMFlIUMHkqjH4qAxC2KiUBGw,2673
24
+ compressed_tensors/linear/compressed_linear.py,sha256=MJa-UfoKhIkdUWRD1shrXXri2cOwR5GK0a4t4bNYosM,3268
25
+ compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
26
+ compressed_tensors/quantization/quant_args.py,sha256=jwC__lSmuiJ2qSJYYZGgWgQNbZu6YhhS0e-qugrTNXE,9058
27
+ compressed_tensors/quantization/quant_config.py,sha256=K6kOZ6LDXpFlqsVzR4NEATV6y6Ea83rJWnNyVlvw-pI,10379
28
+ compressed_tensors/quantization/quant_scheme.py,sha256=eQ0JrRZ80GX69fpwW87VzPzzhajhk4mUaJScjk82OY4,6010
29
+ compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
30
+ compressed_tensors/quantization/lifecycle/apply.py,sha256=jCUSgeOBtagE5IhgIbyYMZ4kv8Rm20VGJ4IxXZ5HAnw,15066
33
31
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
34
- compressed_tensors/quantization/lifecycle/forward.py,sha256=qy6_3z5YWDIffiAjQxgmBRggZifA7z93F9vk2GajIIU,15703
35
- compressed_tensors/quantization/lifecycle/frozen.py,sha256=NiJw7NP7pcT6idWFa8vksgiLoT8oQ975e57S4QfD2QQ,1874
32
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=QPL6-vKOFuKdKIEsVqMhsw4x552Jpm2sqO0oeChbnrM,12941
36
33
  compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
37
- compressed_tensors/quantization/lifecycle/initialize.py,sha256=2n309DPxeV_nrM5H_yfQOhF5kteu428qBd4CBzocscw,8908
38
- compressed_tensors/quantization/observers/__init__.py,sha256=DYrttzq-8MHLZUzpX-xzzm4hrw6HcXkMkux82KBKb1M,738
39
- compressed_tensors/quantization/observers/base.py,sha256=5ovQicWPYHjIxr6-EkQ4lgOX0PpI9g23iSzKpxjM1Zg,8420
40
- compressed_tensors/quantization/observers/helpers.py,sha256=o9hg4E9b5cCb5PaEAj6jHiUWkNrKtYtv0b1pGg-T9B4,5516
41
- compressed_tensors/quantization/observers/min_max.py,sha256=sQXqU3z-voxIDfR_9mQzwQUflZj2sASm_G8CYaXntFw,3865
42
- compressed_tensors/quantization/observers/mse.py,sha256=Aeh-253Vbab1F8cYuBiGNn4OXWJ67wXQ_JVfl3mu2a8,6034
34
+ compressed_tensors/quantization/lifecycle/initialize.py,sha256=C41hKA5VANyEwkB5FxzEn3Z0Da5tfxF1I07P8rUcyS0,8537
43
35
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
44
- compressed_tensors/quantization/utils/helpers.py,sha256=y4LEyC2oUd876ZMdALWKGH3Ct5EgBJZV4id_NUjTGH8,9531
36
+ compressed_tensors/quantization/utils/helpers.py,sha256=DBP-sGRpGAY01K0LFE7qqonNj4hkTYL_mXrMs2LtAD8,14100
45
37
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
46
- compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85SLG77nml2iA,11890
38
+ compressed_tensors/registry/registry.py,sha256=vRcjVB1ITfSbfYUaGndBBmqhip_5vsS62weorVg0iXo,11896
47
39
  compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
48
- compressed_tensors/utils/helpers.py,sha256=hWGIR0W7ENHwdC7wW2SQJJiCF9-xOu_u3fY2RzLyYg4,4101
40
+ compressed_tensors/utils/helpers.py,sha256=T3p0TbhWbQIRjL6Up2Z7UhZO5jpR6WxBhYPPvrhE6lE,5018
49
41
  compressed_tensors/utils/offload.py,sha256=d9q8LNe8HyF8tOjgjA7QGLD3HRysmNp0d8eBbdqBgIM,4089
50
42
  compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
51
43
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
52
44
  compressed_tensors/utils/safetensors_load.py,sha256=m08ANVuTBxQdoa6LufDgcNJ7wCLDJolyZljB8VEybAU,8578
53
45
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
54
- compressed_tensors-0.7.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
55
- compressed_tensors-0.7.1.dist-info/METADATA,sha256=ouRYcF6o8A9ilFaWfE51ApA0Z49_KmvTf-KrfnNTxwI,6782
56
- compressed_tensors-0.7.1.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
57
- compressed_tensors-0.7.1.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
58
- compressed_tensors-0.7.1.dist-info/RECORD,,
46
+ compressed_tensors-0.8.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
47
+ compressed_tensors-0.8.1.dist-info/METADATA,sha256=rDPAoGePUI_yRN7LRP23t3vKWhDfxPbeNR1TX6vpPPI,6782
48
+ compressed_tensors-0.8.1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
49
+ compressed_tensors-0.8.1.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
50
+ compressed_tensors-0.8.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.44.0)
2
+ Generator: bdist_wheel (0.45.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,201 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from enum import Enum
17
- from typing import Any, Dict, List, Optional, Tuple
18
-
19
- from compressed_tensors.quantization.observers import Observer
20
- from compressed_tensors.quantization.quant_args import QuantizationArgs
21
- from torch import Tensor
22
- from transformers import DynamicCache as HFDyanmicCache
23
-
24
-
25
- class KVCacheScaleType(Enum):
26
- KEY = "k_scale"
27
- VALUE = "v_scale"
28
-
29
-
30
- class QuantizedKVParameterCache(HFDyanmicCache):
31
-
32
- """
33
- Quantized KV cache used in the forward call based on HF's dynamic cache.
34
- Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
35
- Singleton, so that the same cache gets reused in all forward call of self_attn.
36
- Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
37
- gets called appropriately.
38
- The size of tensor is
39
- `[batch_size, num_heads, seq_len - residual_length, head_dim]`.
40
-
41
-
42
- Triggered by adding kv_cache_scheme in the recipe.
43
-
44
- Example:
45
-
46
- ```python3
47
- recipe = '''
48
- quant_stage:
49
- quant_modifiers:
50
- QuantizationModifier:
51
- kv_cache_scheme:
52
- num_bits: 8
53
- type: float
54
- strategy: tensor
55
- dynamic: false
56
- symmetric: true
57
- '''
58
-
59
- """
60
-
61
- _instance = None
62
- _initialized = False
63
-
64
- def __new__(cls, *args, **kwargs):
65
- """Singleton"""
66
- if cls._instance is None:
67
- cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
68
- return cls._instance
69
-
70
- def __init__(self, quantization_args: QuantizationArgs):
71
- if not self._initialized:
72
- super().__init__()
73
-
74
- self.quantization_args = quantization_args
75
-
76
- self.k_observers: List[Observer] = []
77
- self.v_observers: List[Observer] = []
78
-
79
- # each index corresponds to layer_idx of the attention layer
80
- self.k_scales: List[Tensor] = []
81
- self.v_scales: List[Tensor] = []
82
-
83
- self.k_zps: List[Tensor] = []
84
- self.v_zps: List[Tensor] = []
85
-
86
- self._initialized = True
87
-
88
- def update(
89
- self,
90
- key_states: Tensor,
91
- value_states: Tensor,
92
- layer_idx: int,
93
- cache_kwargs: Optional[Dict[str, Any]] = None,
94
- ) -> Tuple[Tensor, Tensor]:
95
- """
96
- Get the k_scale and v_scale and output the
97
- fakequant-ed key_states and value_states
98
- """
99
-
100
- if len(self.k_observers) <= layer_idx:
101
- k_observer = self.quantization_args.get_observer()
102
- v_observer = self.quantization_args.get_observer()
103
-
104
- self.k_observers.append(k_observer)
105
- self.v_observers.append(v_observer)
106
-
107
- q_key_states = self._quantize(
108
- key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
109
- )
110
- q_value_states = self._quantize(
111
- value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
112
- )
113
-
114
- qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
115
- qdq_value_states = self._dequantize(
116
- q_value_states, KVCacheScaleType.VALUE, layer_idx
117
- )
118
-
119
- keys_to_return, values_to_return = qdq_key_states, qdq_value_states
120
-
121
- return keys_to_return, values_to_return
122
-
123
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
124
- """
125
- Returns the sequence length of the cached states.
126
- A layer index can be optionally passed.
127
- """
128
- if len(self.key_cache) <= layer_idx:
129
- return 0
130
- # since we cannot get the seq_length of each layer directly and
131
- # rely on `_seen_tokens` which is updated every "layer_idx" == 0,
132
- # this is a hack to get the actual seq_length for the given layer_idx
133
- # this part of code otherwise fails when used to
134
- # verify attn_weight shape in some models
135
- return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
136
-
137
- def reset_states(self):
138
- """reset the kv states (used in calibration)"""
139
- self.key_cache: List[Tensor] = []
140
- self.value_cache: List[Tensor] = []
141
- # Used in `generate` to keep tally of how many tokens the cache has seen
142
- self._seen_tokens = 0
143
- self._quantized_key_cache: List[Tensor] = []
144
- self._quantized_value_cache: List[Tensor] = []
145
-
146
- def reset(self):
147
- """
148
- Reset the instantiation, create new instance on init
149
- """
150
- QuantizedKVParameterCache._instance = None
151
- QuantizedKVParameterCache._initialized = False
152
-
153
- def _quantize(self, tensor, kv_type, layer_idx):
154
- """Quantizes a key/value using a defined quantization method."""
155
- from compressed_tensors.quantization.lifecycle.forward import quantize
156
-
157
- if kv_type == KVCacheScaleType.KEY: # key type
158
- observer = self.k_observers[layer_idx]
159
- scales = self.k_scales
160
- zps = self.k_zps
161
- else:
162
- assert kv_type == KVCacheScaleType.VALUE
163
- observer = self.v_observers[layer_idx]
164
- scales = self.v_scales
165
- zps = self.v_zps
166
-
167
- scale, zp = observer(tensor)
168
- if len(scales) <= layer_idx:
169
- scales.append(scale)
170
- zps.append(zp)
171
- else:
172
- scales[layer_idx] = scale
173
- zps[layer_idx] = scale
174
-
175
- q_tensor = quantize(
176
- x=tensor,
177
- scale=scale,
178
- zero_point=zp,
179
- args=self.quantization_args,
180
- )
181
- return q_tensor
182
-
183
- def _dequantize(self, qtensor, kv_type, layer_idx):
184
- """Dequantizes back the tensor that was quantized by `self._quantize()`"""
185
- from compressed_tensors.quantization.lifecycle.forward import dequantize
186
-
187
- if kv_type == KVCacheScaleType.KEY:
188
- scale = self.k_scales[layer_idx]
189
- zp = self.k_zps[layer_idx]
190
- else:
191
- assert kv_type == KVCacheScaleType.VALUE
192
- scale = self.v_scales[layer_idx]
193
- zp = self.v_zps[layer_idx]
194
-
195
- qdq_tensor = dequantize(
196
- x_q=qtensor,
197
- scale=scale,
198
- zero_point=zp,
199
- args=self.quantization_args,
200
- )
201
- return qdq_tensor