compressed-tensors-nightly 0.3.3.20240519__py3-none-any.whl → 0.3.3.20240521__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/quantization/lifecycle/forward.py +1 -24
- compressed_tensors/quantization/observers/base.py +11 -20
- compressed_tensors/quantization/observers/memoryless.py +15 -9
- compressed_tensors/quantization/observers/min_max.py +0 -4
- {compressed_tensors_nightly-0.3.3.20240519.dist-info → compressed_tensors_nightly-0.3.3.20240521.dist-info}/METADATA +1 -1
- {compressed_tensors_nightly-0.3.3.20240519.dist-info → compressed_tensors_nightly-0.3.3.20240521.dist-info}/RECORD +9 -9
- {compressed_tensors_nightly-0.3.3.20240519.dist-info → compressed_tensors_nightly-0.3.3.20240521.dist-info}/LICENSE +0 -0
- {compressed_tensors_nightly-0.3.3.20240519.dist-info → compressed_tensors_nightly-0.3.3.20240521.dist-info}/WHEEL +0 -0
- {compressed_tensors_nightly-0.3.3.20240519.dist-info → compressed_tensors_nightly-0.3.3.20240521.dist-info}/top_level.txt +0 -0
@@ -150,7 +150,6 @@ def _process_quantization(
|
|
150
150
|
q_min = torch.tensor(-bit_range / 2, device=x.device)
|
151
151
|
group_size = args.group_size
|
152
152
|
|
153
|
-
# group
|
154
153
|
if args.strategy == QuantizationStrategy.GROUP:
|
155
154
|
|
156
155
|
if do_dequantize: # if dequantizing the output should be a fp type
|
@@ -195,29 +194,7 @@ def _process_quantization(
|
|
195
194
|
)
|
196
195
|
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
|
197
196
|
|
198
|
-
# channel
|
199
|
-
elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
|
200
|
-
if do_quantize:
|
201
|
-
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
|
202
|
-
if do_dequantize:
|
203
|
-
output = _dequantize(output if do_quantize else x, scale, zero_point)
|
204
|
-
|
205
|
-
# per-token
|
206
|
-
elif args.strategy == QuantizationStrategy.TOKEN:
|
207
|
-
# before: scale shape = [num_tokens]
|
208
|
-
# after: scale shape = [num_tokens, 1]
|
209
|
-
# x.shape = 1, num_tokens, 1]
|
210
|
-
# scale gets broadcasted as expected withput having [1, num_tokens, 1] shape
|
211
|
-
|
212
|
-
scale = scale.unsqueeze(1)
|
213
|
-
zero_point = zero_point.unsqueeze(1)
|
214
|
-
|
215
|
-
if do_quantize:
|
216
|
-
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
|
217
|
-
if do_dequantize:
|
218
|
-
output = _dequantize(output if do_quantize else x, scale, zero_point)
|
219
|
-
|
220
|
-
else:
|
197
|
+
else: # covers channel, token and tensor strategies
|
221
198
|
if do_quantize:
|
222
199
|
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
|
223
200
|
if do_dequantize:
|
@@ -50,9 +50,16 @@ class Observer(Module, RegistryMixin):
|
|
50
50
|
"""
|
51
51
|
return self.get_qparams(observed=observed)
|
52
52
|
|
53
|
-
def calculate_qparams(
|
53
|
+
def calculate_qparams(
|
54
|
+
self,
|
55
|
+
observed: Tensor,
|
56
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
57
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
54
58
|
"""
|
55
59
|
:param observed: observed tensor to calculate quantization parameters for
|
60
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
61
|
+
returned scale and zero point will be shaped (1,) along the
|
62
|
+
reduced dimensions
|
56
63
|
:return: tuple of scale and zero point derived from the observed tensor
|
57
64
|
"""
|
58
65
|
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
|
@@ -70,6 +77,7 @@ class Observer(Module, RegistryMixin):
|
|
70
77
|
Convenience function to wrap overwritten calculate_qparams
|
71
78
|
adds support to make observed tensor optional and support for tracking latest
|
72
79
|
calculated scale and zero point
|
80
|
+
|
73
81
|
:param observed: optional observed tensor to calculate quantization parameters
|
74
82
|
from
|
75
83
|
:return: tuple of scale and zero point based on last observed value
|
@@ -100,10 +108,8 @@ class Observer(Module, RegistryMixin):
|
|
100
108
|
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
|
101
109
|
|
102
110
|
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
|
103
|
-
|
104
111
|
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
|
105
112
|
# should be batch, token
|
106
|
-
|
107
113
|
self._scale, self._zero_point = self.get_qparams_along_dim(
|
108
114
|
observed, dim=1
|
109
115
|
)
|
@@ -111,20 +117,5 @@ class Observer(Module, RegistryMixin):
|
|
111
117
|
return self._scale, self._zero_point
|
112
118
|
|
113
119
|
def get_qparams_along_dim(self, observed, dim: int):
|
114
|
-
|
115
|
-
|
116
|
-
# TODO: generalize the logic for reduce_dims
|
117
|
-
scales, zero_points = [], []
|
118
|
-
|
119
|
-
# TODO: make a more generic way to get the channel
|
120
|
-
num_dims = observed.shape[dim]
|
121
|
-
|
122
|
-
for dim_idx in range(num_dims):
|
123
|
-
scale, zero_point = self.calculate_qparams(
|
124
|
-
observed.select(dim=dim, index=dim_idx)
|
125
|
-
)
|
126
|
-
|
127
|
-
scales.append(scale)
|
128
|
-
zero_points.append(zero_point)
|
129
|
-
# breakpoint()
|
130
|
-
return torch.stack(scales), torch.stack(zero_points)
|
120
|
+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
|
121
|
+
return self.calculate_qparams(observed, reduce_dims=reduce_dims)
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Tuple
|
15
|
+
from typing import Optional, Tuple
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from compressed_tensors.quantization.observers.base import Observer
|
@@ -30,19 +30,25 @@ class MemorylessObserver(Observer):
|
|
30
30
|
zero point based on the latest observed value without tracking state
|
31
31
|
"""
|
32
32
|
|
33
|
-
def calculate_qparams(
|
33
|
+
def calculate_qparams(
|
34
|
+
self,
|
35
|
+
observed: Tensor,
|
36
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
37
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
34
38
|
"""
|
35
|
-
Returns the min and max values of observed
|
39
|
+
Returns the min and max values of observed tensor
|
36
40
|
|
37
41
|
:param observed: observed tensor to calculate quantization parameters for
|
42
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
43
|
+
returned scale and zero point will be shaped (1,) along the
|
44
|
+
reduced dimensions
|
38
45
|
:return: tuple of scale and zero point derived from the observed tensor
|
39
46
|
"""
|
40
|
-
# TODO: Add support for full range of quantization Args, only supports 8bit
|
41
|
-
# per tensor
|
42
|
-
min_val, max_val = torch.aminmax(observed)
|
43
47
|
|
44
|
-
|
45
|
-
|
46
|
-
|
48
|
+
if not reduce_dims:
|
49
|
+
min_val, max_val = torch.aminmax(observed)
|
50
|
+
else:
|
51
|
+
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
|
52
|
+
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
|
47
53
|
|
48
54
|
return calculate_qparams(min_val, max_val, self.quantization_args)
|
@@ -74,7 +74,3 @@ class MovingAverageMinMaxObserver(Observer):
|
|
74
74
|
)
|
75
75
|
|
76
76
|
return calculate_qparams(self.min_val, self.max_val, self.quantization_args)
|
77
|
-
|
78
|
-
def get_qparams_along_dim(self, observed, dim: int):
|
79
|
-
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
|
80
|
-
return self.calculate_qparams(observed, reduce_dims=reduce_dims)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: compressed-tensors-nightly
|
3
|
-
Version: 0.3.3.
|
3
|
+
Version: 0.3.3.20240521
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -21,14 +21,14 @@ compressed_tensors/quantization/lifecycle/__init__.py,sha256=ggRGWRqhCxCaTTDWRcg
|
|
21
21
|
compressed_tensors/quantization/lifecycle/apply.py,sha256=whKfNGC_EZm0BC23AP7qWfjRe5OJVWmcZOpX7lryZZc,7625
|
22
22
|
compressed_tensors/quantization/lifecycle/calibration.py,sha256=mLns4jlaWmBwOW8Jtlm5bMX-JET1AiZYUBO7qa-XuxI,1776
|
23
23
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=VreB10xPwgSLQQlTu20UCrFpRS--cA7-lx5s7nrPPrg,2247
|
24
|
-
compressed_tensors/quantization/lifecycle/forward.py,sha256=
|
24
|
+
compressed_tensors/quantization/lifecycle/forward.py,sha256=x9JaIX3TK7cb_-0aCOTTYtA4At9l6v5YOY_70GzIeFU,10520
|
25
25
|
compressed_tensors/quantization/lifecycle/frozen.py,sha256=h1XYt89MouBTf3jTYLG_6OdFxIu5q2N8tPjsy6J4E6Y,1726
|
26
26
|
compressed_tensors/quantization/lifecycle/initialize.py,sha256=U6g9qifSF6pagQZQZEwd-rwWC6uQ_dZXn1wg6nr1Abg,3697
|
27
27
|
compressed_tensors/quantization/observers/__init__.py,sha256=DNH31NQYrIBBcmHsMyFA6whh4pbRsLwuNa6L8AeXaGc,745
|
28
|
-
compressed_tensors/quantization/observers/base.py,sha256=
|
28
|
+
compressed_tensors/quantization/observers/base.py,sha256=yIV2bd9PKPZwodgiBTZEco2ARbD3B0rOKDC0MOFluZs,4900
|
29
29
|
compressed_tensors/quantization/observers/helpers.py,sha256=JwALNfBYY9Eyl8Q180t0lGh8szumQj8TygfNl-isErs,2166
|
30
|
-
compressed_tensors/quantization/observers/memoryless.py,sha256=
|
31
|
-
compressed_tensors/quantization/observers/min_max.py,sha256=
|
30
|
+
compressed_tensors/quantization/observers/memoryless.py,sha256=Gach22cZLhDms6ueKF56XOiLhyWVIEYIEXRRXP5Nu8I,2045
|
31
|
+
compressed_tensors/quantization/observers/min_max.py,sha256=OGrtyn6_sWuTSx5QgUPVKRIiarfWrK9QqXeRXoJQynw,2861
|
32
32
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
33
33
|
compressed_tensors/quantization/utils/helpers.py,sha256=NzAH18Cn_-mTAR87y6IlcQU5gC393XSjgNKC9CRkr78,6017
|
34
34
|
compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
|
@@ -36,8 +36,8 @@ compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85S
|
|
36
36
|
compressed_tensors/utils/__init__.py,sha256=5DrYjoZbaEvSkJcC-GRSbM_RBHVF4tG9gMd3zsJnjLw,665
|
37
37
|
compressed_tensors/utils/helpers.py,sha256=h0jfl9drs5FAx40tCHRcVtJqXixB5hT5yq_IG2aY_-w,1735
|
38
38
|
compressed_tensors/utils/safetensors_load.py,sha256=wo9UirGrGlenBqZeqotvpCT7D5MEdjCo2J3HeRaIFoU,8502
|
39
|
-
compressed_tensors_nightly-0.3.3.
|
40
|
-
compressed_tensors_nightly-0.3.3.
|
41
|
-
compressed_tensors_nightly-0.3.3.
|
42
|
-
compressed_tensors_nightly-0.3.3.
|
43
|
-
compressed_tensors_nightly-0.3.3.
|
39
|
+
compressed_tensors_nightly-0.3.3.20240521.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
40
|
+
compressed_tensors_nightly-0.3.3.20240521.dist-info/METADATA,sha256=DTxrrkh-4Wr9G5MAOS_2ILUsgrOIT-RDYi2IiVc13xg,5633
|
41
|
+
compressed_tensors_nightly-0.3.3.20240521.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
42
|
+
compressed_tensors_nightly-0.3.3.20240521.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
43
|
+
compressed_tensors_nightly-0.3.3.20240521.dist-info/RECORD,,
|
File without changes
|
File without changes
|