compressed-tensors 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. compressed_tensors/__init__.py +21 -0
  2. compressed_tensors/base.py +16 -0
  3. compressed_tensors/compressors/__init__.py +25 -0
  4. compressed_tensors/compressors/base.py +79 -0
  5. compressed_tensors/compressors/dense.py +34 -0
  6. compressed_tensors/compressors/helpers.py +161 -0
  7. compressed_tensors/compressors/sparse_bitmask.py +238 -0
  8. compressed_tensors/config/__init__.py +18 -0
  9. compressed_tensors/config/base.py +42 -0
  10. compressed_tensors/config/dense.py +36 -0
  11. compressed_tensors/config/sparse_bitmask.py +36 -0
  12. compressed_tensors/quantization/__init__.py +21 -0
  13. compressed_tensors/quantization/lifecycle/__init__.py +22 -0
  14. compressed_tensors/quantization/lifecycle/apply.py +173 -0
  15. compressed_tensors/quantization/lifecycle/calibration.py +51 -0
  16. compressed_tensors/quantization/lifecycle/forward.py +136 -0
  17. compressed_tensors/quantization/lifecycle/frozen.py +46 -0
  18. compressed_tensors/quantization/lifecycle/initialize.py +96 -0
  19. compressed_tensors/quantization/observers/__init__.py +21 -0
  20. compressed_tensors/quantization/observers/base.py +69 -0
  21. compressed_tensors/quantization/observers/helpers.py +53 -0
  22. compressed_tensors/quantization/observers/memoryless.py +48 -0
  23. compressed_tensors/quantization/observers/min_max.py +65 -0
  24. compressed_tensors/quantization/quant_args.py +85 -0
  25. compressed_tensors/quantization/quant_config.py +171 -0
  26. compressed_tensors/quantization/quant_scheme.py +39 -0
  27. compressed_tensors/quantization/utils/__init__.py +16 -0
  28. compressed_tensors/quantization/utils/helpers.py +115 -0
  29. compressed_tensors/registry/__init__.py +17 -0
  30. compressed_tensors/registry/registry.py +360 -0
  31. compressed_tensors/utils/__init__.py +16 -0
  32. compressed_tensors/utils/helpers.py +151 -0
  33. compressed_tensors/utils/safetensors_load.py +237 -0
  34. compressed_tensors-0.3.0.dist-info/METADATA +22 -0
  35. compressed_tensors-0.3.0.dist-info/RECORD +37 -0
  36. compressed_tensors-0.3.0.dist-info/WHEEL +5 -0
  37. compressed_tensors-0.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
19
+ from torch import FloatTensor, IntTensor, Tensor
20
+
21
+
22
+ __all__ = ["calculate_qparams"]
23
+
24
+
25
+ def calculate_qparams(
26
+ min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
27
+ ) -> Tuple[FloatTensor, IntTensor]:
28
+ """
29
+ :param min_vals: tensor of min value(s) to caluclate scale(s) and zero point(s)
30
+ from
31
+ :param max_vals: tensor of max value(s) to caluclate scale(s) and zero point(s)
32
+ from
33
+ :param quantization_args: settings to quantization
34
+ :return: tuple of the calculated scale(s) and zero point(s)
35
+ """
36
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
37
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
38
+
39
+ bit_range = 2**quantization_args.num_bits - 1
40
+ bit_min = -(bit_range + 1) / 2
41
+ bit_max = bit_min + bit_range
42
+ if quantization_args.symmetric:
43
+ zero_points = torch.tensor(0).to(torch.int8)
44
+ max_val_pos = torch.max(-min_vals, max_vals)
45
+ scales = max_val_pos / (float(bit_range) / 2)
46
+ scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
47
+ else:
48
+ scales = (max_vals - min_vals) / float(bit_range)
49
+ scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
50
+ zero_points = bit_min - torch.round(min_vals / scales)
51
+ zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)
52
+
53
+ return scales, zero_points
@@ -0,0 +1,48 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.quantization.observers.base import Observer
19
+ from compressed_tensors.quantization.observers.helpers import calculate_qparams
20
+ from torch import FloatTensor, IntTensor, Tensor
21
+
22
+
23
+ __all__ = ["MemorylessObserver"]
24
+
25
+
26
+ @Observer.register("memoryless")
27
+ class MemorylessObserver(Observer):
28
+ """
29
+ Implements a dynamic quantization observer that sets the scale and
30
+ zero point based on the latest observed value without tracking state
31
+ """
32
+
33
+ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
34
+ """
35
+ Returns the min and max values of observed
36
+
37
+ :param observed: observed tensor to calculate quantization parameters for
38
+ :return: tuple of scale and zero point derived from the observed tensor
39
+ """
40
+ # TODO: Add support for full range of quantization Args, only supports 8bit
41
+ # per tensor
42
+ min_val, max_val = torch.aminmax(observed)
43
+
44
+ # ensure zero is in the range
45
+ min_val = torch.min(min_val, torch.zeros_like(min_val))
46
+ max_val = torch.max(max_val, torch.zeros_like(max_val))
47
+
48
+ return calculate_qparams(min_val, max_val, self.quantization_args)
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.quantization.observers.base import Observer
19
+ from compressed_tensors.quantization.observers.helpers import calculate_qparams
20
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
21
+ from torch import FloatTensor, IntTensor, Tensor
22
+
23
+
24
+ __all__ = ["MovingAverageMinMaxObserver"]
25
+
26
+
27
+ @Observer.register("minmax")
28
+ class MovingAverageMinMaxObserver(Observer):
29
+ """
30
+ Implements a dynamic quantization observer that sets the scale and
31
+ zero point based on a moving average of the overall min and max observed values
32
+ """
33
+
34
+ def __init__(
35
+ self, quantization_args: QuantizationArgs, averaging_constant: float = 0.01
36
+ ):
37
+ super().__init__(quantization_args=quantization_args)
38
+
39
+ self.min_val = float("inf")
40
+ self.max_val = -float("inf")
41
+ self.averaging_constant = averaging_constant
42
+
43
+ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
44
+ """
45
+ Updates the observed min and max using a moving average smoothed by the
46
+ averaging_constant
47
+
48
+ :param observed: observed tensor to calculate quantization parameters for
49
+ :return: tuple of scale and zero point derived from the observed tensor
50
+ """
51
+
52
+ min_val, max_val = torch.aminmax(observed)
53
+
54
+ if self.min_val == float("inf") and self.max_val == float("-inf"):
55
+ self.min_val = min_val
56
+ self.max_val = max_val
57
+ else:
58
+ self.min_val = self.min_val + self.averaging_constant * (
59
+ min_val - self.min_val
60
+ )
61
+ self.max_val = self.max_val + self.averaging_constant * (
62
+ max_val - self.max_val
63
+ )
64
+
65
+ return calculate_qparams(self.min_val, self.max_val, self.quantization_args)
@@ -0,0 +1,85 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum
16
+ from typing import Any, Dict, Optional
17
+
18
+ from pydantic import BaseModel, Field
19
+
20
+
21
+ __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
22
+
23
+
24
+ class QuantizationType(str, Enum):
25
+ """
26
+ Enum storing quantization type options
27
+ """
28
+
29
+ INT = "int"
30
+ FLOAT = "float"
31
+
32
+
33
+ class QuantizationStrategy(str, Enum):
34
+ """
35
+ Enum storing quantization strategy options
36
+ """
37
+
38
+ TENSOR = "tensor"
39
+ CHANNEL = "channel"
40
+ GROUP = "group"
41
+ BLOCK = "block"
42
+
43
+
44
+ class QuantizationArgs(BaseModel):
45
+ """
46
+ User facing arguments used to define a quantization config for weights or
47
+ activations
48
+
49
+ :param num_bits: quantization bit depth
50
+ :param type: dtype to quantized to, either int or float
51
+ :param symmetric: whether or not quantization scale is symmetric about zero-point
52
+ :param strategy: string id determining the scope of scale/zero-point to apply
53
+ :param group_size: group length to use for the group strategy
54
+ :param block_structure: 2d block structure to use for the block strategy, must be
55
+ of the format "2x4", "8x16", etc.
56
+ """
57
+
58
+ num_bits: int = 8
59
+ type: QuantizationType = QuantizationType.INT
60
+ symmetric: bool = True
61
+ strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
62
+ group_size: Optional[int] = None
63
+ block_structure: Optional[str] = None
64
+ observer: str = Field(
65
+ default="minmax",
66
+ description=(
67
+ "The class to use to compute the quantization param - "
68
+ "scale and zero-point'"
69
+ ),
70
+ )
71
+ observer_kwargs: Dict[str, Any] = Field(
72
+ default_factory=dict,
73
+ description=(
74
+ "optional dict of kwargs to be passed directly to torch quantization "
75
+ "Observers constructor excluding quantization range or symmetry"
76
+ ),
77
+ )
78
+
79
+ def get_observer(self):
80
+ """
81
+ :return: torch quantization FakeQuantize built based on these QuantizationArgs
82
+ """
83
+ from compressed_tensors.quantization.observers.base import Observer
84
+
85
+ return Observer.load_from_registry(self.observer, quantization_args=self)
@@ -0,0 +1,171 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum
16
+ from typing import Dict, List, Optional
17
+
18
+ from compressed_tensors.base import QUANTIZATION_CONFIG_NAME
19
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
20
+ from compressed_tensors.quantization.utils import (
21
+ calculate_compression_ratio,
22
+ is_module_quantized,
23
+ iter_named_leaf_modules,
24
+ module_type,
25
+ )
26
+ from pydantic import BaseModel, Field
27
+ from torch.nn import Module
28
+ from transformers import AutoConfig
29
+
30
+
31
+ __all__ = [
32
+ "QuantizationStatus",
33
+ "QuantizationConfig",
34
+ "LIFECYCLE_ORDER",
35
+ ]
36
+
37
+
38
+ class QuantizationStatus(str, Enum):
39
+ """
40
+ Enum storing the different states a quantized layer can be in
41
+
42
+ Initialized: scale, zero points and observers have been attached to the layer but
43
+ are set to dummy values (not yet calibrated)
44
+ Calibration: scale and zero points have been calibrated through OBCQ or similar
45
+ algorithm, observers are still attached
46
+ Frozen: scale and zero points are finalized, observers have been deleted, weights
47
+ are still in their original precision
48
+ Compressed: weights have been converted to their target type or compressed to
49
+ their closed approximation
50
+ """
51
+
52
+ INITIALIZED = "initialized"
53
+ CALIBRATION = "calibration"
54
+ FROZEN = "frozen"
55
+ COMPRESSED = "compressed"
56
+
57
+ @classmethod
58
+ def lifecycle_order(cls) -> List["QuantizationStatus"]:
59
+ """
60
+ :return: list of correct quantization lifecycle order
61
+ """
62
+ return
63
+
64
+ def __ge__(self, other):
65
+ if not isinstance(other, self.__class__):
66
+ raise NotImplementedError
67
+ return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)
68
+
69
+
70
+ LIFECYCLE_ORDER = [
71
+ QuantizationStatus.INITIALIZED,
72
+ QuantizationStatus.CALIBRATION,
73
+ QuantizationStatus.FROZEN,
74
+ QuantizationStatus.COMPRESSED,
75
+ ]
76
+
77
+
78
+ class QuantizationConfig(BaseModel):
79
+ """
80
+ Full configuration specifying how a model is quantized. Each quantized layer is
81
+ mapped to a QuantizationScheme in config_groups.
82
+
83
+ :param config_groups: dict of QuantizationSchemes specifying the quantization
84
+ settings for each quantized layer
85
+ :param quant_method: a constant used to differentiate sparseML quantization from
86
+ other quantization configs
87
+ :param format: specifies how the quantized model is stored on disk
88
+ :quantization_status: specifies the current status of all quantized layers. It is
89
+ assumed all layers are in the same state.
90
+ :global_compression_ratio: optional informational config to report the model
91
+ compression ratio acheived by the quantization config
92
+ :ignore: optional list of layers to ignore from config_groups. Layers in this list
93
+ are not quantized even if they match up with a target in config_groups
94
+ """
95
+
96
+ config_groups: Dict[str, QuantizationScheme]
97
+ quant_method: str = "sparseml"
98
+ format: str = "fakequant"
99
+ quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
100
+ global_compression_ratio: Optional[float] = None
101
+ ignore: Optional[List[str]] = Field(default_factory=list)
102
+
103
+ @staticmethod
104
+ def from_model_config(model_name_or_path) -> "QuantizationConfig":
105
+ """
106
+ Given a path to a model config, extract a quantization config if it exists
107
+
108
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
109
+ :return: instantiated QuantizationConfig if config contains a quant config
110
+ """
111
+ config = AutoConfig.from_pretrained(model_name_or_path)
112
+ quantization_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
113
+ if quantization_config is None:
114
+ return None
115
+
116
+ return QuantizationConfig.parse_obj(quantization_config)
117
+
118
+ @staticmethod
119
+ def from_pretrained(model: Module) -> "QuantizationConfig":
120
+ """
121
+ Converts a model into its associated QuantizationConfig based on the
122
+ QuantizationScheme attached to each quanitzed module
123
+
124
+ :param model: model to calculate quantization scheme of
125
+ :return: filled out QuantizationScheme for the input model
126
+ """
127
+ quant_scheme_to_layers = []
128
+ quantization_status = None
129
+ ignore = {}
130
+ quantization_type_names = set()
131
+ for name, submodule in iter_named_leaf_modules(model):
132
+ layer_type = module_type(submodule)
133
+ if not is_module_quantized(submodule):
134
+ if layer_type not in ignore:
135
+ ignore[layer_type] = []
136
+ ignore[layer_type].append(name)
137
+ else:
138
+ quantization_status = submodule.quantization_status
139
+ scheme = submodule.quantization_scheme
140
+ quantization_type_names.add(layer_type)
141
+
142
+ match_found = False
143
+ for existing_scheme in quant_scheme_to_layers:
144
+ if scheme == existing_scheme:
145
+ match_found = True
146
+ break
147
+ if not match_found:
148
+ quant_scheme_to_layers.append(scheme)
149
+
150
+ # clean up ignore list, we can leave out layers types if none of the
151
+ # instances are quantized
152
+ consolidated_ignore = []
153
+ for layer_type, ignore_names in ignore.items():
154
+ if layer_type in quantization_type_names:
155
+ # specific layers of a quantized type are ignored
156
+ consolidated_ignore += ignore_names
157
+ # else we leave it off the ignore list, doesn't fall under any of the
158
+ # existing quantization schemes so it won't be quantized
159
+
160
+ config_groups = {}
161
+ for idx, scheme in enumerate(quant_scheme_to_layers):
162
+ group_name = "group_" + str(idx)
163
+ config_groups[group_name] = scheme
164
+
165
+ compression_ratio = calculate_compression_ratio(model)
166
+ return QuantizationConfig(
167
+ config_groups=config_groups,
168
+ quantization_status=quantization_status,
169
+ global_compression_ratio=compression_ratio,
170
+ ignore=consolidated_ignore,
171
+ )
@@ -0,0 +1,39 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional
16
+
17
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
18
+ from pydantic import BaseModel
19
+
20
+
21
+ __all__ = ["QuantizationScheme"]
22
+
23
+
24
+ class QuantizationScheme(BaseModel):
25
+ """
26
+ Set of QuantizationArgs defining how the weights, inputs and outputs of target list
27
+ of modules should be quantized
28
+
29
+ :param targets: list of modules to apply the QuantizationArgs to, can be layer
30
+ names, layer types or a regular expression
31
+ :param weights: quantization config for layer weights
32
+ :param input_activations: quantization config for layer inputs
33
+ :param output_activations: quantization config for layer outputs
34
+ """
35
+
36
+ targets: List[str]
37
+ weights: Optional[QuantizationArgs] = None
38
+ input_activations: Optional[QuantizationArgs] = None
39
+ output_activations: Optional[QuantizationArgs] = None
@@ -0,0 +1,16 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ from .helpers import *
@@ -0,0 +1,115 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from torch.nn import Module
19
+ from tqdm import tqdm
20
+
21
+
22
+ __all__ = [
23
+ "is_module_quantized",
24
+ "is_model_quantized",
25
+ "iter_named_leaf_modules",
26
+ "module_type",
27
+ "calculate_compression_ratio",
28
+ ]
29
+
30
+
31
+ def is_module_quantized(module: Module) -> bool:
32
+ """
33
+ Check if a module is quantized, based on the existence of a non-empty quantization
34
+ scheme
35
+
36
+ :param module: pytorch module to check
37
+ :return: True if module is quantized, False otherwise
38
+ """
39
+ if not hasattr(module, "quantization_scheme"):
40
+ return False
41
+
42
+ if module.quantization_scheme.weights is not None:
43
+ return True
44
+
45
+ if module.quantization_scheme.input_activations is not None:
46
+ return True
47
+
48
+ if module.quantization_scheme.output_activations is not None:
49
+ return True
50
+
51
+ return False
52
+
53
+
54
+ def is_model_quantized(model: Module) -> bool:
55
+ """
56
+ Check if any modules in a model are quantized, based on the existence of a non-empty
57
+ quantization scheme in at least one module
58
+
59
+ :param model: pytorch model
60
+ :return: True if model is quantized, False otherwise
61
+ """
62
+
63
+ for _, submodule in iter_named_leaf_modules(model):
64
+ if is_module_quantized(submodule):
65
+ return True
66
+
67
+ return False
68
+
69
+
70
+ def module_type(module: Module) -> str:
71
+ """
72
+ Gets a string representation of a module type
73
+
74
+ :module: pytorch module to get type of
75
+ :return: module type as a string
76
+ """
77
+ return type(module).__name__
78
+
79
+
80
+ def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
81
+ # yields modules that do not have any submodules
82
+ # TODO: potentially expand to add list of allowed submodules such as observers
83
+ for name, submodule in model.named_modules():
84
+ if len(list(submodule.children())) == 0:
85
+ yield name, submodule
86
+
87
+
88
+ def calculate_compression_ratio(model: Module) -> float:
89
+ """
90
+ Calculates the quantization compression ratio of a pytorch model, based on the
91
+ number of bits needed to represent the total weights in compressed form. Does not
92
+ take into account activation quantizatons.
93
+
94
+ :param model: pytorch module to calculate compression ratio for
95
+ :return: compression ratio of the whole model
96
+ """
97
+ total_compressed = 0.0
98
+ total_uncompressed = 0.0
99
+ for name, submodule in tqdm(
100
+ iter_named_leaf_modules(model),
101
+ desc="Calculating quantization compression ratio",
102
+ ):
103
+ for parameter in model.parameters():
104
+ try:
105
+ uncompressed_bits = torch.finfo(parameter.dtype).bits
106
+ except TypeError:
107
+ uncompressed_bits = torch.iinfo(parameter.dtype).bits
108
+ compressed_bits = uncompressed_bits
109
+ if is_module_quantized(submodule):
110
+ compressed_bits = submodule.quantization_scheme.weights.num_bits
111
+ num_weights = parameter.numel()
112
+ total_compressed += compressed_bits * num_weights
113
+ total_uncompressed += uncompressed_bits * num_weights
114
+
115
+ return total_uncompressed / total_compressed
@@ -0,0 +1,17 @@
1
+ # flake8: noqa
2
+
3
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .registry import *