mct-nightly 2.4.0.20250630.629__py3-none-any.whl → 2.4.0.20250702.605__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/METADATA +16 -16
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/RECORD +75 -72
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
- model_compression_toolkit/core/common/framework_info.py +5 -32
- model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
- model_compression_toolkit/core/common/graph/base_graph.py +20 -37
- model_compression_toolkit/core/common/graph/base_node.py +13 -106
- model_compression_toolkit/core/common/graph/functional_node.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
- model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -96
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
- model_compression_toolkit/core/graph_prep_runner.py +31 -20
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
- model_compression_toolkit/core/keras/default_framework_info.py +0 -11
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
- model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/runner.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
- model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
- model_compression_toolkit/quantization_preparation/__init__.py +14 -0
- model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
|
@@ -12,24 +12,16 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
from typing import Callable, Any, List, Tuple, Union, Dict, TYPE_CHECKING
|
|
15
|
+
from typing import Any, List, Dict, TYPE_CHECKING
|
|
18
16
|
from enum import Enum, auto
|
|
19
|
-
import numpy as np
|
|
20
17
|
|
|
21
18
|
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
|
|
22
|
-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
|
|
23
19
|
from model_compression_toolkit.logger import Logger
|
|
24
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
|
|
25
|
-
get_activation_quantization_params_fn, get_weights_quantization_params_fn
|
|
26
20
|
|
|
27
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
28
|
-
|
|
29
|
-
from model_compression_toolkit.target_platform_capabilities.constants import POS_ATTR
|
|
21
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
|
|
30
23
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import \
|
|
31
|
-
AttributeQuantizationConfig,
|
|
32
|
-
OpQuantizationConfig
|
|
24
|
+
AttributeQuantizationConfig, OpQuantizationConfig
|
|
33
25
|
|
|
34
26
|
if TYPE_CHECKING:
|
|
35
27
|
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
|
|
@@ -86,29 +78,14 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
86
78
|
"""
|
|
87
79
|
Attributes for configuring the quantization of the activations of a node.
|
|
88
80
|
"""
|
|
89
|
-
def __init__(self,
|
|
90
|
-
qc: QuantizationConfig,
|
|
91
|
-
op_cfg: OpQuantizationConfig,
|
|
92
|
-
activation_quantization_fn: Callable,
|
|
93
|
-
activation_quantization_params_fn: Callable
|
|
94
|
-
):
|
|
81
|
+
def __init__(self, op_cfg: OpQuantizationConfig):
|
|
95
82
|
"""
|
|
96
83
|
|
|
97
84
|
Args:
|
|
98
|
-
qc: QuantizationConfig to create the node's config from.
|
|
99
85
|
op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
|
|
100
|
-
activation_quantization_fn: Function to use when quantizing the node's activations.
|
|
101
|
-
activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
|
|
102
86
|
"""
|
|
103
|
-
|
|
104
|
-
self.activation_quantization_fn = activation_quantization_fn
|
|
105
|
-
self.activation_quantization_params_fn = activation_quantization_params_fn
|
|
106
|
-
self.activation_quantization_params = {}
|
|
107
87
|
self.activation_quantization_method = op_cfg.activation_quantization_method
|
|
108
|
-
self.activation_error_method = qc.activation_error_method
|
|
109
88
|
self.activation_n_bits = op_cfg.activation_n_bits
|
|
110
|
-
self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
|
|
111
|
-
self.activation_bias_correction_term = None
|
|
112
89
|
if op_cfg.enable_activation_quantization and op_cfg.quantization_preserving:
|
|
113
90
|
raise ValueError("An OpQuantizationConfig can't have both enable_activation_quantization and quantization_preserving enabled.")
|
|
114
91
|
if op_cfg.enable_activation_quantization:
|
|
@@ -118,6 +95,29 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
118
95
|
else:
|
|
119
96
|
self.quant_mode = ActivationQuantizationMode.NO_QUANT
|
|
120
97
|
self.signedness = op_cfg.signedness
|
|
98
|
+
|
|
99
|
+
self.activation_quantization_params = {}
|
|
100
|
+
# TODO irena: computed by compute_activation_bias_correction. shouldnt really be here
|
|
101
|
+
self.activation_bias_correction_term = None
|
|
102
|
+
|
|
103
|
+
# TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
|
|
104
|
+
self.activation_error_method = None
|
|
105
|
+
self.relu_bound_to_power_of_2 = None
|
|
106
|
+
self.activation_channel_equalization = None
|
|
107
|
+
self.input_scaling = None
|
|
108
|
+
self.min_threshold = None
|
|
109
|
+
self.l_p_value = None
|
|
110
|
+
self.shift_negative_activation_correction = None
|
|
111
|
+
self.z_threshold = None
|
|
112
|
+
self.shift_negative_ratio = None
|
|
113
|
+
self.shift_negative_threshold_recalculation = None
|
|
114
|
+
self.concat_threshold_update = None
|
|
115
|
+
|
|
116
|
+
def set_qc(self, qc: QuantizationConfig):
|
|
117
|
+
""" TODO irena: temporary keep all the attributes as before not to break all code at once.
|
|
118
|
+
Eventually all of them should be removed from here. """
|
|
119
|
+
self.activation_error_method = qc.activation_error_method
|
|
120
|
+
self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
|
|
121
121
|
self.activation_channel_equalization = qc.activation_channel_equalization
|
|
122
122
|
self.input_scaling = qc.input_scaling
|
|
123
123
|
self.min_threshold = qc.min_threshold
|
|
@@ -139,65 +139,6 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
139
139
|
def fln_quantization(self):
|
|
140
140
|
return self.quant_mode == ActivationQuantizationMode.FLN_QUANT
|
|
141
141
|
|
|
142
|
-
def quantize_node_output(self,
|
|
143
|
-
tensors: Any) -> Any:
|
|
144
|
-
"""
|
|
145
|
-
|
|
146
|
-
Args:
|
|
147
|
-
tensors: framework tensor/s
|
|
148
|
-
|
|
149
|
-
Returns:
|
|
150
|
-
Framework tensor/s after applying fake quantization.
|
|
151
|
-
|
|
152
|
-
"""
|
|
153
|
-
fake_quant = self.activation_quantization_fn(self.activation_n_bits,
|
|
154
|
-
self.activation_quantization_params)
|
|
155
|
-
|
|
156
|
-
if fake_quant is None:
|
|
157
|
-
Logger.critical(
|
|
158
|
-
"Layer is intended to be quantized, but the fake_quant function is None.") # pragma: no cover
|
|
159
|
-
|
|
160
|
-
return fake_quant(tensors)
|
|
161
|
-
|
|
162
|
-
@property
|
|
163
|
-
def activation_error_method(self) -> QuantizationErrorMethod:
|
|
164
|
-
"""
|
|
165
|
-
activation_error_method getter.
|
|
166
|
-
"""
|
|
167
|
-
return self._activation_error_method
|
|
168
|
-
|
|
169
|
-
@activation_error_method.setter
|
|
170
|
-
def activation_error_method(self, value: QuantizationErrorMethod):
|
|
171
|
-
"""
|
|
172
|
-
activation_error_method setter.
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
value: New activation_error_method to set to the node activation configuration.
|
|
176
|
-
|
|
177
|
-
"""
|
|
178
|
-
self._activation_error_method = value
|
|
179
|
-
self.activation_quantization_params_fn = get_activation_quantization_params_fn(activation_quantization_method=self.activation_quantization_method)
|
|
180
|
-
|
|
181
|
-
def set_activation_quantization_fn(self, activation_quantization_fn: Callable):
|
|
182
|
-
"""
|
|
183
|
-
Sets activation quantization function for the node.
|
|
184
|
-
|
|
185
|
-
Args:
|
|
186
|
-
activation_quantization_fn: Function for quantazing the activations.
|
|
187
|
-
|
|
188
|
-
"""
|
|
189
|
-
self.activation_quantization_fn = activation_quantization_fn
|
|
190
|
-
|
|
191
|
-
def set_activation_quantization_params_fn(self, activation_quantization_params_fn:Callable):
|
|
192
|
-
"""
|
|
193
|
-
Sets activation params function for the node.
|
|
194
|
-
|
|
195
|
-
Args:
|
|
196
|
-
activation_quantization_params_fn: Function for calculating activation params.
|
|
197
|
-
|
|
198
|
-
"""
|
|
199
|
-
self.activation_quantization_params_fn = activation_quantization_params_fn
|
|
200
|
-
|
|
201
142
|
def set_activation_quantization_param(self,
|
|
202
143
|
activation_params: dict):
|
|
203
144
|
"""
|
|
@@ -224,9 +165,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
224
165
|
if not isinstance(other, NodeActivationQuantizationConfig):
|
|
225
166
|
return False # pragma: no cover
|
|
226
167
|
|
|
227
|
-
return self.
|
|
228
|
-
self.activation_quantization_params_fn == other.activation_quantization_params_fn and \
|
|
229
|
-
self.activation_error_method == other.activation_error_method and \
|
|
168
|
+
return self.activation_error_method == other.activation_error_method and \
|
|
230
169
|
self.activation_quantization_method == other.activation_quantization_method and \
|
|
231
170
|
self.activation_n_bits == other.activation_n_bits and \
|
|
232
171
|
self.quant_mode == other.quant_mode and \
|
|
@@ -240,9 +179,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
240
179
|
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
|
|
241
180
|
|
|
242
181
|
def __hash__(self):
|
|
243
|
-
return hash((self.
|
|
244
|
-
self.activation_quantization_params_fn,
|
|
245
|
-
self.activation_error_method,
|
|
182
|
+
return hash((self.activation_error_method,
|
|
246
183
|
self.activation_quantization_method,
|
|
247
184
|
self.activation_n_bits,
|
|
248
185
|
self.quant_mode,
|
|
@@ -261,65 +198,29 @@ class WeightsAttrQuantizationConfig:
|
|
|
261
198
|
Configuration for quantizing a weights attribute of a node.
|
|
262
199
|
"""
|
|
263
200
|
def __init__(self,
|
|
264
|
-
qc: QuantizationConfig,
|
|
265
201
|
weights_attr_cfg: AttributeQuantizationConfig,
|
|
266
202
|
weights_channels_axis: ChannelAxisMapping = None):
|
|
267
203
|
"""
|
|
268
204
|
|
|
269
205
|
Args:
|
|
270
|
-
qc: QuantizationConfig to create the node's config from.
|
|
271
206
|
weights_attr_cfg: AttributeQuantizationConfig with parameters to use when creating the node's attribute quantization config.
|
|
272
207
|
weights_channels_axis: Axis to quantize a node's attribute when quantizing per-channel (if not quantizing per-channel than expecting None).
|
|
273
208
|
"""
|
|
274
|
-
self.weights_quantization_fn = get_weights_quantization_fn(weights_attr_cfg.weights_quantization_method)
|
|
275
|
-
self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_attr_cfg.weights_quantization_method)
|
|
276
209
|
self.weights_channels_axis = weights_channels_axis
|
|
277
|
-
self.weights_quantization_params = {}
|
|
278
210
|
self.weights_quantization_method = weights_attr_cfg.weights_quantization_method
|
|
279
|
-
self.weights_error_method = qc.weights_error_method
|
|
280
211
|
self.weights_n_bits = weights_attr_cfg.weights_n_bits
|
|
281
212
|
self.weights_per_channel_threshold = weights_attr_cfg.weights_per_channel_threshold
|
|
282
213
|
self.enable_weights_quantization = weights_attr_cfg.enable_weights_quantization
|
|
283
|
-
self.
|
|
284
|
-
|
|
285
|
-
@property
|
|
286
|
-
def weights_error_method(self) -> QuantizationErrorMethod:
|
|
287
|
-
"""
|
|
288
|
-
weights_error_method getter.
|
|
289
|
-
"""
|
|
290
|
-
return self._weights_error_method
|
|
291
|
-
|
|
292
|
-
@weights_error_method.setter
|
|
293
|
-
def weights_error_method(self, value: QuantizationErrorMethod):
|
|
294
|
-
"""
|
|
295
|
-
weights_error_method setter.
|
|
296
|
-
|
|
297
|
-
Args:
|
|
298
|
-
value: New weights_error_method to set to the node weights configuration.
|
|
299
|
-
|
|
300
|
-
"""
|
|
301
|
-
self._weights_error_method = value
|
|
302
|
-
self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_quantization_method=self.weights_quantization_method)
|
|
303
|
-
|
|
304
|
-
def set_weights_quantization_fn(self, weights_quantization_fn: Callable):
|
|
305
|
-
"""
|
|
306
|
-
Sets weights quantization function for the node.
|
|
307
|
-
|
|
308
|
-
Args:
|
|
309
|
-
weights_quantization_fn: Function for quantazing the weights.
|
|
214
|
+
self.weights_quantization_params = {}
|
|
310
215
|
|
|
311
|
-
|
|
312
|
-
self.
|
|
216
|
+
# TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
|
|
217
|
+
self.weights_error_method = None
|
|
218
|
+
self.l_p_value = None
|
|
313
219
|
|
|
314
|
-
def
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
Args:
|
|
319
|
-
weights_quantization_params_fn: Function for calculating the weights params.
|
|
320
|
-
|
|
321
|
-
"""
|
|
322
|
-
self.weights_quantization_params_fn = weights_quantization_params_fn
|
|
220
|
+
def set_qc(self, qc: QuantizationConfig):
|
|
221
|
+
# TODO irena: temporary keep the fields to not break everything at once.
|
|
222
|
+
self.weights_error_method = qc.weights_error_method
|
|
223
|
+
self.l_p_value = qc.l_p_value
|
|
323
224
|
|
|
324
225
|
def set_weights_quantization_param(self,
|
|
325
226
|
weights_params: dict):
|
|
@@ -334,31 +235,6 @@ class WeightsAttrQuantizationConfig:
|
|
|
334
235
|
for param_name, param_value in weights_params.items():
|
|
335
236
|
self.weights_quantization_params[param_name] = param_value
|
|
336
237
|
|
|
337
|
-
def calculate_and_set_weights_params(self, tensor_data: np.ndarray, min_threshold: float):
|
|
338
|
-
"""
|
|
339
|
-
Args:
|
|
340
|
-
tensor_data: Tensor content as Numpy array.
|
|
341
|
-
min_threshold: A minimal threshold to set as quantization parameter.
|
|
342
|
-
|
|
343
|
-
Returns:
|
|
344
|
-
Recalculated weights quantization params from the kernel and channel axis.
|
|
345
|
-
|
|
346
|
-
"""
|
|
347
|
-
assert self.enable_weights_quantization
|
|
348
|
-
assert not (self.weights_per_channel_threshold and self.weights_channels_axis is None), \
|
|
349
|
-
"Trying to calculate threshold per channel, channel axis in None."
|
|
350
|
-
if self.weights_quantization_params_fn is not None:
|
|
351
|
-
self.set_weights_quantization_param(
|
|
352
|
-
self.weights_quantization_params_fn(tensor_data,
|
|
353
|
-
p=self.l_p_value,
|
|
354
|
-
n_bits=self.weights_n_bits,
|
|
355
|
-
per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
|
|
356
|
-
channel_axis=self.weights_channels_axis.output, # output channel axis
|
|
357
|
-
min_threshold=min_threshold)[0] # Take only first output, the q-params, as axis is already chosen.
|
|
358
|
-
)
|
|
359
|
-
else:
|
|
360
|
-
self.set_weights_quantization_param({})
|
|
361
|
-
|
|
362
238
|
def __eq__(self, other: Any) -> bool:
|
|
363
239
|
"""
|
|
364
240
|
Compares the object to another object to find if they are equal.
|
|
@@ -372,20 +248,16 @@ class WeightsAttrQuantizationConfig:
|
|
|
372
248
|
if not isinstance(other, WeightsAttrQuantizationConfig):
|
|
373
249
|
return False # pragma: no cover
|
|
374
250
|
|
|
375
|
-
return self.
|
|
376
|
-
self.weights_quantization_params_fn == other.weights_quantization_params_fn and \
|
|
377
|
-
self.weights_channels_axis == other.weights_channels_axis and \
|
|
378
|
-
self.weights_error_method == other.weights_error_method and \
|
|
251
|
+
return self.weights_channels_axis == other.weights_channels_axis and \
|
|
379
252
|
self.weights_quantization_method == other.weights_quantization_method and \
|
|
380
253
|
self.weights_n_bits == other.weights_n_bits and \
|
|
381
254
|
self.weights_per_channel_threshold == other.weights_per_channel_threshold and \
|
|
382
255
|
self.enable_weights_quantization == other.enable_weights_quantization and \
|
|
256
|
+
self.weights_error_method == other.weights_error_method and \
|
|
383
257
|
self.l_p_value == other.l_p_value
|
|
384
258
|
|
|
385
259
|
def __hash__(self):
|
|
386
|
-
return hash((self.
|
|
387
|
-
self.weights_quantization_params_fn,
|
|
388
|
-
self.weights_channels_axis,
|
|
260
|
+
return hash((self.weights_channels_axis,
|
|
389
261
|
self.weights_error_method,
|
|
390
262
|
self.weights_quantization_method,
|
|
391
263
|
self.weights_n_bits,
|
|
@@ -399,23 +271,19 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
399
271
|
Holding a mapping between the node's weights attributes and their quantization configurations,
|
|
400
272
|
in addition to quantization parameters that are global for all attributes of the represented node.
|
|
401
273
|
"""
|
|
402
|
-
def __init__(self,
|
|
274
|
+
def __init__(self,
|
|
403
275
|
op_cfg: OpQuantizationConfig,
|
|
404
276
|
weights_channels_axis: ChannelAxisMapping,
|
|
405
277
|
node_attrs_list: List[str]):
|
|
406
278
|
"""
|
|
407
279
|
|
|
408
280
|
Args:
|
|
409
|
-
qc: QuantizationConfig to create the node's config from.
|
|
410
281
|
op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
|
|
411
282
|
weights_channels_axis: Axis to quantize a node's weights attribute when quantizing per-channel.
|
|
412
283
|
node_attrs_list: A list of the node's weights attributes names.
|
|
413
284
|
|
|
414
285
|
"""
|
|
415
|
-
self.min_threshold = qc.min_threshold
|
|
416
286
|
self.simd_size = op_cfg.simd_size
|
|
417
|
-
self.weights_second_moment_correction = qc.weights_second_moment_correction
|
|
418
|
-
self.weights_bias_correction = qc.weights_bias_correction
|
|
419
287
|
|
|
420
288
|
# Initialize a quantization configuration for each of the node's attributes
|
|
421
289
|
self.attributes_config_mapping = {}
|
|
@@ -427,7 +295,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
427
295
|
# POS_ATTR string. If none are found, it indicates that no specific quantization config is defined for
|
|
428
296
|
# positional weights, so the default config will be used instead.
|
|
429
297
|
attrs_included_in_name = {k: v for k, v in op_cfg.attr_weights_configs_mapping.items() if
|
|
430
|
-
|
|
298
|
+
POSITIONAL_ATTR in k}
|
|
431
299
|
|
|
432
300
|
if len(attrs_included_in_name) > 1: # pragma: no cover
|
|
433
301
|
raise ValueError(f"Found multiple attribute in FQC OpConfig that are contained "
|
|
@@ -443,8 +311,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
443
311
|
attr_cfg = list(attrs_included_in_name.values())[0]
|
|
444
312
|
|
|
445
313
|
# Register this attribute under the positional attributes config mapping.
|
|
446
|
-
self.pos_attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(
|
|
447
|
-
weights_attr_cfg=attr_cfg,
|
|
314
|
+
self.pos_attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(weights_attr_cfg=attr_cfg,
|
|
448
315
|
weights_channels_axis=
|
|
449
316
|
weights_channels_axis)
|
|
450
317
|
else:
|
|
@@ -461,9 +328,18 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
461
328
|
else:
|
|
462
329
|
attr_cfg = list(attrs_included_in_name.values())[0]
|
|
463
330
|
|
|
464
|
-
self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(
|
|
465
|
-
weights_attr_cfg=attr_cfg,
|
|
331
|
+
self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(weights_attr_cfg=attr_cfg,
|
|
466
332
|
weights_channels_axis=weights_channels_axis)
|
|
333
|
+
# TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
|
|
334
|
+
self.min_threshold = None
|
|
335
|
+
self.weights_second_moment_correction = None
|
|
336
|
+
self.weights_bias_correction = None
|
|
337
|
+
|
|
338
|
+
def set_qc(self, qc: QuantizationConfig):
|
|
339
|
+
# TODO irena: temporary keep the fields to not break everything at once.
|
|
340
|
+
self.min_threshold = qc.min_threshold
|
|
341
|
+
self.weights_second_moment_correction = qc.weights_second_moment_correction
|
|
342
|
+
self.weights_bias_correction = qc.weights_bias_correction
|
|
467
343
|
|
|
468
344
|
def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationConfig:
|
|
469
345
|
"""
|
|
@@ -14,15 +14,35 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
from collections.abc import Callable
|
|
17
|
-
from functools import partial
|
|
18
17
|
|
|
19
18
|
from mct_quantizers import QuantizationMethod
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
|
|
20
21
|
from model_compression_toolkit.logger import Logger
|
|
21
22
|
from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
|
|
22
23
|
from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
|
|
23
24
|
symmetric_quantizer, uniform_quantizer
|
|
24
25
|
|
|
25
26
|
|
|
27
|
+
def get_activation_quantization_fn(activation_quantization_cfg: NodeActivationQuantizationConfig,
|
|
28
|
+
get_activation_quantization_fn_factory: Callable) -> Callable:
|
|
29
|
+
"""
|
|
30
|
+
Get activation quantizer based on activation quantization configuration.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
activation_quantization_cfg: activation quantization configuration.
|
|
34
|
+
get_activation_quantization_fn_factory: activation quantization functions factory.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Activation quantizer that accepts a tensor and returns a quantized tensor.
|
|
38
|
+
"""
|
|
39
|
+
quantizer_factory = get_activation_quantization_fn_factory(
|
|
40
|
+
activation_quantization_cfg.activation_quantization_method)
|
|
41
|
+
quantizer = quantizer_factory(activation_quantization_cfg.activation_n_bits,
|
|
42
|
+
activation_quantization_cfg.activation_quantization_params)
|
|
43
|
+
return quantizer
|
|
44
|
+
|
|
45
|
+
|
|
26
46
|
def get_weights_quantization_fn(weights_quantization_method: QuantizationMethod) -> Callable:
|
|
27
47
|
"""
|
|
28
48
|
Generate a function for weight quantization.
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py
CHANGED
|
@@ -12,9 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import
|
|
16
|
-
power_of_two_selection_histogram, power_of_two_selection_tensor
|
|
17
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import
|
|
18
|
-
|
|
19
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.
|
|
15
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import (
|
|
16
|
+
power_of_two_no_clipping_selection_min_max, power_of_two_selection_histogram, power_of_two_selection_tensor)
|
|
17
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import (
|
|
18
|
+
lut_kmeans_tensor, lut_kmeans_histogram)
|
|
19
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import (
|
|
20
|
+
symmetric_no_clipping_selection_min_max, symmetric_selection_histogram, symmetric_selection_tensor)
|
|
21
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import (
|
|
22
|
+
uniform_no_clipping_selection_min_max, uniform_selection_histogram, uniform_selection_tensor)
|
|
20
23
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.outlier_filter import z_score_filter
|
|
@@ -13,17 +13,59 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
import numpy as np
|
|
16
|
-
from typing import Dict, Union, Optional, Tuple
|
|
16
|
+
from typing import Dict, Union, Optional, Tuple, Callable
|
|
17
17
|
|
|
18
18
|
from mct_quantizers import QuantizationMethod
|
|
19
|
-
|
|
19
|
+
|
|
20
|
+
import model_compression_toolkit.core.common.quantization.quantization_params_generation as qpg
|
|
20
21
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
|
|
21
22
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
|
22
|
-
from model_compression_toolkit.core.common.quantization import quantization_params_generation
|
|
23
23
|
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
|
24
24
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
|
|
25
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationErrorMethod
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def compute_activation_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
|
|
29
|
+
node_prior_info: NodePriorInfo,
|
|
30
|
+
out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
|
|
31
|
+
"""
|
|
32
|
+
Compute the activations params for a given node in a graph according to a params function.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
activation_quant_cfg: node's activation quantization configuration.
|
|
36
|
+
node_prior_info: Prior info collected for the node that is being quantized.
|
|
37
|
+
out_stats_container: Tensor containing output statistics of the node.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The computed activation quantization params.
|
|
41
|
+
"""
|
|
42
|
+
activation_quantization_params_fn = _get_activation_quantization_params_fn(
|
|
43
|
+
activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())
|
|
44
|
+
|
|
45
|
+
# Extract and filter histogram data from the statistics container.
|
|
46
|
+
bins_values, bins_counts = _get_histogram_data(activation_quant_cfg, out_stats_container)
|
|
47
|
+
|
|
48
|
+
# Retrieve the minimum and maximum values from the statistics container.
|
|
49
|
+
min_value, max_value = out_stats_container.get_min_max_values()
|
|
50
|
+
|
|
51
|
+
# Determine if the activations should be considered signed.
|
|
52
|
+
signed = _determine_signedness(activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts)
|
|
25
53
|
|
|
26
|
-
|
|
54
|
+
# Compute and return the activation quantization parameters.
|
|
55
|
+
return activation_quantization_params_fn(
|
|
56
|
+
bins_values,
|
|
57
|
+
bins_counts,
|
|
58
|
+
activation_quant_cfg.l_p_value,
|
|
59
|
+
activation_quant_cfg.activation_n_bits,
|
|
60
|
+
min_value,
|
|
61
|
+
max_value,
|
|
62
|
+
min_threshold=activation_quant_cfg.min_threshold,
|
|
63
|
+
quant_error_method=activation_quant_cfg.activation_error_method,
|
|
64
|
+
is_signed=signed
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _get_histogram_data(
|
|
27
69
|
activation_quant_cfg: NodeActivationQuantizationConfig,
|
|
28
70
|
out_stats_container: BaseStatsCollector
|
|
29
71
|
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
|
@@ -38,7 +80,6 @@ def get_histogram_data(
|
|
|
38
80
|
A tuple containing the filtered bins_values and bins_counts.
|
|
39
81
|
"""
|
|
40
82
|
bins_values, bins_counts = None, None
|
|
41
|
-
|
|
42
83
|
# If the statistics container collected the histogram, we start by filtering outliers using z threshold
|
|
43
84
|
# filtering, and then computing the threshold based on the filtered histogram.
|
|
44
85
|
if out_stats_container.require_collection():
|
|
@@ -46,14 +87,15 @@ def get_histogram_data(
|
|
|
46
87
|
bins_values, bins_counts = out_stats_container.weighted_hc.get_histogram()
|
|
47
88
|
else:
|
|
48
89
|
bins_values, bins_counts = out_stats_container.hc.get_histogram()
|
|
49
|
-
bins_counts =
|
|
90
|
+
bins_counts = qpg.z_score_filter(
|
|
50
91
|
activation_quant_cfg.z_threshold,
|
|
51
92
|
bins_values,
|
|
52
93
|
bins_counts
|
|
53
94
|
)
|
|
54
95
|
return bins_values, bins_counts
|
|
55
96
|
|
|
56
|
-
|
|
97
|
+
|
|
98
|
+
def _determine_signedness(
|
|
57
99
|
activation_quant_cfg: NodeActivationQuantizationConfig,
|
|
58
100
|
nodes_prior_info: NodePriorInfo,
|
|
59
101
|
min_value: float,
|
|
@@ -83,73 +125,37 @@ def determine_signedness(
|
|
|
83
125
|
return np.any(bins_values[:-1][bins_counts > 0] < 0)
|
|
84
126
|
|
|
85
127
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
128
|
+
_activation_quant_params_fns = {
|
|
129
|
+
QuantizationMethod.POWER_OF_TWO: qpg.power_of_two_selection_histogram,
|
|
130
|
+
QuantizationMethod.SYMMETRIC: qpg.symmetric_selection_histogram,
|
|
131
|
+
QuantizationMethod.UNIFORM: qpg.uniform_selection_histogram,
|
|
132
|
+
QuantizationMethod.LUT_POT_QUANTIZER: qpg.lut_kmeans_histogram
|
|
133
|
+
}
|
|
134
|
+
_activation_no_clipping_quant_params_fns = {
|
|
135
|
+
QuantizationMethod.POWER_OF_TWO: qpg.power_of_two_no_clipping_selection_min_max,
|
|
136
|
+
QuantizationMethod.SYMMETRIC: qpg.symmetric_no_clipping_selection_min_max,
|
|
137
|
+
QuantizationMethod.UNIFORM: qpg.uniform_no_clipping_selection_min_max,
|
|
138
|
+
QuantizationMethod.LUT_POT_QUANTIZER: qpg.lut_kmeans_histogram
|
|
139
|
+
}
|
|
92
140
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
"""
|
|
97
|
-
if nodes_prior_info.is_output_bounded():
|
|
98
|
-
if activation_quant_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
|
|
99
|
-
activation_quant_cfg.set_activation_quantization_params_fn(
|
|
100
|
-
quantization_params_generation.power_of_two_no_clipping_selection_min_max
|
|
101
|
-
)
|
|
102
|
-
elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.SYMMETRIC:
|
|
103
|
-
activation_quant_cfg.set_activation_quantization_params_fn(
|
|
104
|
-
quantization_params_generation.symmetric_no_clipping_selection_min_max
|
|
105
|
-
)
|
|
106
|
-
elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.UNIFORM:
|
|
107
|
-
activation_quant_cfg.set_activation_quantization_params_fn(
|
|
108
|
-
quantization_params_generation.uniform_no_clipping_selection_min_max
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
|
|
113
|
-
nodes_prior_info: NodePriorInfo,
|
|
114
|
-
out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
|
|
141
|
+
|
|
142
|
+
def _get_activation_quantization_params_fn(activation_quantization_method: QuantizationMethod,
|
|
143
|
+
no_clipping: bool) -> Callable:
|
|
115
144
|
"""
|
|
116
|
-
|
|
145
|
+
Generate a function for finding activation quantization parameters.
|
|
117
146
|
|
|
118
147
|
Args:
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
out_stats_container: Tensor containing output statistics of the node.
|
|
148
|
+
activation_quantization_method: Which quantization method to use for activations.
|
|
149
|
+
no_clipping: Whether to use the no-clipping version of the quantizer (if available).
|
|
122
150
|
|
|
123
151
|
Returns:
|
|
124
|
-
|
|
152
|
+
A function to find the quantization parameters.
|
|
125
153
|
"""
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
# Determine if the activations should be considered signed.
|
|
136
|
-
signed = determine_signedness(
|
|
137
|
-
activation_quant_cfg,
|
|
138
|
-
nodes_prior_info,
|
|
139
|
-
min_value,
|
|
140
|
-
bins_values,
|
|
141
|
-
bins_counts
|
|
142
|
-
)
|
|
143
|
-
|
|
144
|
-
# Compute and return the activation quantization parameters.
|
|
145
|
-
return activation_quant_cfg.activation_quantization_params_fn(
|
|
146
|
-
bins_values,
|
|
147
|
-
bins_counts,
|
|
148
|
-
activation_quant_cfg.l_p_value,
|
|
149
|
-
activation_quant_cfg.activation_n_bits,
|
|
150
|
-
min_value,
|
|
151
|
-
max_value,
|
|
152
|
-
min_threshold=activation_quant_cfg.min_threshold,
|
|
153
|
-
quant_error_method=activation_quant_cfg.activation_error_method,
|
|
154
|
-
is_signed=signed
|
|
155
|
-
)
|
|
154
|
+
if no_clipping:
|
|
155
|
+
params_fn = _activation_no_clipping_quant_params_fns.get(activation_quantization_method)
|
|
156
|
+
else:
|
|
157
|
+
params_fn = _activation_quant_params_fns.get(activation_quantization_method)
|
|
158
|
+
if params_fn is None:
|
|
159
|
+
raise ValueError(f"No parameter function found for the specified quantization method: "
|
|
160
|
+
"{activation_quantization_method}") # pragma: no cover
|
|
161
|
+
return params_fn
|
|
@@ -25,9 +25,9 @@ from model_compression_toolkit.core.common.framework_implementation import Frame
|
|
|
25
25
|
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
|
|
26
26
|
HessianScoresGranularity
|
|
27
27
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
|
|
28
|
-
import
|
|
28
|
+
import compute_activation_qparams
|
|
29
29
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
|
|
30
|
-
|
|
30
|
+
compute_weights_qparams
|
|
31
31
|
from model_compression_toolkit.logger import Logger
|
|
32
32
|
|
|
33
33
|
|
|
@@ -119,21 +119,19 @@ def calculate_quantization_params(graph: Graph,
|
|
|
119
119
|
mod_attr_cfg = copy.deepcopy(attr_cfg)
|
|
120
120
|
mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE
|
|
121
121
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
num_hessian_samples=num_hessian_samples)
|
|
122
|
+
min_threshold = candidate_qc.weights_quantization_cfg.min_threshold
|
|
123
|
+
weights_params, output_channels_axis = compute_weights_qparams(n.get_weights_by_keys(attr),
|
|
124
|
+
mod_attr_cfg, output_channels_axis,
|
|
125
|
+
min_threshold=min_threshold, node=n,
|
|
126
|
+
hessian_info_service=hessian_info_service,
|
|
127
|
+
num_hessian_samples=num_hessian_samples)
|
|
129
128
|
attr_cfg.weights_channels_axis = ChannelAxisMapping(output_channels_axis, attr_cfg.weights_channels_axis.input)
|
|
130
129
|
attr_cfg.set_weights_quantization_param(weights_params)
|
|
131
130
|
|
|
132
131
|
if n.is_activation_quantization_enabled():
|
|
133
132
|
# If node's activations should be quantized as well, we compute its activation quantization parameters
|
|
134
|
-
activation_params =
|
|
135
|
-
activation_quant_cfg=candidate_qc.activation_quantization_cfg,
|
|
136
|
-
nodes_prior_info=n.prior_info,
|
|
133
|
+
activation_params = compute_activation_qparams(
|
|
134
|
+
activation_quant_cfg=candidate_qc.activation_quantization_cfg, node_prior_info=n.prior_info,
|
|
137
135
|
out_stats_container=graph.get_out_stats_collector(n))
|
|
138
136
|
# Create a NodeQuantizationConfig containing all quantization params and attach it to the node
|
|
139
137
|
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params)
|