compressed-tensors 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. compressed_tensors/compressors/model_compressors/model_compressor.py +17 -5
  2. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +4 -2
  3. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +2 -0
  4. compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
  5. compressed_tensors/config/base.py +60 -2
  6. compressed_tensors/linear/compressed_linear.py +3 -1
  7. compressed_tensors/quantization/__init__.py +0 -1
  8. compressed_tensors/quantization/lifecycle/__init__.py +0 -2
  9. compressed_tensors/quantization/lifecycle/apply.py +3 -17
  10. compressed_tensors/quantization/lifecycle/forward.py +24 -87
  11. compressed_tensors/quantization/lifecycle/initialize.py +21 -24
  12. compressed_tensors/quantization/quant_args.py +27 -25
  13. compressed_tensors/quantization/quant_config.py +2 -2
  14. compressed_tensors/quantization/quant_scheme.py +17 -24
  15. compressed_tensors/quantization/utils/helpers.py +125 -8
  16. compressed_tensors/registry/registry.py +1 -1
  17. compressed_tensors/utils/helpers.py +33 -1
  18. compressed_tensors/version.py +1 -1
  19. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/METADATA +1 -1
  20. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/RECORD +23 -31
  21. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/WHEEL +1 -1
  22. compressed_tensors/quantization/cache.py +0 -201
  23. compressed_tensors/quantization/lifecycle/calibration.py +0 -70
  24. compressed_tensors/quantization/lifecycle/frozen.py +0 -55
  25. compressed_tensors/quantization/observers/__init__.py +0 -21
  26. compressed_tensors/quantization/observers/base.py +0 -213
  27. compressed_tensors/quantization/observers/helpers.py +0 -149
  28. compressed_tensors/quantization/observers/min_max.py +0 -104
  29. compressed_tensors/quantization/observers/mse.py +0 -162
  30. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/LICENSE +0 -0
  31. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/top_level.txt +0 -0
@@ -1,162 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Any, Optional, Tuple
16
-
17
- import torch
18
- from compressed_tensors.quantization.observers.base import Observer
19
- from compressed_tensors.quantization.observers.helpers import calculate_qparams
20
- from compressed_tensors.quantization.quant_args import QuantizationArgs
21
- from torch import FloatTensor, IntTensor, Tensor
22
-
23
-
24
- __all__ = ["MovingAverageMSEObserver"]
25
-
26
-
27
- @Observer.register("mse")
28
- class MovingAverageMSEObserver(Observer):
29
- """
30
- Implements a dynamic quantization observer that sets the scale and
31
- zero point based on a moving average of the mse-clipped min and max observed values
32
- """
33
-
34
- def __init__(
35
- self,
36
- quantization_args: QuantizationArgs,
37
- averaging_constant: float = 0.01,
38
- grid: float = 100.0,
39
- maxshrink: float = 0.80,
40
- norm: float = 2.4,
41
- ):
42
- super().__init__(quantization_args=quantization_args)
43
-
44
- self.min_val = {}
45
- self.max_val = {}
46
- self.averaging_constant = averaging_constant
47
- self.grid = grid
48
- self.maxshrink = maxshrink
49
- self.norm = norm
50
-
51
- def calculate_mse_min_max(
52
- self,
53
- observed: Tensor,
54
- reduce_dims: Optional[Tuple[int]] = None,
55
- ):
56
- """
57
- Computes the mse-clipped min and max values of the observed tensor by
58
- optimizing for quantization error
59
-
60
- :param observed: observed tensor to calculate quantization parameters for
61
- :param reduce_dims: optional tuple of dimensions to reduce along,
62
- returned values will be shaped (1,) along the reduced dimensions
63
- :return: tuple of min and max values derived from the observed tensor
64
- """
65
- from compressed_tensors.quantization.lifecycle import fake_quantize
66
-
67
- if not reduce_dims:
68
- absolute_min_val, absolute_max_val = torch.aminmax(observed)
69
- else:
70
- absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
71
- absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
72
-
73
- best = torch.full(absolute_min_val.shape, float("inf"))
74
- min_val = torch.ones(absolute_min_val.shape)
75
- max_val = torch.zeros(absolute_max_val.shape)
76
- for i in range(int(self.maxshrink * self.grid)):
77
- p = 1 - i / self.grid
78
- shrinked_min_val = p * absolute_min_val
79
- shrinked_max_val = p * absolute_max_val
80
-
81
- candidate_scales, candidate_zero_points = calculate_qparams(
82
- shrinked_min_val, shrinked_max_val, self.quantization_args
83
- )
84
- q = fake_quantize(
85
- observed,
86
- candidate_scales,
87
- candidate_zero_points,
88
- self.quantization_args,
89
- )
90
-
91
- q -= observed
92
- q.abs_()
93
- q.pow_(self.norm)
94
- if not reduce_dims:
95
- err = torch.sum(q)
96
- else:
97
- err = torch.sum(q, reduce_dims, keepdims=True)
98
-
99
- tmp = err < best
100
- if torch.any(tmp):
101
- best[tmp] = err[tmp]
102
- min_val[tmp] = shrinked_min_val[tmp]
103
- max_val[tmp] = shrinked_max_val[tmp]
104
- return min_val, max_val
105
-
106
- def calculate_qparams(
107
- self,
108
- observed: Tensor,
109
- reduce_dims: Optional[Tuple[int]] = None,
110
- tensor_id: Optional[Any] = None,
111
- ) -> Tuple[FloatTensor, IntTensor]:
112
- """
113
- Updates the mse-clipped min and max values of the observed tensor using
114
- a moving average smoothed by the averaging_constant
115
-
116
- :param observed: observed tensor to calculate quantization parameters for
117
- :param reduce_dims: optional tuple of dimensions to reduce along,
118
- returned scale and zero point will be shaped (1,) along the
119
- reduced dimensions
120
- :param tensor_id: Optional id if different ranges of observed tensors are
121
- passed, useful for sharding tensors by group_size
122
- :return: tuple of scale and zero point derived from the observed tensor
123
- """
124
- min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims)
125
-
126
- running_min_val = self.min_val.get(tensor_id, None)
127
- running_max_val = self.max_val.get(tensor_id, None)
128
-
129
- if running_min_val is None or running_max_val is None:
130
- updated_min_val = min_val
131
- updated_max_val = max_val
132
- else:
133
- updated_min_val = running_min_val + self.averaging_constant * (
134
- min_val - running_min_val
135
- )
136
- updated_max_val = running_max_val + self.averaging_constant * (
137
- max_val - running_max_val
138
- )
139
-
140
- tensor_id = tensor_id or "default"
141
- self.min_val[tensor_id] = updated_min_val
142
- self.max_val[tensor_id] = updated_max_val
143
-
144
- return calculate_qparams(
145
- updated_min_val, updated_max_val, self.quantization_args
146
- )
147
-
148
- def get_qparams_along_dim(
149
- self, observed, dim: int, tensor_id: Optional[Any] = None
150
- ):
151
- reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
152
- return self.calculate_qparams(
153
- observed, reduce_dims=reduce_dims, tensor_id=tensor_id
154
- )
155
-
156
- def reset(self):
157
- """
158
- Reset the state of the observer, including min and maximum values
159
- """
160
- super().reset()
161
- self.min_val = {}
162
- self.max_val = {}