compressed-tensors 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. compressed_tensors/__init__.py +21 -0
  2. compressed_tensors/base.py +16 -0
  3. compressed_tensors/compressors/__init__.py +25 -0
  4. compressed_tensors/compressors/base.py +79 -0
  5. compressed_tensors/compressors/dense.py +34 -0
  6. compressed_tensors/compressors/helpers.py +161 -0
  7. compressed_tensors/compressors/sparse_bitmask.py +238 -0
  8. compressed_tensors/config/__init__.py +18 -0
  9. compressed_tensors/config/base.py +42 -0
  10. compressed_tensors/config/dense.py +36 -0
  11. compressed_tensors/config/sparse_bitmask.py +36 -0
  12. compressed_tensors/quantization/__init__.py +21 -0
  13. compressed_tensors/quantization/lifecycle/__init__.py +22 -0
  14. compressed_tensors/quantization/lifecycle/apply.py +173 -0
  15. compressed_tensors/quantization/lifecycle/calibration.py +51 -0
  16. compressed_tensors/quantization/lifecycle/forward.py +136 -0
  17. compressed_tensors/quantization/lifecycle/frozen.py +46 -0
  18. compressed_tensors/quantization/lifecycle/initialize.py +96 -0
  19. compressed_tensors/quantization/observers/__init__.py +21 -0
  20. compressed_tensors/quantization/observers/base.py +69 -0
  21. compressed_tensors/quantization/observers/helpers.py +53 -0
  22. compressed_tensors/quantization/observers/memoryless.py +48 -0
  23. compressed_tensors/quantization/observers/min_max.py +65 -0
  24. compressed_tensors/quantization/quant_args.py +85 -0
  25. compressed_tensors/quantization/quant_config.py +171 -0
  26. compressed_tensors/quantization/quant_scheme.py +39 -0
  27. compressed_tensors/quantization/utils/__init__.py +16 -0
  28. compressed_tensors/quantization/utils/helpers.py +115 -0
  29. compressed_tensors/registry/__init__.py +17 -0
  30. compressed_tensors/registry/registry.py +360 -0
  31. compressed_tensors/utils/__init__.py +16 -0
  32. compressed_tensors/utils/helpers.py +151 -0
  33. compressed_tensors/utils/safetensors_load.py +237 -0
  34. compressed_tensors-0.3.0.dist-info/METADATA +22 -0
  35. compressed_tensors-0.3.0.dist-info/RECORD +37 -0
  36. compressed_tensors-0.3.0.dist-info/WHEEL +5 -0
  37. compressed_tensors-0.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,36 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from compressed_tensors.config import CompressionConfig, CompressionFormat
18
+
19
+
20
+ __all__ = ["BitmaskConfig"]
21
+
22
+
23
+ @CompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
24
+ class BitmaskConfig(CompressionConfig):
25
+ """
26
+ Configuration for storing a sparse model using
27
+ bitmask compression
28
+
29
+ :param global_sparsity: average sparsity of the entire model
30
+ :param sparsity_structure: structure of the sparsity, such as
31
+ "unstructured", "2:4", "8:16" etc
32
+ """
33
+
34
+ format: str = CompressionFormat.sparse_bitmask.value
35
+ global_sparsity: Optional[float] = 0.0
36
+ sparsity_structure: Optional[str] = "unstructured"
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ # isort: skip_file
17
+
18
+ from .quant_args import *
19
+ from .quant_config import *
20
+ from .quant_scheme import *
21
+ from .lifecycle import *
@@ -0,0 +1,22 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ # isort: skip_file
17
+
18
+ from .calibration import *
19
+ from .forward import *
20
+ from .frozen import *
21
+ from .initialize import *
22
+ from .apply import *
@@ -0,0 +1,173 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from collections import OrderedDict
17
+ from typing import Dict, Iterable, Optional
18
+
19
+ from compressed_tensors.quantization.lifecycle.calibration import (
20
+ set_module_for_calibration,
21
+ )
22
+ from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
23
+ from compressed_tensors.quantization.lifecycle.initialize import (
24
+ initialize_module_for_quantization,
25
+ )
26
+ from compressed_tensors.quantization.quant_config import (
27
+ QuantizationConfig,
28
+ QuantizationStatus,
29
+ )
30
+ from compressed_tensors.quantization.utils import iter_named_leaf_modules
31
+ from compressed_tensors.utils.safetensors_load import get_safetensors_folder
32
+ from torch.nn import Module
33
+
34
+
35
+ __all__ = [
36
+ "load_pretrained_quantization",
37
+ "apply_quantization_config",
38
+ "apply_quantization_status",
39
+ ]
40
+
41
+ from compressed_tensors.quantization.utils.helpers import is_module_quantized
42
+ from compressed_tensors.utils.safetensors_load import get_quantization_state_dict
43
+
44
+
45
+ def load_pretrained_quantization(model: Module, model_name_or_path: str):
46
+ """
47
+ Loads the quantization parameters (scale and zero point) from model_name_or_path to
48
+ a model that has already been initialized with a quantization config
49
+
50
+ :param model: model to load pretrained quantization parameters to
51
+ :param model_name_or_path: Hugging Face stub or local folder containing a quantized
52
+ model, which is used to load quantization parameters
53
+ """
54
+ model_path = get_safetensors_folder(model_name_or_path)
55
+ state_dict = get_quantization_state_dict(model_path)
56
+
57
+ for name, submodule in iter_named_leaf_modules(model):
58
+ if not is_module_quantized(submodule):
59
+ continue
60
+ if submodule.quantization_scheme.weights is not None:
61
+ base_name = "weight"
62
+ _load_quant_args_from_state_dict(
63
+ base_name=base_name,
64
+ module_name=name,
65
+ module=submodule,
66
+ state_dict=state_dict,
67
+ )
68
+ if submodule.quantization_scheme.input_activations is not None:
69
+ base_name = "input"
70
+ _load_quant_args_from_state_dict(
71
+ base_name=base_name,
72
+ module_name=name,
73
+ module=submodule,
74
+ state_dict=state_dict,
75
+ )
76
+ if submodule.quantization_scheme.output_activations is not None:
77
+ base_name = "output"
78
+ _load_quant_args_from_state_dict(
79
+ base_name=base_name,
80
+ module_name=name,
81
+ module=submodule,
82
+ state_dict=state_dict,
83
+ )
84
+
85
+
86
+ def apply_quantization_config(model: Module, config: QuantizationConfig):
87
+ """
88
+ Initializes the model for quantization in-place based on the given config
89
+
90
+ :param model: model to apply quantization config to
91
+ :param config: quantization config
92
+ """
93
+ # build mapping of targets to schemes for easier matching
94
+ # use ordered dict to preserve target ordering in config
95
+ target_to_scheme = OrderedDict()
96
+ for scheme in config.config_groups.values():
97
+ for target in scheme.targets:
98
+ target_to_scheme[target] = scheme
99
+
100
+ # mark appropriate layers for quantization by setting their quantization schemes
101
+ for name, submodule in iter_named_leaf_modules(model):
102
+ if _find_first_name_or_class_match(name, submodule, config.ignore):
103
+ continue # layer matches ignore list, continue
104
+ target = _find_first_name_or_class_match(name, submodule, target_to_scheme)
105
+ if target is not None:
106
+ # target matched - add layer and scheme to target list
107
+ submodule.quantization_scheme = target_to_scheme[target]
108
+
109
+ # apply current quantization status across all targeted layers
110
+ apply_quantization_status(model, config.quantization_status)
111
+
112
+
113
+ def apply_quantization_status(model: Module, status: QuantizationStatus):
114
+ """
115
+ Applies in place the quantization lifecycle up to the given status
116
+
117
+ :param model: model to apply quantization to
118
+ :param status: status to update the module to
119
+ """
120
+ if status >= QuantizationStatus.INITIALIZED:
121
+ model.apply(initialize_module_for_quantization)
122
+ if status >= QuantizationStatus.CALIBRATION:
123
+ model.apply(set_module_for_calibration)
124
+ if status >= QuantizationStatus.FROZEN:
125
+ model.apply(freeze_module_quantization)
126
+
127
+
128
+ def _find_first_name_or_class_match(
129
+ name: str,
130
+ module: Module,
131
+ targets: Iterable[str],
132
+ ) -> Optional[str]:
133
+ # first element of targets that matches the given name
134
+ # if no name matches returns first target that matches the class name
135
+ # returns None otherwise
136
+ return _find_first_match(name, targets) or _find_first_match(
137
+ module.__class__.__name__, targets
138
+ )
139
+
140
+
141
+ def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]:
142
+ # returns first element of target that matches value either
143
+ # exactly or as a regex after 're:'
144
+ for target in targets:
145
+ if target.startswith("re:"):
146
+ pattern = target[3:]
147
+ if re.match(pattern, value):
148
+ return target
149
+ elif target == value:
150
+ return target
151
+ return None
152
+
153
+
154
+ def _load_quant_args_from_state_dict(
155
+ base_name: str, module_name: str, module: Module, state_dict: Dict
156
+ ):
157
+ """
158
+ Loads scale and zero point from a state_dict into the specified module
159
+
160
+ :param base_name: quantization target, one of: weights, input_activations or
161
+ output_activations
162
+ :param module_name: pytorch module name to look up in state_dict
163
+ :module: pytorch module associated with module_name
164
+ :state_dict: state_dict to search for matching quantization parameters
165
+ """
166
+ scale_name = f"{base_name}_scale"
167
+ zp_name = f"{base_name}_zero_point"
168
+ device = next(module.parameters()).device
169
+
170
+ scale = getattr(module, scale_name)
171
+ zp = getattr(module, zp_name)
172
+ scale.data = state_dict[f"{module_name}.{scale_name}"].to(device)
173
+ zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
@@ -0,0 +1,51 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+
18
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
19
+ from torch.nn import Module
20
+
21
+
22
+ __all__ = [
23
+ "set_module_for_calibration",
24
+ ]
25
+
26
+
27
+ _LOGGER = logging.getLogger(__name__)
28
+
29
+
30
+ def set_module_for_calibration(module: Module):
31
+ """
32
+ marks a layer as ready for calibration which activates observers
33
+ to update scales and zero points on each forward pass
34
+
35
+ apply to full model with `model.apply(set_module_for_calibration)`
36
+
37
+ :param module: module to set for calibration
38
+ """
39
+ if not getattr(module, "quantization_scheme", None):
40
+ # no quantization scheme nothing to do
41
+ return
42
+ status = getattr(module, "quantization_status", None)
43
+ if not status or status != QuantizationStatus.INITIALIZED:
44
+ raise _LOGGER.warning(
45
+ f"Attempting set module with status {status} to calibration mode. "
46
+ f"but status is not {QuantizationStatus.INITIALIZED} - you may "
47
+ "be calibrating an uninitialized module which may fail or attempting "
48
+ "to re-calibrate a frozen module"
49
+ )
50
+
51
+ module.quantization_status = QuantizationStatus.CALIBRATION
@@ -0,0 +1,136 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import wraps
16
+
17
+ import torch
18
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
19
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
20
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
21
+ from torch.nn import Module
22
+
23
+
24
+ __all__ = ["wrap_module_forward_quantized"]
25
+
26
+
27
+ @torch.no_grad()
28
+ def quantize(
29
+ x: torch.Tensor,
30
+ scale: torch.Tensor,
31
+ zero_point: torch.Tensor,
32
+ q_min: torch.Tensor,
33
+ q_max: torch.Tensor,
34
+ ) -> torch.Tensor:
35
+ return torch.clamp(
36
+ torch.round(
37
+ x / scale + zero_point,
38
+ ),
39
+ q_min,
40
+ q_max,
41
+ )
42
+
43
+
44
+ @torch.no_grad()
45
+ def dequantize(
46
+ x_q: torch.Tensor,
47
+ scale: torch.Tensor,
48
+ zero_point: torch.Tensor,
49
+ ) -> torch.Tensor:
50
+ return (x_q - zero_point) * scale
51
+
52
+
53
+ @torch.no_grad()
54
+ def fake_quantize(
55
+ x: torch.Tensor,
56
+ scale: torch.Tensor,
57
+ zero_point: torch.Tensor,
58
+ args: QuantizationArgs,
59
+ ) -> torch.Tensor:
60
+ bit_range = 2**args.num_bits
61
+ max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
62
+ min_q = torch.tensor(-bit_range / 2, device=x.device)
63
+ Q = torch.zeros_like(x)
64
+ Q = quantize(x, scale, zero_point, min_q, max_q)
65
+ return dequantize(Q, scale, zero_point)
66
+
67
+
68
+ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
69
+ # expects a module already initialized and injected with the parameters in
70
+ # initialize_module_for_quantization
71
+ forward_func_orig = module.forward.__func__
72
+
73
+ @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
74
+ def wrapped_forward(self, *args, **kwargs):
75
+ input_ = args[0]
76
+
77
+ if scheme.input_activations is not None:
78
+ # calibrate and (fake) quantize input activations when applicable
79
+ input_ = _maybe_calibrate_or_quantize(
80
+ module, input_, "input", scheme.input_activations
81
+ )
82
+
83
+ if scheme.weights is not None:
84
+ # calibrate and (fake) quantize weights when applicable
85
+ unquantized_weight = self.weight.data.clone()
86
+ self.weight.data = _maybe_calibrate_or_quantize(
87
+ module, self.weight, "weight", scheme.weights
88
+ )
89
+
90
+ # perform wrapped forward call
91
+ output = forward_func_orig.__get__(module, module.__class__)(
92
+ input_, *args[1:], **kwargs
93
+ )
94
+
95
+ if scheme.output_activations is not None:
96
+ # calibrate and (fake) quantize output activations when applicable
97
+ output = _maybe_calibrate_or_quantize(
98
+ module, output, "output", scheme.output_activations
99
+ )
100
+
101
+ # restore back to unquantized_value
102
+ if scheme.weights is not None:
103
+ self.weight.data = unquantized_weight
104
+
105
+ return output
106
+
107
+ # bind wrapped forward to module class so reference to `self` is correct
108
+ bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
109
+ # set forward to wrapped forward
110
+ setattr(module, "forward", bound_wrapped_forward)
111
+
112
+
113
+ def _maybe_calibrate_or_quantize(
114
+ module: Module, value: Module, base_name: str, args: "QuantizationArgs"
115
+ ) -> torch.Tensor:
116
+ # only run quantized for the included stages
117
+ if module.quantization_status not in {
118
+ QuantizationStatus.CALIBRATION,
119
+ QuantizationStatus.FROZEN,
120
+ }:
121
+ return value
122
+
123
+ device = next(module.parameters()).device
124
+ scale = getattr(module, f"{base_name}_scale")
125
+ zero_point = getattr(module, f"{base_name}_zero_point")
126
+
127
+ if module.quantization_status == QuantizationStatus.CALIBRATION:
128
+ # get observer and get new quant params from observation
129
+ observer = getattr(module, f"{base_name}_observer")
130
+ updated_scale, updated_zero_point = observer(value)
131
+
132
+ # update scale and zero point
133
+ scale.data = updated_scale.to(device)
134
+ zero_point.data = updated_zero_point.to(device)
135
+
136
+ return fake_quantize(value, scale, zero_point, args)
@@ -0,0 +1,46 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
17
+ from torch.nn import Module
18
+
19
+
20
+ __all__ = [
21
+ "freeze_module_quantization",
22
+ ]
23
+
24
+
25
+ def freeze_module_quantization(module: Module):
26
+ """
27
+ deletes observers so static quantization is completed.
28
+
29
+ apply to full model with `model.apply(freeze_module_quantization)`
30
+
31
+ :param module: module to freeze quantization for
32
+ """
33
+ if not getattr(module, "quantization_scheme", None):
34
+ # no quantization scheme nothing to do
35
+ return
36
+
37
+ # delete observers from module
38
+ observer_names = []
39
+ for submodule_name, _ in module.named_modules():
40
+ if "." not in submodule_name and submodule_name.endswith("_observer"):
41
+ # delete any observers that belong directly to this module
42
+ observer_names.append(submodule_name)
43
+ for observer_name in observer_names:
44
+ delattr(module, observer_name)
45
+
46
+ module.quantization_status = QuantizationStatus.FROZEN
@@ -0,0 +1,96 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+ from typing import Optional
18
+
19
+ import torch
20
+ from compressed_tensors.quantization.lifecycle.forward import (
21
+ wrap_module_forward_quantized,
22
+ )
23
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
24
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
25
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
26
+ from torch.nn import Module, Parameter
27
+
28
+
29
+ __all__ = [
30
+ "initialize_module_for_quantization",
31
+ ]
32
+
33
+
34
+ _LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ def initialize_module_for_quantization(
38
+ module: Module,
39
+ scheme: Optional[QuantizationScheme] = None,
40
+ ):
41
+ """
42
+ attaches appropriate scales, zero points, and observers to a layer
43
+ given its target quantization scheme
44
+
45
+ apply to full model with `model.apply(initialize_module_for_quantization)`
46
+
47
+ :param module: module to set for calibration
48
+ :param scheme: scheme to use for quantization. if None is provided,
49
+ will attempt to use scheme stored in the module under `quantization_scheme`,
50
+ if not provided, the layer will be skipped
51
+ """
52
+ scheme = scheme or getattr(module, "quantization_scheme", None)
53
+ if scheme is None:
54
+ # no scheme passed and layer not targeted for quantization - skip
55
+ return
56
+
57
+ if scheme.input_activations is not None:
58
+ _initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
59
+ if scheme.weights is not None:
60
+ if hasattr(module, "weight"):
61
+ _initialize_scale_zero_point_observer(module, "weight", scheme.weights)
62
+ else:
63
+ _LOGGER.warning(
64
+ f"module type {type(module)} targeted for weight quantization but "
65
+ "has no attribute weight, skipping weight quantization "
66
+ f"for {type(module)}"
67
+ )
68
+ if scheme.output_activations is not None:
69
+ _initialize_scale_zero_point_observer(
70
+ module, "output", scheme.output_activations
71
+ )
72
+
73
+ module.quantization_scheme = scheme
74
+ module.quantization_status = QuantizationStatus.INITIALIZED
75
+
76
+ # wrap forward call of module to perform quantized actions based on calltime status
77
+ wrap_module_forward_quantized(module, scheme)
78
+
79
+
80
+ def _initialize_scale_zero_point_observer(
81
+ module: Module, base_name: str, quantization_args: QuantizationArgs
82
+ ):
83
+ device = next(module.parameters()).device
84
+
85
+ # initializes empty scale and zero point parameters for the module
86
+ init_scale = Parameter(torch.empty(0, device=device), requires_grad=False)
87
+ module.register_parameter(f"{base_name}_scale", init_scale)
88
+
89
+ init_zero_point = Parameter(
90
+ torch.empty(0, device=device, dtype=int), requires_grad=False
91
+ )
92
+ module.register_parameter(f"{base_name}_zero_point", init_zero_point)
93
+
94
+ # initialize observer module and attach as submodule
95
+ observer = quantization_args.get_observer()
96
+ module.register_module(f"{base_name}_observer", observer)
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ # isort: skip_file
17
+
18
+ from .helpers import *
19
+ from .base import *
20
+ from .memoryless import *
21
+ from .min_max import *
@@ -0,0 +1,69 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
18
+ from compressed_tensors.registry.registry import RegistryMixin
19
+ from torch import FloatTensor, IntTensor, Tensor
20
+ from torch.nn import Module
21
+
22
+
23
+ __all__ = ["Observer"]
24
+
25
+
26
+ class Observer(Module, RegistryMixin):
27
+ """
28
+ Base Observer class to be subclassed for specific implementation.
29
+ Subclasses should override `calculate_qparams` to return a scale, zero_point
30
+ pair
31
+ """
32
+
33
+ def __init__(self, quantization_args: QuantizationArgs):
34
+ self.quantization_args: QuantizationArgs = quantization_args
35
+ super().__init__()
36
+ self._scale = None
37
+ self._zero_point = None
38
+
39
+ def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
40
+ """
41
+ maps directly to get_qparams
42
+ :param observed: optional observed tensor to calculate quantization parameters
43
+ from
44
+ :return: tuple of scale and zero point based on last observed value
45
+ """
46
+ return self.get_qparams(observed=observed)
47
+
48
+ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
49
+ """
50
+ :param observed: observed tensor to calculate quantization parameters for
51
+ :return: tuple of scale and zero point derived from the observed tensor
52
+ """
53
+ raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
54
+
55
+ def get_qparams(
56
+ self, observed: Optional[Tensor] = None
57
+ ) -> Tuple[FloatTensor, IntTensor]:
58
+ """
59
+ Convenience function to wrap overwritten calculate_qparams
60
+ adds support to make observed tensor optional and support for tracking latest
61
+ calculated scale and zero point
62
+ :param observed: optional observed tensor to calculate quantization parameters
63
+ from
64
+ :return: tuple of scale and zero point based on last observed value
65
+ """
66
+ if observed is not None:
67
+ # re-calcualte scale and zero point, update the stored value
68
+ self._scale, self._zero_point = self.calculate_qparams(observed)
69
+ return self._scale, self._zero_point