compressed-tensors 0.7.1__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
  2. compressed_tensors/config/base.py +60 -2
  3. compressed_tensors/quantization/__init__.py +0 -1
  4. compressed_tensors/quantization/lifecycle/__init__.py +0 -2
  5. compressed_tensors/quantization/lifecycle/apply.py +1 -16
  6. compressed_tensors/quantization/lifecycle/forward.py +24 -87
  7. compressed_tensors/quantization/lifecycle/initialize.py +21 -24
  8. compressed_tensors/quantization/quant_args.py +11 -22
  9. compressed_tensors/quantization/utils/helpers.py +125 -8
  10. compressed_tensors/registry/registry.py +1 -1
  11. compressed_tensors/version.py +1 -1
  12. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.0.dist-info}/METADATA +1 -1
  13. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.0.dist-info}/RECORD +16 -24
  14. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.0.dist-info}/WHEEL +1 -1
  15. compressed_tensors/quantization/cache.py +0 -201
  16. compressed_tensors/quantization/lifecycle/calibration.py +0 -70
  17. compressed_tensors/quantization/lifecycle/frozen.py +0 -55
  18. compressed_tensors/quantization/observers/__init__.py +0 -21
  19. compressed_tensors/quantization/observers/base.py +0 -213
  20. compressed_tensors/quantization/observers/helpers.py +0 -149
  21. compressed_tensors/quantization/observers/min_max.py +0 -104
  22. compressed_tensors/quantization/observers/mse.py +0 -162
  23. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.0.dist-info}/LICENSE +0 -0
  24. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.0.dist-info}/top_level.txt +0 -0
@@ -258,7 +258,7 @@ def get_from_registry(
258
258
  retrieved_value = _import_and_get_value_from_module(module_path, value_name)
259
259
  else:
260
260
  # look up name in alias registry
261
- name = _ALIAS_REGISTRY[parent_class].get(name)
261
+ name = _ALIAS_REGISTRY[parent_class].get(name, name)
262
262
  # look up name in registry
263
263
  retrieved_value = _REGISTRY[parent_class].get(name)
264
264
  if retrieved_value is None:
@@ -17,7 +17,7 @@ Functionality for storing and setting the version info for SparseML
17
17
  """
18
18
 
19
19
 
20
- version_base = "0.7.1"
20
+ version_base = "0.8.0"
21
21
  is_release = True # change to True to set the generated version as a release version
22
22
 
23
23
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors
3
- Version: 0.7.1
3
+ Version: 0.8.0
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,6 +1,6 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=U13sp7AiFBqeNdF8kzErXdcc0TAgy3S096kUMFPSGV0,1585
3
+ compressed_tensors/version.py,sha256=Z9w80ldLHldBZrnrRolznhe-AZsAg5ftvHw17kgPs10,1585
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=D9TNwQcjanDiAHODPbg8JUqc66e3j50rctY7A708NEs,6743
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
@@ -15,35 +15,27 @@ compressed_tensors/compressors/sparse_compressors/base.py,sha256=Ua4rUSGyucEs-YJ
15
15
  compressed_tensors/compressors/sparse_compressors/dense.py,sha256=lSKNWRx6H7aUqaJj1j4qbXk8Gkm1UohbnvW1Rvq6Ra4,1284
16
16
  compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=4fKwCG7ZM8mUtSnjPvubzEHl-mTnxMzwjmcs7L43WLY,6622
17
17
  compressed_tensors/compressors/sparse_quantized_compressors/__init__.py,sha256=4f_cwcKXB1nVVMoiKgTFAc8jAPjPLElo-Df_EDm1_xw,675
18
- compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=akqE7eW8CLTslpWRxERaZ8R0TSm1lS7D1bgZXKL0xi8,9427
18
+ compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=BMIQWTLlnUvxy14iEJegtiP75WHJeOVojey9mKOK1hE,9427
19
19
  compressed_tensors/config/__init__.py,sha256=ZBqWn3r6ku1qfmlHHYp0mQueY0i7Pwhr9rbQk9dDlMc,704
20
- compressed_tensors/config/base.py,sha256=BNTFKy12isY7qblwxdi_R1f00EzgrNOXLrfxqLCPT8w,1903
20
+ compressed_tensors/config/base.py,sha256=3bFAdwDZjOt-U3fneOeL8dRci-PS8DqstnXuQVtkfiQ,3435
21
21
  compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74jNbjks,1317
22
22
  compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
23
23
  compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
24
24
  compressed_tensors/linear/compressed_linear.py,sha256=0jTTf6XxOAjAYs3tvFtgiNMAO4W10sSeR-pdH2M413g,3218
25
- compressed_tensors/quantization/__init__.py,sha256=nWP_fsl6Nn0ksEgZPzerGiETdvF-ZfNwPnwGlRiR5pY,805
26
- compressed_tensors/quantization/cache.py,sha256=vnBB5zasO_XpHomZvzUPVVbzyCz2VgebsHePm0kANzY,6831
27
- compressed_tensors/quantization/quant_args.py,sha256=k7NuZn8OqjgzmAVaN2-jHPQ1bgDkMuUoLJtLnhkvIOI,9085
25
+ compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
26
+ compressed_tensors/quantization/quant_args.py,sha256=osjNwCSB6tcyH9Qeg5sHEiB-bHyi3XJ8TzkGVJuGTc4,8711
28
27
  compressed_tensors/quantization/quant_config.py,sha256=NCiMvUMnnz5kTyAkDylxjtEGQnjgsIYIeNR2zyHEdTQ,10371
29
28
  compressed_tensors/quantization/quant_scheme.py,sha256=5ggPz5sqEfTUgvJJeiPIINA74QtO-08hb3szsm7UHGE,6000
30
- compressed_tensors/quantization/lifecycle/__init__.py,sha256=MXE2E7GfIfRRfhrdGy2Og3AZOz5N59B0ZGFcsD89y6c,821
31
- compressed_tensors/quantization/lifecycle/apply.py,sha256=czaayvpeUYyWRJhO_klffw6esptOgA9sBKL5TWQcRdw,15805
32
- compressed_tensors/quantization/lifecycle/calibration.py,sha256=IuLeRkVQPrMxkMcIjr4OMFlIUMHkqjH4qAxC2KiUBGw,2673
29
+ compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
30
+ compressed_tensors/quantization/lifecycle/apply.py,sha256=pdCqxXnVw7HoDDanaOtek13g8x_nb54CBUlfuMdhFG4,14993
33
31
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
34
- compressed_tensors/quantization/lifecycle/forward.py,sha256=qy6_3z5YWDIffiAjQxgmBRggZifA7z93F9vk2GajIIU,15703
35
- compressed_tensors/quantization/lifecycle/frozen.py,sha256=NiJw7NP7pcT6idWFa8vksgiLoT8oQ975e57S4QfD2QQ,1874
32
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=QPL6-vKOFuKdKIEsVqMhsw4x552Jpm2sqO0oeChbnrM,12941
36
33
  compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
37
- compressed_tensors/quantization/lifecycle/initialize.py,sha256=2n309DPxeV_nrM5H_yfQOhF5kteu428qBd4CBzocscw,8908
38
- compressed_tensors/quantization/observers/__init__.py,sha256=DYrttzq-8MHLZUzpX-xzzm4hrw6HcXkMkux82KBKb1M,738
39
- compressed_tensors/quantization/observers/base.py,sha256=5ovQicWPYHjIxr6-EkQ4lgOX0PpI9g23iSzKpxjM1Zg,8420
40
- compressed_tensors/quantization/observers/helpers.py,sha256=o9hg4E9b5cCb5PaEAj6jHiUWkNrKtYtv0b1pGg-T9B4,5516
41
- compressed_tensors/quantization/observers/min_max.py,sha256=sQXqU3z-voxIDfR_9mQzwQUflZj2sASm_G8CYaXntFw,3865
42
- compressed_tensors/quantization/observers/mse.py,sha256=Aeh-253Vbab1F8cYuBiGNn4OXWJ67wXQ_JVfl3mu2a8,6034
34
+ compressed_tensors/quantization/lifecycle/initialize.py,sha256=C41hKA5VANyEwkB5FxzEn3Z0Da5tfxF1I07P8rUcyS0,8537
43
35
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
44
- compressed_tensors/quantization/utils/helpers.py,sha256=y4LEyC2oUd876ZMdALWKGH3Ct5EgBJZV4id_NUjTGH8,9531
36
+ compressed_tensors/quantization/utils/helpers.py,sha256=DBP-sGRpGAY01K0LFE7qqonNj4hkTYL_mXrMs2LtAD8,14100
45
37
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
46
- compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85SLG77nml2iA,11890
38
+ compressed_tensors/registry/registry.py,sha256=vRcjVB1ITfSbfYUaGndBBmqhip_5vsS62weorVg0iXo,11896
47
39
  compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
48
40
  compressed_tensors/utils/helpers.py,sha256=hWGIR0W7ENHwdC7wW2SQJJiCF9-xOu_u3fY2RzLyYg4,4101
49
41
  compressed_tensors/utils/offload.py,sha256=d9q8LNe8HyF8tOjgjA7QGLD3HRysmNp0d8eBbdqBgIM,4089
@@ -51,8 +43,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
51
43
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
52
44
  compressed_tensors/utils/safetensors_load.py,sha256=m08ANVuTBxQdoa6LufDgcNJ7wCLDJolyZljB8VEybAU,8578
53
45
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
54
- compressed_tensors-0.7.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
55
- compressed_tensors-0.7.1.dist-info/METADATA,sha256=ouRYcF6o8A9ilFaWfE51ApA0Z49_KmvTf-KrfnNTxwI,6782
56
- compressed_tensors-0.7.1.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
57
- compressed_tensors-0.7.1.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
58
- compressed_tensors-0.7.1.dist-info/RECORD,,
46
+ compressed_tensors-0.8.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
47
+ compressed_tensors-0.8.0.dist-info/METADATA,sha256=lRjH5wempREQ2lTFNqzMusIW95YHN4rF8yd73MVvOe0,6782
48
+ compressed_tensors-0.8.0.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
49
+ compressed_tensors-0.8.0.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
50
+ compressed_tensors-0.8.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.44.0)
2
+ Generator: bdist_wheel (0.45.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,201 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from enum import Enum
17
- from typing import Any, Dict, List, Optional, Tuple
18
-
19
- from compressed_tensors.quantization.observers import Observer
20
- from compressed_tensors.quantization.quant_args import QuantizationArgs
21
- from torch import Tensor
22
- from transformers import DynamicCache as HFDyanmicCache
23
-
24
-
25
- class KVCacheScaleType(Enum):
26
- KEY = "k_scale"
27
- VALUE = "v_scale"
28
-
29
-
30
- class QuantizedKVParameterCache(HFDyanmicCache):
31
-
32
- """
33
- Quantized KV cache used in the forward call based on HF's dynamic cache.
34
- Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
35
- Singleton, so that the same cache gets reused in all forward call of self_attn.
36
- Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
37
- gets called appropriately.
38
- The size of tensor is
39
- `[batch_size, num_heads, seq_len - residual_length, head_dim]`.
40
-
41
-
42
- Triggered by adding kv_cache_scheme in the recipe.
43
-
44
- Example:
45
-
46
- ```python3
47
- recipe = '''
48
- quant_stage:
49
- quant_modifiers:
50
- QuantizationModifier:
51
- kv_cache_scheme:
52
- num_bits: 8
53
- type: float
54
- strategy: tensor
55
- dynamic: false
56
- symmetric: true
57
- '''
58
-
59
- """
60
-
61
- _instance = None
62
- _initialized = False
63
-
64
- def __new__(cls, *args, **kwargs):
65
- """Singleton"""
66
- if cls._instance is None:
67
- cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
68
- return cls._instance
69
-
70
- def __init__(self, quantization_args: QuantizationArgs):
71
- if not self._initialized:
72
- super().__init__()
73
-
74
- self.quantization_args = quantization_args
75
-
76
- self.k_observers: List[Observer] = []
77
- self.v_observers: List[Observer] = []
78
-
79
- # each index corresponds to layer_idx of the attention layer
80
- self.k_scales: List[Tensor] = []
81
- self.v_scales: List[Tensor] = []
82
-
83
- self.k_zps: List[Tensor] = []
84
- self.v_zps: List[Tensor] = []
85
-
86
- self._initialized = True
87
-
88
- def update(
89
- self,
90
- key_states: Tensor,
91
- value_states: Tensor,
92
- layer_idx: int,
93
- cache_kwargs: Optional[Dict[str, Any]] = None,
94
- ) -> Tuple[Tensor, Tensor]:
95
- """
96
- Get the k_scale and v_scale and output the
97
- fakequant-ed key_states and value_states
98
- """
99
-
100
- if len(self.k_observers) <= layer_idx:
101
- k_observer = self.quantization_args.get_observer()
102
- v_observer = self.quantization_args.get_observer()
103
-
104
- self.k_observers.append(k_observer)
105
- self.v_observers.append(v_observer)
106
-
107
- q_key_states = self._quantize(
108
- key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
109
- )
110
- q_value_states = self._quantize(
111
- value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
112
- )
113
-
114
- qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
115
- qdq_value_states = self._dequantize(
116
- q_value_states, KVCacheScaleType.VALUE, layer_idx
117
- )
118
-
119
- keys_to_return, values_to_return = qdq_key_states, qdq_value_states
120
-
121
- return keys_to_return, values_to_return
122
-
123
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
124
- """
125
- Returns the sequence length of the cached states.
126
- A layer index can be optionally passed.
127
- """
128
- if len(self.key_cache) <= layer_idx:
129
- return 0
130
- # since we cannot get the seq_length of each layer directly and
131
- # rely on `_seen_tokens` which is updated every "layer_idx" == 0,
132
- # this is a hack to get the actual seq_length for the given layer_idx
133
- # this part of code otherwise fails when used to
134
- # verify attn_weight shape in some models
135
- return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
136
-
137
- def reset_states(self):
138
- """reset the kv states (used in calibration)"""
139
- self.key_cache: List[Tensor] = []
140
- self.value_cache: List[Tensor] = []
141
- # Used in `generate` to keep tally of how many tokens the cache has seen
142
- self._seen_tokens = 0
143
- self._quantized_key_cache: List[Tensor] = []
144
- self._quantized_value_cache: List[Tensor] = []
145
-
146
- def reset(self):
147
- """
148
- Reset the instantiation, create new instance on init
149
- """
150
- QuantizedKVParameterCache._instance = None
151
- QuantizedKVParameterCache._initialized = False
152
-
153
- def _quantize(self, tensor, kv_type, layer_idx):
154
- """Quantizes a key/value using a defined quantization method."""
155
- from compressed_tensors.quantization.lifecycle.forward import quantize
156
-
157
- if kv_type == KVCacheScaleType.KEY: # key type
158
- observer = self.k_observers[layer_idx]
159
- scales = self.k_scales
160
- zps = self.k_zps
161
- else:
162
- assert kv_type == KVCacheScaleType.VALUE
163
- observer = self.v_observers[layer_idx]
164
- scales = self.v_scales
165
- zps = self.v_zps
166
-
167
- scale, zp = observer(tensor)
168
- if len(scales) <= layer_idx:
169
- scales.append(scale)
170
- zps.append(zp)
171
- else:
172
- scales[layer_idx] = scale
173
- zps[layer_idx] = scale
174
-
175
- q_tensor = quantize(
176
- x=tensor,
177
- scale=scale,
178
- zero_point=zp,
179
- args=self.quantization_args,
180
- )
181
- return q_tensor
182
-
183
- def _dequantize(self, qtensor, kv_type, layer_idx):
184
- """Dequantizes back the tensor that was quantized by `self._quantize()`"""
185
- from compressed_tensors.quantization.lifecycle.forward import dequantize
186
-
187
- if kv_type == KVCacheScaleType.KEY:
188
- scale = self.k_scales[layer_idx]
189
- zp = self.k_zps[layer_idx]
190
- else:
191
- assert kv_type == KVCacheScaleType.VALUE
192
- scale = self.v_scales[layer_idx]
193
- zp = self.v_zps[layer_idx]
194
-
195
- qdq_tensor = dequantize(
196
- x_q=qtensor,
197
- scale=scale,
198
- zero_point=zp,
199
- args=self.quantization_args,
200
- )
201
- return qdq_tensor
@@ -1,70 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import logging
17
-
18
- from compressed_tensors.quantization.quant_config import QuantizationStatus
19
- from compressed_tensors.utils import is_module_offloaded, update_parameter_data
20
- from torch.nn import Module
21
-
22
-
23
- __all__ = [
24
- "set_module_for_calibration",
25
- ]
26
-
27
-
28
- _LOGGER = logging.getLogger(__name__)
29
-
30
-
31
- def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = True):
32
- """
33
- marks a layer as ready for calibration which activates observers
34
- to update scales and zero points on each forward pass
35
-
36
- apply to full model with `model.apply(set_module_for_calibration)`
37
-
38
- :param module: module to set for calibration
39
- :param quantize_weights_upfront: whether to automatically
40
- run weight quantization at the start of calibration
41
- """
42
- if not getattr(module, "quantization_scheme", None):
43
- # no quantization scheme nothing to do
44
- return
45
- status = getattr(module, "quantization_status", None)
46
- if not status or status != QuantizationStatus.INITIALIZED:
47
- _LOGGER.warning(
48
- f"Attempting set module with status {status} to calibration mode. "
49
- f"but status is not {QuantizationStatus.INITIALIZED} - you may "
50
- "be calibrating an uninitialized module which may fail or attempting "
51
- "to re-calibrate a frozen module"
52
- )
53
-
54
- if quantize_weights_upfront and module.quantization_scheme.weights is not None:
55
- # set weight scale and zero_point up front, calibration data doesn't affect it
56
- observer = module.weight_observer
57
- g_idx = getattr(module, "weight_g_idx", None)
58
-
59
- offloaded = is_module_offloaded(module)
60
- if offloaded:
61
- module._hf_hook.pre_forward(module)
62
-
63
- scale, zero_point = observer(module.weight, g_idx=g_idx)
64
- update_parameter_data(module, scale, "weight_scale")
65
- update_parameter_data(module, zero_point, "weight_zero_point")
66
-
67
- if offloaded:
68
- module._hf_hook.post_forward(module, None)
69
-
70
- module.quantization_status = QuantizationStatus.CALIBRATION
@@ -1,55 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from compressed_tensors.quantization.quant_config import QuantizationStatus
17
- from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
18
- from torch.nn import Module
19
-
20
-
21
- __all__ = [
22
- "freeze_module_quantization",
23
- ]
24
-
25
-
26
- def freeze_module_quantization(module: Module):
27
- """
28
- deletes observers so static quantization is completed.
29
-
30
- apply to full model with `model.apply(freeze_module_quantization)`
31
-
32
- :param module: module to freeze quantization for
33
- """
34
- scheme = getattr(module, "quantization_scheme", None)
35
- if not scheme:
36
- # no quantization scheme nothing to do
37
- return
38
-
39
- if module.quantization_status == QuantizationStatus.FROZEN:
40
- # nothing to do, already frozen
41
- return
42
-
43
- # delete observers from module if not dynamic
44
- if scheme.input_activations and not scheme.input_activations.dynamic:
45
- delattr(module, "input_observer")
46
- if scheme.weights and not scheme.weights.dynamic:
47
- delattr(module, "weight_observer")
48
- if (
49
- scheme.output_activations
50
- and not is_kv_cache_quant_scheme(scheme)
51
- and not scheme.output_activations.dynamic
52
- ):
53
- delattr(module, "output_observer")
54
-
55
- module.quantization_status = QuantizationStatus.FROZEN
@@ -1,21 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # flake8: noqa
16
- # isort: skip_file
17
-
18
- from .helpers import *
19
- from .base import *
20
- from .min_max import *
21
- from .mse import *
@@ -1,213 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import logging
16
- from math import ceil
17
- from typing import Any, Iterable, Optional, Tuple, Union
18
-
19
- import torch
20
- from compressed_tensors.quantization.quant_args import (
21
- QuantizationArgs,
22
- QuantizationStrategy,
23
- )
24
- from compressed_tensors.registry.registry import RegistryMixin
25
- from compressed_tensors.utils import safe_permute
26
- from torch import FloatTensor, IntTensor, Tensor
27
- from torch.nn import Module
28
-
29
-
30
- _LOGGER = logging.getLogger(__name__)
31
-
32
-
33
- __all__ = ["Observer"]
34
-
35
-
36
- class Observer(Module, RegistryMixin):
37
- """
38
- Base Observer class to be subclassed for specific implementation.
39
- Subclasses should override `calculate_qparams` to return a scale, zero_point
40
- pair
41
- """
42
-
43
- def __init__(self, quantization_args: QuantizationArgs):
44
- self.quantization_args: QuantizationArgs = quantization_args
45
- super().__init__()
46
- self._scale = None
47
- self._zero_point = None
48
- self._num_observed_tokens = None
49
-
50
- @torch.no_grad()
51
- def forward(
52
- self, observed: Tensor, g_idx: Optional[Tensor] = None
53
- ) -> Tuple[FloatTensor, IntTensor]:
54
- """
55
- maps directly to get_qparams
56
- :param observed: optional observed tensor from which to calculate
57
- quantization parameters
58
- :param g_idx: optional mapping from column index to group index
59
- :return: tuple of scale and zero point based on last observed value
60
- """
61
- self.record_observed_tokens(observed)
62
- return self.get_qparams(observed=observed, g_idx=g_idx)
63
-
64
- def calculate_qparams(
65
- self,
66
- observed: Tensor,
67
- reduce_dims: Optional[Tuple[int]] = None,
68
- ) -> Tuple[FloatTensor, IntTensor]:
69
- """
70
- :param observed: observed tensor to calculate quantization parameters for
71
- :param reduce_dims: optional tuple of dimensions to reduce along,
72
- returned scale and zero point will be shaped (1,) along the
73
- reduced dimensions
74
- :return: tuple of scale and zero point derived from the observed tensor
75
- """
76
- raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
77
-
78
- def post_calculate_qparams(self) -> None:
79
- """
80
- Run any logic specific to its observers after running calculate_qparams
81
- """
82
- ...
83
-
84
- def get_qparams(
85
- self,
86
- observed: Optional[Tensor] = None,
87
- g_idx: Optional[Tensor] = None,
88
- ) -> Tuple[FloatTensor, IntTensor]:
89
- """
90
- Convenience function to wrap overwritten calculate_qparams
91
- adds support to make observed tensor optional and support for tracking latest
92
- calculated scale and zero point
93
-
94
- :param observed: optional observed tensor to calculate quantization parameters
95
- from
96
- :param g_idx: optional mapping from column index to group index
97
- :return: tuple of scale and zero point based on last observed value
98
- """
99
- if observed is not None:
100
- group_size = self.quantization_args.group_size
101
-
102
- if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
103
-
104
- # re-calculate scale and zero point, update the stored value
105
- self._scale, self._zero_point = self.calculate_qparams(observed)
106
-
107
- elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
108
- rows = observed.shape[0]
109
- columns = observed.shape[1]
110
- num_groups = int(ceil(columns / group_size))
111
- self._scale = torch.empty(
112
- (rows, num_groups), dtype=observed.dtype, device=observed.device
113
- )
114
- zp_dtype = self.quantization_args.pytorch_dtype()
115
- self._zero_point = torch.empty(
116
- (rows, num_groups), dtype=zp_dtype, device=observed.device
117
- )
118
-
119
- # support column-order (default) quantization as well as other orderings
120
- # such as activation ordering. Below checks if g_idx has initialized
121
- is_column_order = g_idx is None or -1 in g_idx
122
- if is_column_order:
123
- group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
124
- else:
125
- group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
126
- group_sizes = group_sizes[torch.argsort(group_indices)]
127
-
128
- perm = torch.argsort(g_idx)
129
- observed = safe_permute(observed, perm, dim=1)
130
-
131
- # TODO: experiment with vectorizing for loop for performance
132
- end = 0
133
- for group_index, group_count in enumerate(group_sizes):
134
- start = end
135
- end = start + group_count
136
- scale, zero_point = self.get_qparams_along_dim(
137
- observed[:, start:end],
138
- 0,
139
- tensor_id=group_index,
140
- )
141
-
142
- self._scale[:, group_index] = scale.squeeze(1)
143
- self._zero_point[:, group_index] = zero_point.squeeze(1)
144
-
145
- elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
146
- # assume observed is transposed, because its the output, hence use dim 0
147
- self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
148
-
149
- elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
150
- # use dim 1, assume the obsersed.shape = [batch, token, hidden]
151
- # should be batch, token
152
- self._scale, self._zero_point = self.get_qparams_along_dim(
153
- observed,
154
- dim={0, 1},
155
- )
156
-
157
- return self._scale, self._zero_point
158
-
159
- def get_qparams_along_dim(
160
- self,
161
- observed,
162
- dim: Union[int, Iterable[int]],
163
- tensor_id: Optional[Any] = None,
164
- ):
165
- if isinstance(dim, int):
166
- dim = [dim]
167
- dim = set(dim)
168
-
169
- reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
170
- return self.calculate_qparams(
171
- observed, reduce_dims=reduce_dims, tensor_id=tensor_id
172
- )
173
-
174
- def record_observed_tokens(self, batch_tensor: Tensor):
175
- """
176
- Counts the number of tokens observed during the
177
- forward passes. The count is aggregated in the
178
- _num_observed_tokens attribute of the class.
179
-
180
- Note: The batch_tensor is expected to have two dimensions
181
- (batch_size * sequence_length, num_features). This is the
182
- general shape expected by the forward pass of the expert
183
- layers in a MOE model. If the input tensor does not have
184
- two dimensions, the _num_observed_tokens attribute will be set
185
- to None.
186
- """
187
- if not isinstance(batch_tensor, Tensor):
188
- raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")
189
-
190
- if batch_tensor.ndim != 2:
191
- _LOGGER.debug(
192
- "The input tensor is expected to have two dimensions "
193
- "(batch_size * sequence_length, num_features). "
194
- f"The input tensor has {batch_tensor.ndim} dimensions."
195
- )
196
- return
197
-
198
- if self._num_observed_tokens is None:
199
- # initialize the count
200
- self._num_observed_tokens = 0
201
-
202
- # batch_tensor (batch_size * sequence_length, num_features)
203
- # observed_tokens (batch_size * sequence_length)
204
- observed_tokens, _ = batch_tensor.shape
205
- self._num_observed_tokens += observed_tokens
206
-
207
- def reset(self):
208
- """
209
- Reset the state of the observer
210
- """
211
- self._num_observed_tokens = None
212
- self._scale = None
213
- self._zero_point = None