compressed-tensors-nightly 0.7.1.20241030__py3-none-any.whl → 0.7.1.20241101__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. compressed_tensors/quantization/__init__.py +0 -1
  2. compressed_tensors/quantization/lifecycle/__init__.py +0 -2
  3. compressed_tensors/quantization/lifecycle/apply.py +1 -16
  4. compressed_tensors/quantization/lifecycle/forward.py +13 -107
  5. compressed_tensors/quantization/lifecycle/initialize.py +18 -21
  6. compressed_tensors/quantization/quant_args.py +1 -8
  7. compressed_tensors/quantization/utils/helpers.py +127 -8
  8. {compressed_tensors_nightly-0.7.1.20241030.dist-info → compressed_tensors_nightly-0.7.1.20241101.dist-info}/METADATA +1 -1
  9. {compressed_tensors_nightly-0.7.1.20241030.dist-info → compressed_tensors_nightly-0.7.1.20241101.dist-info}/RECORD +12 -20
  10. compressed_tensors/quantization/cache.py +0 -200
  11. compressed_tensors/quantization/lifecycle/calibration.py +0 -80
  12. compressed_tensors/quantization/lifecycle/frozen.py +0 -50
  13. compressed_tensors/quantization/observers/__init__.py +0 -21
  14. compressed_tensors/quantization/observers/base.py +0 -213
  15. compressed_tensors/quantization/observers/helpers.py +0 -149
  16. compressed_tensors/quantization/observers/min_max.py +0 -104
  17. compressed_tensors/quantization/observers/mse.py +0 -164
  18. {compressed_tensors_nightly-0.7.1.20241030.dist-info → compressed_tensors_nightly-0.7.1.20241101.dist-info}/LICENSE +0 -0
  19. {compressed_tensors_nightly-0.7.1.20241030.dist-info → compressed_tensors_nightly-0.7.1.20241101.dist-info}/WHEEL +0 -0
  20. {compressed_tensors_nightly-0.7.1.20241030.dist-info → compressed_tensors_nightly-0.7.1.20241101.dist-info}/top_level.txt +0 -0
@@ -1,149 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from collections import Counter
16
- from typing import Tuple
17
-
18
- import torch
19
- from compressed_tensors.quantization.quant_args import (
20
- FP8_DTYPE,
21
- QuantizationArgs,
22
- QuantizationStrategy,
23
- QuantizationType,
24
- )
25
- from torch import FloatTensor, IntTensor, Tensor
26
-
27
-
28
- __all__ = [
29
- "calculate_qparams",
30
- "get_observer_token_count",
31
- "calculate_range",
32
- "compute_dynamic_scales_and_zp",
33
- ]
34
-
35
-
36
- def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
37
- """
38
- Returns the computed scales and zero points for dynamic activation
39
- qunatization.
40
-
41
- :param value: tensor to calculate quantization parameters for
42
- :param args: quantization args
43
- :param reduce_dims: optional tuple of dimensions to reduce along,
44
- returned scale and zero point will be shaped (1,) along the
45
- reduced dimensions
46
- :return: tuple of scale and zero point derived from the observed tensor
47
- """
48
- if args.strategy == QuantizationStrategy.TOKEN:
49
- dim = {1, 2}
50
- reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
51
- elif args.strategy == QuantizationStrategy.TENSOR:
52
- reduce_dims = None
53
- else:
54
- raise ValueError(
55
- f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
56
- "must be used for dynamic quantization",
57
- )
58
-
59
- if not reduce_dims:
60
- min_val, max_val = torch.aminmax(value)
61
- else:
62
- min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
63
- max_val = torch.amax(value, dim=reduce_dims, keepdims=True)
64
-
65
- return calculate_qparams(min_val, max_val, args)
66
-
67
-
68
- def get_observer_token_count(module: torch.nn.Module) -> Counter:
69
- """
70
- Parse the module and return the number of tokens observed by
71
- each module's observer.
72
-
73
- :param module: module to parse
74
- :return: counter with the number of tokens observed by each observer
75
- """
76
- token_counts = Counter()
77
- for name, module in module.named_modules():
78
- if name.endswith(".input_observer"):
79
- token_counts[
80
- name.replace(".input_observer", "")
81
- ] = module._num_observed_tokens
82
- return token_counts
83
-
84
-
85
- def calculate_qparams(
86
- min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
87
- ) -> Tuple[FloatTensor, IntTensor]:
88
- """
89
- :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
90
- from
91
- :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
92
- from
93
- :param quantization_args: settings to quantization
94
- :return: tuple of the calculated scale(s) and zero point(s)
95
- """
96
- min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
97
- max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
98
- device = min_vals.device
99
-
100
- bit_min, bit_max = calculate_range(quantization_args, device)
101
- bit_range = bit_max - bit_min
102
- zp_dtype = quantization_args.pytorch_dtype()
103
-
104
- if quantization_args.symmetric:
105
- max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
106
- scales = max_val_pos / (float(bit_range) / 2)
107
- scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
108
- zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
109
- else:
110
- scales = (max_vals - min_vals) / float(bit_range)
111
- scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
112
- zero_points = bit_min - (min_vals / scales)
113
- zero_points = torch.clamp(zero_points, bit_min, bit_max)
114
-
115
- # match zero-points to quantized type
116
- zero_points = zero_points.to(zp_dtype)
117
-
118
- if scales.ndim == 0:
119
- scales = scales.reshape(1)
120
- zero_points = zero_points.reshape(1)
121
-
122
- return scales, zero_points
123
-
124
-
125
- def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
126
- """
127
- Calculated the effective quantization range for the given Quantization Args
128
-
129
- :param quantization_args: quantization args to get range of
130
- :param device: device to store the range to
131
- :return: tuple endpoints for the given quantization range
132
- """
133
- if quantization_args.type == QuantizationType.INT:
134
- bit_range = 2**quantization_args.num_bits
135
- q_max = torch.tensor(bit_range / 2 - 1, device=device)
136
- q_min = torch.tensor(-bit_range / 2, device=device)
137
- elif quantization_args.type == QuantizationType.FLOAT:
138
- if quantization_args.num_bits != 8:
139
- raise ValueError(
140
- "Floating point quantization is only supported for 8 bits,"
141
- f"got {quantization_args.num_bits}"
142
- )
143
- fp_range_info = torch.finfo(FP8_DTYPE)
144
- q_max = torch.tensor(fp_range_info.max, device=device)
145
- q_min = torch.tensor(fp_range_info.min, device=device)
146
- else:
147
- raise ValueError(f"Invalid quantization type {quantization_args.type}")
148
-
149
- return q_min, q_max
@@ -1,104 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Any, Optional, Tuple
16
-
17
- import torch
18
- from compressed_tensors.quantization.observers.base import Observer
19
- from compressed_tensors.quantization.observers.helpers import calculate_qparams
20
- from compressed_tensors.quantization.quant_args import QuantizationArgs
21
- from torch import FloatTensor, IntTensor, Tensor
22
-
23
-
24
- __all__ = ["MovingAverageMinMaxObserver"]
25
-
26
-
27
- @Observer.register("minmax")
28
- class MovingAverageMinMaxObserver(Observer):
29
- """
30
- Implements a dynamic quantization observer that sets the scale and
31
- zero point based on a moving average of the overall min and max observed values
32
- """
33
-
34
- def __init__(
35
- self, quantization_args: QuantizationArgs, averaging_constant: float = 0.01
36
- ):
37
- super().__init__(quantization_args=quantization_args)
38
-
39
- self.min_val = {}
40
- self.max_val = {}
41
- self.averaging_constant = averaging_constant
42
-
43
- def calculate_qparams(
44
- self,
45
- observed: Tensor,
46
- reduce_dims: Optional[Tuple[int]] = None,
47
- tensor_id: Optional[Any] = None,
48
- ) -> Tuple[FloatTensor, IntTensor]:
49
- """
50
- Updates the observed min and max using a moving average smoothed by the
51
- averaging_constant
52
-
53
- :param observed: observed tensor to calculate quantization parameters for
54
- :param reduce_dims: optional tuple of dimensions to reduce along,
55
- returned scale and zero point will be shaped (1,) along the
56
- reduced dimensions
57
- :param tensor_id: Optional id if different ranges of observed tensors are
58
- passed, useful for sharding tensors by group_size
59
- :return: tuple of scale and zero point derived from the observed tensor
60
- """
61
- tensor_id = tensor_id or "default"
62
-
63
- if not reduce_dims:
64
- min_val, max_val = torch.aminmax(observed)
65
- else:
66
- min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
67
- max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
68
-
69
- running_min_val = self.min_val.get(tensor_id, None)
70
- running_max_val = self.max_val.get(tensor_id, None)
71
-
72
- if running_min_val is None or running_max_val is None:
73
- updated_min_val = min_val
74
- updated_max_val = max_val
75
- else:
76
- updated_min_val = running_min_val + self.averaging_constant * (
77
- min_val - running_min_val
78
- )
79
- updated_max_val = running_max_val + self.averaging_constant * (
80
- max_val - running_max_val
81
- )
82
-
83
- self.min_val[tensor_id] = updated_min_val
84
- self.max_val[tensor_id] = updated_max_val
85
-
86
- return calculate_qparams(
87
- updated_min_val, updated_max_val, self.quantization_args
88
- )
89
-
90
- def get_qparams_along_dim(
91
- self, observed, dim: int, tensor_id: Optional[Any] = None
92
- ):
93
- reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
94
- return self.calculate_qparams(
95
- observed, reduce_dims=reduce_dims, tensor_id=tensor_id
96
- )
97
-
98
- def reset(self):
99
- """
100
- Reset the state of the observer, including min and maximum values
101
- """
102
- super().reset()
103
- self.min_val = {}
104
- self.max_val = {}
@@ -1,164 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Any, Optional, Tuple
16
-
17
- import torch
18
- from compressed_tensors.quantization.observers.base import Observer
19
- from compressed_tensors.quantization.observers.helpers import calculate_qparams
20
- from compressed_tensors.quantization.quant_args import QuantizationArgs
21
- from torch import FloatTensor, IntTensor, Tensor
22
-
23
-
24
- __all__ = ["MovingAverageMSEObserver"]
25
-
26
-
27
- @Observer.register("mse")
28
- class MovingAverageMSEObserver(Observer):
29
- """
30
- Implements a dynamic quantization observer that sets the scale and
31
- zero point based on a moving average of the mse-clipped min and max observed values
32
- """
33
-
34
- def __init__(
35
- self,
36
- quantization_args: QuantizationArgs,
37
- averaging_constant: float = 0.01,
38
- grid: float = 100.0,
39
- maxshrink: float = 0.80,
40
- norm: float = 2.4,
41
- ):
42
- super().__init__(quantization_args=quantization_args)
43
-
44
- self.min_val = {}
45
- self.max_val = {}
46
- self.averaging_constant = averaging_constant
47
- self.grid = grid
48
- self.maxshrink = maxshrink
49
- self.norm = norm
50
-
51
- def calculate_mse_min_max(
52
- self,
53
- observed: Tensor,
54
- reduce_dims: Optional[Tuple[int]] = None,
55
- ):
56
- """
57
- Computes the mse-clipped min and max values of the observed tensor by
58
- optimizing for quantization error
59
-
60
- :param observed: observed tensor to calculate quantization parameters for
61
- :param reduce_dims: optional tuple of dimensions to reduce along,
62
- returned values will be shaped (1,) along the reduced dimensions
63
- :return: tuple of min and max values derived from the observed tensor
64
- """
65
- from compressed_tensors.quantization.lifecycle import fake_quantize
66
-
67
- if not reduce_dims:
68
- absolute_min_val, absolute_max_val = torch.aminmax(observed)
69
- else:
70
- absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
71
- absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
72
-
73
- best = torch.full_like(
74
- absolute_min_val, torch.finfo(absolute_min_val.dtype).max
75
- )
76
- min_val = torch.ones_like(absolute_min_val)
77
- max_val = torch.zeros_like(absolute_max_val)
78
- for i in range(int(self.maxshrink * self.grid)):
79
- p = 1 - i / self.grid
80
- shrinked_min_val = p * absolute_min_val
81
- shrinked_max_val = p * absolute_max_val
82
-
83
- candidate_scales, candidate_zero_points = calculate_qparams(
84
- shrinked_min_val, shrinked_max_val, self.quantization_args
85
- )
86
- q = fake_quantize(
87
- observed,
88
- candidate_scales,
89
- candidate_zero_points,
90
- self.quantization_args,
91
- )
92
-
93
- q -= observed
94
- q.abs_()
95
- q.pow_(self.norm)
96
- if not reduce_dims:
97
- err = torch.sum(q)
98
- else:
99
- err = torch.sum(q, reduce_dims, keepdims=True)
100
-
101
- tmp = err < best
102
- if torch.any(tmp):
103
- best[tmp] = err[tmp]
104
- min_val[tmp] = shrinked_min_val[tmp]
105
- max_val[tmp] = shrinked_max_val[tmp]
106
- return min_val, max_val
107
-
108
- def calculate_qparams(
109
- self,
110
- observed: Tensor,
111
- reduce_dims: Optional[Tuple[int]] = None,
112
- tensor_id: Optional[Any] = None,
113
- ) -> Tuple[FloatTensor, IntTensor]:
114
- """
115
- Updates the mse-clipped min and max values of the observed tensor using
116
- a moving average smoothed by the averaging_constant
117
-
118
- :param observed: observed tensor to calculate quantization parameters for
119
- :param reduce_dims: optional tuple of dimensions to reduce along,
120
- returned scale and zero point will be shaped (1,) along the
121
- reduced dimensions
122
- :param tensor_id: Optional id if different ranges of observed tensors are
123
- passed, useful for sharding tensors by group_size
124
- :return: tuple of scale and zero point derived from the observed tensor
125
- """
126
- min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims)
127
-
128
- running_min_val = self.min_val.get(tensor_id, None)
129
- running_max_val = self.max_val.get(tensor_id, None)
130
-
131
- if running_min_val is None or running_max_val is None:
132
- updated_min_val = min_val
133
- updated_max_val = max_val
134
- else:
135
- updated_min_val = running_min_val + self.averaging_constant * (
136
- min_val - running_min_val
137
- )
138
- updated_max_val = running_max_val + self.averaging_constant * (
139
- max_val - running_max_val
140
- )
141
-
142
- tensor_id = tensor_id or "default"
143
- self.min_val[tensor_id] = updated_min_val
144
- self.max_val[tensor_id] = updated_max_val
145
-
146
- return calculate_qparams(
147
- updated_min_val, updated_max_val, self.quantization_args
148
- )
149
-
150
- def get_qparams_along_dim(
151
- self, observed, dim: int, tensor_id: Optional[Any] = None
152
- ):
153
- reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
154
- return self.calculate_qparams(
155
- observed, reduce_dims=reduce_dims, tensor_id=tensor_id
156
- )
157
-
158
- def reset(self):
159
- """
160
- Reset the state of the observer, including min and maximum values
161
- """
162
- super().reset()
163
- self.min_val = {}
164
- self.max_val = {}