compressed-tensors-nightly 0.3.3.20240519__py3-none-any.whl → 0.3.3.20240521__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -150,7 +150,6 @@ def _process_quantization(
150
150
  q_min = torch.tensor(-bit_range / 2, device=x.device)
151
151
  group_size = args.group_size
152
152
 
153
- # group
154
153
  if args.strategy == QuantizationStrategy.GROUP:
155
154
 
156
155
  if do_dequantize: # if dequantizing the output should be a fp type
@@ -195,29 +194,7 @@ def _process_quantization(
195
194
  )
196
195
  output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
197
196
 
198
- # channel-wise
199
- elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
200
- if do_quantize:
201
- output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
202
- if do_dequantize:
203
- output = _dequantize(output if do_quantize else x, scale, zero_point)
204
-
205
- # per-token
206
- elif args.strategy == QuantizationStrategy.TOKEN:
207
- # before: scale shape = [num_tokens]
208
- # after: scale shape = [num_tokens, 1]
209
- # x.shape = 1, num_tokens, 1]
210
- # scale gets broadcasted as expected withput having [1, num_tokens, 1] shape
211
-
212
- scale = scale.unsqueeze(1)
213
- zero_point = zero_point.unsqueeze(1)
214
-
215
- if do_quantize:
216
- output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
217
- if do_dequantize:
218
- output = _dequantize(output if do_quantize else x, scale, zero_point)
219
-
220
- else:
197
+ else: # covers channel, token and tensor strategies
221
198
  if do_quantize:
222
199
  output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
223
200
  if do_dequantize:
@@ -50,9 +50,16 @@ class Observer(Module, RegistryMixin):
50
50
  """
51
51
  return self.get_qparams(observed=observed)
52
52
 
53
- def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
53
+ def calculate_qparams(
54
+ self,
55
+ observed: Tensor,
56
+ reduce_dims: Optional[Tuple[int]] = None,
57
+ ) -> Tuple[FloatTensor, IntTensor]:
54
58
  """
55
59
  :param observed: observed tensor to calculate quantization parameters for
60
+ :param reduce_dims: optional tuple of dimensions to reduce along,
61
+ returned scale and zero point will be shaped (1,) along the
62
+ reduced dimensions
56
63
  :return: tuple of scale and zero point derived from the observed tensor
57
64
  """
58
65
  raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
@@ -70,6 +77,7 @@ class Observer(Module, RegistryMixin):
70
77
  Convenience function to wrap overwritten calculate_qparams
71
78
  adds support to make observed tensor optional and support for tracking latest
72
79
  calculated scale and zero point
80
+
73
81
  :param observed: optional observed tensor to calculate quantization parameters
74
82
  from
75
83
  :return: tuple of scale and zero point based on last observed value
@@ -100,10 +108,8 @@ class Observer(Module, RegistryMixin):
100
108
  self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
101
109
 
102
110
  elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
103
-
104
111
  # use dim 1, assume the obsersed.shape = [batch, token, hidden]
105
112
  # should be batch, token
106
-
107
113
  self._scale, self._zero_point = self.get_qparams_along_dim(
108
114
  observed, dim=1
109
115
  )
@@ -111,20 +117,5 @@ class Observer(Module, RegistryMixin):
111
117
  return self._scale, self._zero_point
112
118
 
113
119
  def get_qparams_along_dim(self, observed, dim: int):
114
- # TODO: add documentation that specifies the shape must
115
- # be padded with 1-dims so the scales are along the right channel
116
- # TODO: generalize the logic for reduce_dims
117
- scales, zero_points = [], []
118
-
119
- # TODO: make a more generic way to get the channel
120
- num_dims = observed.shape[dim]
121
-
122
- for dim_idx in range(num_dims):
123
- scale, zero_point = self.calculate_qparams(
124
- observed.select(dim=dim, index=dim_idx)
125
- )
126
-
127
- scales.append(scale)
128
- zero_points.append(zero_point)
129
- # breakpoint()
130
- return torch.stack(scales), torch.stack(zero_points)
120
+ reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
121
+ return self.calculate_qparams(observed, reduce_dims=reduce_dims)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple
15
+ from typing import Optional, Tuple
16
16
 
17
17
  import torch
18
18
  from compressed_tensors.quantization.observers.base import Observer
@@ -30,19 +30,25 @@ class MemorylessObserver(Observer):
30
30
  zero point based on the latest observed value without tracking state
31
31
  """
32
32
 
33
- def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
33
+ def calculate_qparams(
34
+ self,
35
+ observed: Tensor,
36
+ reduce_dims: Optional[Tuple[int]] = None,
37
+ ) -> Tuple[FloatTensor, IntTensor]:
34
38
  """
35
- Returns the min and max values of observed
39
+ Returns the min and max values of observed tensor
36
40
 
37
41
  :param observed: observed tensor to calculate quantization parameters for
42
+ :param reduce_dims: optional tuple of dimensions to reduce along,
43
+ returned scale and zero point will be shaped (1,) along the
44
+ reduced dimensions
38
45
  :return: tuple of scale and zero point derived from the observed tensor
39
46
  """
40
- # TODO: Add support for full range of quantization Args, only supports 8bit
41
- # per tensor
42
- min_val, max_val = torch.aminmax(observed)
43
47
 
44
- # ensure zero is in the range
45
- min_val = torch.min(min_val, torch.zeros_like(min_val))
46
- max_val = torch.max(max_val, torch.zeros_like(max_val))
48
+ if not reduce_dims:
49
+ min_val, max_val = torch.aminmax(observed)
50
+ else:
51
+ min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
52
+ max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
47
53
 
48
54
  return calculate_qparams(min_val, max_val, self.quantization_args)
@@ -74,7 +74,3 @@ class MovingAverageMinMaxObserver(Observer):
74
74
  )
75
75
 
76
76
  return calculate_qparams(self.min_val, self.max_val, self.quantization_args)
77
-
78
- def get_qparams_along_dim(self, observed, dim: int):
79
- reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
80
- return self.calculate_qparams(observed, reduce_dims=reduce_dims)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors-nightly
3
- Version: 0.3.3.20240519
3
+ Version: 0.3.3.20240521
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -21,14 +21,14 @@ compressed_tensors/quantization/lifecycle/__init__.py,sha256=ggRGWRqhCxCaTTDWRcg
21
21
  compressed_tensors/quantization/lifecycle/apply.py,sha256=whKfNGC_EZm0BC23AP7qWfjRe5OJVWmcZOpX7lryZZc,7625
22
22
  compressed_tensors/quantization/lifecycle/calibration.py,sha256=mLns4jlaWmBwOW8Jtlm5bMX-JET1AiZYUBO7qa-XuxI,1776
23
23
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=VreB10xPwgSLQQlTu20UCrFpRS--cA7-lx5s7nrPPrg,2247
24
- compressed_tensors/quantization/lifecycle/forward.py,sha256=sXo7ReS2ehHFwbtwUbhPnsnnj-CZ3iyAZKmUzHxjTKc,11373
24
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=x9JaIX3TK7cb_-0aCOTTYtA4At9l6v5YOY_70GzIeFU,10520
25
25
  compressed_tensors/quantization/lifecycle/frozen.py,sha256=h1XYt89MouBTf3jTYLG_6OdFxIu5q2N8tPjsy6J4E6Y,1726
26
26
  compressed_tensors/quantization/lifecycle/initialize.py,sha256=U6g9qifSF6pagQZQZEwd-rwWC6uQ_dZXn1wg6nr1Abg,3697
27
27
  compressed_tensors/quantization/observers/__init__.py,sha256=DNH31NQYrIBBcmHsMyFA6whh4pbRsLwuNa6L8AeXaGc,745
28
- compressed_tensors/quantization/observers/base.py,sha256=X7zeeFj42JxP_5dX2XbEGHcqLrkiV53-nJN3qhW2NA8,5156
28
+ compressed_tensors/quantization/observers/base.py,sha256=yIV2bd9PKPZwodgiBTZEco2ARbD3B0rOKDC0MOFluZs,4900
29
29
  compressed_tensors/quantization/observers/helpers.py,sha256=JwALNfBYY9Eyl8Q180t0lGh8szumQj8TygfNl-isErs,2166
30
- compressed_tensors/quantization/observers/memoryless.py,sha256=ZHTPh4aURE8LvHBFaP--HIC2JanMX5-VRdIkE2JHthw,1859
31
- compressed_tensors/quantization/observers/min_max.py,sha256=s2I40pzTXrVAjIsavNt6TLAl7-qDUmdc43Xd5rb4XAY,3071
30
+ compressed_tensors/quantization/observers/memoryless.py,sha256=Gach22cZLhDms6ueKF56XOiLhyWVIEYIEXRRXP5Nu8I,2045
31
+ compressed_tensors/quantization/observers/min_max.py,sha256=OGrtyn6_sWuTSx5QgUPVKRIiarfWrK9QqXeRXoJQynw,2861
32
32
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
33
33
  compressed_tensors/quantization/utils/helpers.py,sha256=NzAH18Cn_-mTAR87y6IlcQU5gC393XSjgNKC9CRkr78,6017
34
34
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
@@ -36,8 +36,8 @@ compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85S
36
36
  compressed_tensors/utils/__init__.py,sha256=5DrYjoZbaEvSkJcC-GRSbM_RBHVF4tG9gMd3zsJnjLw,665
37
37
  compressed_tensors/utils/helpers.py,sha256=h0jfl9drs5FAx40tCHRcVtJqXixB5hT5yq_IG2aY_-w,1735
38
38
  compressed_tensors/utils/safetensors_load.py,sha256=wo9UirGrGlenBqZeqotvpCT7D5MEdjCo2J3HeRaIFoU,8502
39
- compressed_tensors_nightly-0.3.3.20240519.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
40
- compressed_tensors_nightly-0.3.3.20240519.dist-info/METADATA,sha256=daj8jdjy8w_qoL30d2ExbyhkV59pFSdJmgGF_GTn_Gk,5633
41
- compressed_tensors_nightly-0.3.3.20240519.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
42
- compressed_tensors_nightly-0.3.3.20240519.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
43
- compressed_tensors_nightly-0.3.3.20240519.dist-info/RECORD,,
39
+ compressed_tensors_nightly-0.3.3.20240521.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
40
+ compressed_tensors_nightly-0.3.3.20240521.dist-info/METADATA,sha256=DTxrrkh-4Wr9G5MAOS_2ILUsgrOIT-RDYi2IiVc13xg,5633
41
+ compressed_tensors_nightly-0.3.3.20240521.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
42
+ compressed_tensors_nightly-0.3.3.20240521.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
43
+ compressed_tensors_nightly-0.3.3.20240521.dist-info/RECORD,,