mct-nightly 2.4.0.20250630.629__py3-none-any.whl → 2.4.0.20250702.605__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/METADATA +16 -16
  2. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/RECORD +75 -72
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
  5. model_compression_toolkit/core/common/framework_info.py +5 -32
  6. model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
  7. model_compression_toolkit/core/common/graph/base_graph.py +20 -37
  8. model_compression_toolkit/core/common/graph/base_node.py +13 -106
  9. model_compression_toolkit/core/common/graph/functional_node.py +1 -1
  10. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
  11. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
  12. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
  13. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
  14. model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
  15. model_compression_toolkit/core/common/network_editors/actions.py +4 -96
  16. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  17. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
  18. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
  19. model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
  20. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
  22. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
  23. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
  24. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
  25. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  26. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
  27. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
  28. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
  29. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
  30. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
  31. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
  32. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
  33. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
  34. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
  35. model_compression_toolkit/core/graph_prep_runner.py +31 -20
  36. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
  37. model_compression_toolkit/core/keras/default_framework_info.py +0 -11
  38. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
  39. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
  40. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
  41. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
  42. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
  43. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
  44. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
  45. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
  46. model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
  47. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  48. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
  49. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
  50. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
  51. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
  52. model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
  53. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
  54. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
  55. model_compression_toolkit/core/runner.py +1 -1
  56. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
  57. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  58. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
  59. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
  60. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
  61. model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
  62. model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
  63. model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
  64. model_compression_toolkit/quantization_preparation/__init__.py +14 -0
  65. model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
  66. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  67. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
  68. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/WHEEL +0 -0
  69. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/licenses/LICENSE.md +0 -0
  70. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/top_level.txt +0 -0
  71. /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
  72. /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
  73. /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
  74. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
  75. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
  76. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
@@ -12,24 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
-
17
- from typing import Callable, Any, List, Tuple, Union, Dict, TYPE_CHECKING
15
+ from typing import Any, List, Dict, TYPE_CHECKING
18
16
  from enum import Enum, auto
19
- import numpy as np
20
17
 
21
18
  from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
22
- from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
23
19
  from model_compression_toolkit.logger import Logger
24
- from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
25
- get_activation_quantization_params_fn, get_weights_quantization_params_fn
26
20
 
27
- from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
28
- QuantizationErrorMethod
29
- from model_compression_toolkit.target_platform_capabilities.constants import POS_ATTR
21
+ from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
22
+ from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
30
23
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import \
31
- AttributeQuantizationConfig, \
32
- OpQuantizationConfig
24
+ AttributeQuantizationConfig, OpQuantizationConfig
33
25
 
34
26
  if TYPE_CHECKING:
35
27
  from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
@@ -86,29 +78,14 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
86
78
  """
87
79
  Attributes for configuring the quantization of the activations of a node.
88
80
  """
89
- def __init__(self,
90
- qc: QuantizationConfig,
91
- op_cfg: OpQuantizationConfig,
92
- activation_quantization_fn: Callable,
93
- activation_quantization_params_fn: Callable
94
- ):
81
+ def __init__(self, op_cfg: OpQuantizationConfig):
95
82
  """
96
83
 
97
84
  Args:
98
- qc: QuantizationConfig to create the node's config from.
99
85
  op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
100
- activation_quantization_fn: Function to use when quantizing the node's activations.
101
- activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
102
86
  """
103
-
104
- self.activation_quantization_fn = activation_quantization_fn
105
- self.activation_quantization_params_fn = activation_quantization_params_fn
106
- self.activation_quantization_params = {}
107
87
  self.activation_quantization_method = op_cfg.activation_quantization_method
108
- self.activation_error_method = qc.activation_error_method
109
88
  self.activation_n_bits = op_cfg.activation_n_bits
110
- self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
111
- self.activation_bias_correction_term = None
112
89
  if op_cfg.enable_activation_quantization and op_cfg.quantization_preserving:
113
90
  raise ValueError("An OpQuantizationConfig can't have both enable_activation_quantization and quantization_preserving enabled.")
114
91
  if op_cfg.enable_activation_quantization:
@@ -118,6 +95,29 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
118
95
  else:
119
96
  self.quant_mode = ActivationQuantizationMode.NO_QUANT
120
97
  self.signedness = op_cfg.signedness
98
+
99
+ self.activation_quantization_params = {}
100
+ # TODO irena: computed by compute_activation_bias_correction. shouldnt really be here
101
+ self.activation_bias_correction_term = None
102
+
103
+ # TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
104
+ self.activation_error_method = None
105
+ self.relu_bound_to_power_of_2 = None
106
+ self.activation_channel_equalization = None
107
+ self.input_scaling = None
108
+ self.min_threshold = None
109
+ self.l_p_value = None
110
+ self.shift_negative_activation_correction = None
111
+ self.z_threshold = None
112
+ self.shift_negative_ratio = None
113
+ self.shift_negative_threshold_recalculation = None
114
+ self.concat_threshold_update = None
115
+
116
+ def set_qc(self, qc: QuantizationConfig):
117
+ """ TODO irena: temporary keep all the attributes as before not to break all code at once.
118
+ Eventually all of them should be removed from here. """
119
+ self.activation_error_method = qc.activation_error_method
120
+ self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
121
121
  self.activation_channel_equalization = qc.activation_channel_equalization
122
122
  self.input_scaling = qc.input_scaling
123
123
  self.min_threshold = qc.min_threshold
@@ -139,65 +139,6 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
139
139
  def fln_quantization(self):
140
140
  return self.quant_mode == ActivationQuantizationMode.FLN_QUANT
141
141
 
142
- def quantize_node_output(self,
143
- tensors: Any) -> Any:
144
- """
145
-
146
- Args:
147
- tensors: framework tensor/s
148
-
149
- Returns:
150
- Framework tensor/s after applying fake quantization.
151
-
152
- """
153
- fake_quant = self.activation_quantization_fn(self.activation_n_bits,
154
- self.activation_quantization_params)
155
-
156
- if fake_quant is None:
157
- Logger.critical(
158
- "Layer is intended to be quantized, but the fake_quant function is None.") # pragma: no cover
159
-
160
- return fake_quant(tensors)
161
-
162
- @property
163
- def activation_error_method(self) -> QuantizationErrorMethod:
164
- """
165
- activation_error_method getter.
166
- """
167
- return self._activation_error_method
168
-
169
- @activation_error_method.setter
170
- def activation_error_method(self, value: QuantizationErrorMethod):
171
- """
172
- activation_error_method setter.
173
-
174
- Args:
175
- value: New activation_error_method to set to the node activation configuration.
176
-
177
- """
178
- self._activation_error_method = value
179
- self.activation_quantization_params_fn = get_activation_quantization_params_fn(activation_quantization_method=self.activation_quantization_method)
180
-
181
- def set_activation_quantization_fn(self, activation_quantization_fn: Callable):
182
- """
183
- Sets activation quantization function for the node.
184
-
185
- Args:
186
- activation_quantization_fn: Function for quantazing the activations.
187
-
188
- """
189
- self.activation_quantization_fn = activation_quantization_fn
190
-
191
- def set_activation_quantization_params_fn(self, activation_quantization_params_fn:Callable):
192
- """
193
- Sets activation params function for the node.
194
-
195
- Args:
196
- activation_quantization_params_fn: Function for calculating activation params.
197
-
198
- """
199
- self.activation_quantization_params_fn = activation_quantization_params_fn
200
-
201
142
  def set_activation_quantization_param(self,
202
143
  activation_params: dict):
203
144
  """
@@ -224,9 +165,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
224
165
  if not isinstance(other, NodeActivationQuantizationConfig):
225
166
  return False # pragma: no cover
226
167
 
227
- return self.activation_quantization_fn == other.activation_quantization_fn and \
228
- self.activation_quantization_params_fn == other.activation_quantization_params_fn and \
229
- self.activation_error_method == other.activation_error_method and \
168
+ return self.activation_error_method == other.activation_error_method and \
230
169
  self.activation_quantization_method == other.activation_quantization_method and \
231
170
  self.activation_n_bits == other.activation_n_bits and \
232
171
  self.quant_mode == other.quant_mode and \
@@ -240,9 +179,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
240
179
  self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
241
180
 
242
181
  def __hash__(self):
243
- return hash((self.activation_quantization_fn,
244
- self.activation_quantization_params_fn,
245
- self.activation_error_method,
182
+ return hash((self.activation_error_method,
246
183
  self.activation_quantization_method,
247
184
  self.activation_n_bits,
248
185
  self.quant_mode,
@@ -261,65 +198,29 @@ class WeightsAttrQuantizationConfig:
261
198
  Configuration for quantizing a weights attribute of a node.
262
199
  """
263
200
  def __init__(self,
264
- qc: QuantizationConfig,
265
201
  weights_attr_cfg: AttributeQuantizationConfig,
266
202
  weights_channels_axis: ChannelAxisMapping = None):
267
203
  """
268
204
 
269
205
  Args:
270
- qc: QuantizationConfig to create the node's config from.
271
206
  weights_attr_cfg: AttributeQuantizationConfig with parameters to use when creating the node's attribute quantization config.
272
207
  weights_channels_axis: Axis to quantize a node's attribute when quantizing per-channel (if not quantizing per-channel than expecting None).
273
208
  """
274
- self.weights_quantization_fn = get_weights_quantization_fn(weights_attr_cfg.weights_quantization_method)
275
- self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_attr_cfg.weights_quantization_method)
276
209
  self.weights_channels_axis = weights_channels_axis
277
- self.weights_quantization_params = {}
278
210
  self.weights_quantization_method = weights_attr_cfg.weights_quantization_method
279
- self.weights_error_method = qc.weights_error_method
280
211
  self.weights_n_bits = weights_attr_cfg.weights_n_bits
281
212
  self.weights_per_channel_threshold = weights_attr_cfg.weights_per_channel_threshold
282
213
  self.enable_weights_quantization = weights_attr_cfg.enable_weights_quantization
283
- self.l_p_value = qc.l_p_value
284
-
285
- @property
286
- def weights_error_method(self) -> QuantizationErrorMethod:
287
- """
288
- weights_error_method getter.
289
- """
290
- return self._weights_error_method
291
-
292
- @weights_error_method.setter
293
- def weights_error_method(self, value: QuantizationErrorMethod):
294
- """
295
- weights_error_method setter.
296
-
297
- Args:
298
- value: New weights_error_method to set to the node weights configuration.
299
-
300
- """
301
- self._weights_error_method = value
302
- self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_quantization_method=self.weights_quantization_method)
303
-
304
- def set_weights_quantization_fn(self, weights_quantization_fn: Callable):
305
- """
306
- Sets weights quantization function for the node.
307
-
308
- Args:
309
- weights_quantization_fn: Function for quantazing the weights.
214
+ self.weights_quantization_params = {}
310
215
 
311
- """
312
- self.weights_quantization_fn = weights_quantization_fn
216
+ # TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
217
+ self.weights_error_method = None
218
+ self.l_p_value = None
313
219
 
314
- def set_weights_quantization_params_fn(self, weights_quantization_params_fn: Callable):
315
- """
316
- Sets weights params function for the node.
317
-
318
- Args:
319
- weights_quantization_params_fn: Function for calculating the weights params.
320
-
321
- """
322
- self.weights_quantization_params_fn = weights_quantization_params_fn
220
+ def set_qc(self, qc: QuantizationConfig):
221
+ # TODO irena: temporary keep the fields to not break everything at once.
222
+ self.weights_error_method = qc.weights_error_method
223
+ self.l_p_value = qc.l_p_value
323
224
 
324
225
  def set_weights_quantization_param(self,
325
226
  weights_params: dict):
@@ -334,31 +235,6 @@ class WeightsAttrQuantizationConfig:
334
235
  for param_name, param_value in weights_params.items():
335
236
  self.weights_quantization_params[param_name] = param_value
336
237
 
337
- def calculate_and_set_weights_params(self, tensor_data: np.ndarray, min_threshold: float):
338
- """
339
- Args:
340
- tensor_data: Tensor content as Numpy array.
341
- min_threshold: A minimal threshold to set as quantization parameter.
342
-
343
- Returns:
344
- Recalculated weights quantization params from the kernel and channel axis.
345
-
346
- """
347
- assert self.enable_weights_quantization
348
- assert not (self.weights_per_channel_threshold and self.weights_channels_axis is None), \
349
- "Trying to calculate threshold per channel, channel axis in None."
350
- if self.weights_quantization_params_fn is not None:
351
- self.set_weights_quantization_param(
352
- self.weights_quantization_params_fn(tensor_data,
353
- p=self.l_p_value,
354
- n_bits=self.weights_n_bits,
355
- per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
356
- channel_axis=self.weights_channels_axis.output, # output channel axis
357
- min_threshold=min_threshold)[0] # Take only first output, the q-params, as axis is already chosen.
358
- )
359
- else:
360
- self.set_weights_quantization_param({})
361
-
362
238
  def __eq__(self, other: Any) -> bool:
363
239
  """
364
240
  Compares the object to another object to find if they are equal.
@@ -372,20 +248,16 @@ class WeightsAttrQuantizationConfig:
372
248
  if not isinstance(other, WeightsAttrQuantizationConfig):
373
249
  return False # pragma: no cover
374
250
 
375
- return self.weights_quantization_fn == other.weights_quantization_fn and \
376
- self.weights_quantization_params_fn == other.weights_quantization_params_fn and \
377
- self.weights_channels_axis == other.weights_channels_axis and \
378
- self.weights_error_method == other.weights_error_method and \
251
+ return self.weights_channels_axis == other.weights_channels_axis and \
379
252
  self.weights_quantization_method == other.weights_quantization_method and \
380
253
  self.weights_n_bits == other.weights_n_bits and \
381
254
  self.weights_per_channel_threshold == other.weights_per_channel_threshold and \
382
255
  self.enable_weights_quantization == other.enable_weights_quantization and \
256
+ self.weights_error_method == other.weights_error_method and \
383
257
  self.l_p_value == other.l_p_value
384
258
 
385
259
  def __hash__(self):
386
- return hash((self.weights_quantization_fn,
387
- self.weights_quantization_params_fn,
388
- self.weights_channels_axis,
260
+ return hash((self.weights_channels_axis,
389
261
  self.weights_error_method,
390
262
  self.weights_quantization_method,
391
263
  self.weights_n_bits,
@@ -399,23 +271,19 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
399
271
  Holding a mapping between the node's weights attributes and their quantization configurations,
400
272
  in addition to quantization parameters that are global for all attributes of the represented node.
401
273
  """
402
- def __init__(self, qc: QuantizationConfig,
274
+ def __init__(self,
403
275
  op_cfg: OpQuantizationConfig,
404
276
  weights_channels_axis: ChannelAxisMapping,
405
277
  node_attrs_list: List[str]):
406
278
  """
407
279
 
408
280
  Args:
409
- qc: QuantizationConfig to create the node's config from.
410
281
  op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
411
282
  weights_channels_axis: Axis to quantize a node's weights attribute when quantizing per-channel.
412
283
  node_attrs_list: A list of the node's weights attributes names.
413
284
 
414
285
  """
415
- self.min_threshold = qc.min_threshold
416
286
  self.simd_size = op_cfg.simd_size
417
- self.weights_second_moment_correction = qc.weights_second_moment_correction
418
- self.weights_bias_correction = qc.weights_bias_correction
419
287
 
420
288
  # Initialize a quantization configuration for each of the node's attributes
421
289
  self.attributes_config_mapping = {}
@@ -427,7 +295,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
427
295
  # POS_ATTR string. If none are found, it indicates that no specific quantization config is defined for
428
296
  # positional weights, so the default config will be used instead.
429
297
  attrs_included_in_name = {k: v for k, v in op_cfg.attr_weights_configs_mapping.items() if
430
- POS_ATTR in k}
298
+ POSITIONAL_ATTR in k}
431
299
 
432
300
  if len(attrs_included_in_name) > 1: # pragma: no cover
433
301
  raise ValueError(f"Found multiple attribute in FQC OpConfig that are contained "
@@ -443,8 +311,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
443
311
  attr_cfg = list(attrs_included_in_name.values())[0]
444
312
 
445
313
  # Register this attribute under the positional attributes config mapping.
446
- self.pos_attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(qc=qc,
447
- weights_attr_cfg=attr_cfg,
314
+ self.pos_attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(weights_attr_cfg=attr_cfg,
448
315
  weights_channels_axis=
449
316
  weights_channels_axis)
450
317
  else:
@@ -461,9 +328,18 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
461
328
  else:
462
329
  attr_cfg = list(attrs_included_in_name.values())[0]
463
330
 
464
- self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(qc=qc,
465
- weights_attr_cfg=attr_cfg,
331
+ self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(weights_attr_cfg=attr_cfg,
466
332
  weights_channels_axis=weights_channels_axis)
333
+ # TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
334
+ self.min_threshold = None
335
+ self.weights_second_moment_correction = None
336
+ self.weights_bias_correction = None
337
+
338
+ def set_qc(self, qc: QuantizationConfig):
339
+ # TODO irena: temporary keep the fields to not break everything at once.
340
+ self.min_threshold = qc.min_threshold
341
+ self.weights_second_moment_correction = qc.weights_second_moment_correction
342
+ self.weights_bias_correction = qc.weights_bias_correction
467
343
 
468
344
  def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationConfig:
469
345
  """
@@ -14,15 +14,35 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from collections.abc import Callable
17
- from functools import partial
18
17
 
19
18
  from mct_quantizers import QuantizationMethod
19
+
20
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
20
21
  from model_compression_toolkit.logger import Logger
21
22
  from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
22
23
  from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
23
24
  symmetric_quantizer, uniform_quantizer
24
25
 
25
26
 
27
+ def get_activation_quantization_fn(activation_quantization_cfg: NodeActivationQuantizationConfig,
28
+ get_activation_quantization_fn_factory: Callable) -> Callable:
29
+ """
30
+ Get activation quantizer based on activation quantization configuration.
31
+
32
+ Args:
33
+ activation_quantization_cfg: activation quantization configuration.
34
+ get_activation_quantization_fn_factory: activation quantization functions factory.
35
+
36
+ Returns:
37
+ Activation quantizer that accepts a tensor and returns a quantized tensor.
38
+ """
39
+ quantizer_factory = get_activation_quantization_fn_factory(
40
+ activation_quantization_cfg.activation_quantization_method)
41
+ quantizer = quantizer_factory(activation_quantization_cfg.activation_n_bits,
42
+ activation_quantization_cfg.activation_quantization_params)
43
+ return quantizer
44
+
45
+
26
46
  def get_weights_quantization_fn(weights_quantization_method: QuantizationMethod) -> Callable:
27
47
  """
28
48
  Generate a function for weight quantization.
@@ -12,9 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import power_of_two_no_clipping_selection_min_max, \
16
- power_of_two_selection_histogram, power_of_two_selection_tensor
17
- from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import lut_kmeans_tensor
18
- from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import symmetric_no_clipping_selection_min_max
19
- from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import uniform_no_clipping_selection_min_max
15
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import (
16
+ power_of_two_no_clipping_selection_min_max, power_of_two_selection_histogram, power_of_two_selection_tensor)
17
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import (
18
+ lut_kmeans_tensor, lut_kmeans_histogram)
19
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import (
20
+ symmetric_no_clipping_selection_min_max, symmetric_selection_histogram, symmetric_selection_tensor)
21
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import (
22
+ uniform_no_clipping_selection_min_max, uniform_selection_histogram, uniform_selection_tensor)
20
23
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.outlier_filter import z_score_filter
@@ -13,17 +13,59 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import numpy as np
16
- from typing import Dict, Union, Optional, Tuple
16
+ from typing import Dict, Union, Optional, Tuple, Callable
17
17
 
18
18
  from mct_quantizers import QuantizationMethod
19
- from model_compression_toolkit.core import QuantizationErrorMethod
19
+
20
+ import model_compression_toolkit.core.common.quantization.quantization_params_generation as qpg
20
21
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
21
22
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
22
- from model_compression_toolkit.core.common.quantization import quantization_params_generation
23
23
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
24
24
  from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
25
+ from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationErrorMethod
26
+
27
+
28
+ def compute_activation_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
29
+ node_prior_info: NodePriorInfo,
30
+ out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
31
+ """
32
+ Compute the activations params for a given node in a graph according to a params function.
33
+
34
+ Args:
35
+ activation_quant_cfg: node's activation quantization configuration.
36
+ node_prior_info: Prior info collected for the node that is being quantized.
37
+ out_stats_container: Tensor containing output statistics of the node.
38
+
39
+ Returns:
40
+ The computed activation quantization params.
41
+ """
42
+ activation_quantization_params_fn = _get_activation_quantization_params_fn(
43
+ activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())
44
+
45
+ # Extract and filter histogram data from the statistics container.
46
+ bins_values, bins_counts = _get_histogram_data(activation_quant_cfg, out_stats_container)
47
+
48
+ # Retrieve the minimum and maximum values from the statistics container.
49
+ min_value, max_value = out_stats_container.get_min_max_values()
50
+
51
+ # Determine if the activations should be considered signed.
52
+ signed = _determine_signedness(activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts)
25
53
 
26
- def get_histogram_data(
54
+ # Compute and return the activation quantization parameters.
55
+ return activation_quantization_params_fn(
56
+ bins_values,
57
+ bins_counts,
58
+ activation_quant_cfg.l_p_value,
59
+ activation_quant_cfg.activation_n_bits,
60
+ min_value,
61
+ max_value,
62
+ min_threshold=activation_quant_cfg.min_threshold,
63
+ quant_error_method=activation_quant_cfg.activation_error_method,
64
+ is_signed=signed
65
+ )
66
+
67
+
68
+ def _get_histogram_data(
27
69
  activation_quant_cfg: NodeActivationQuantizationConfig,
28
70
  out_stats_container: BaseStatsCollector
29
71
  ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
@@ -38,7 +80,6 @@ def get_histogram_data(
38
80
  A tuple containing the filtered bins_values and bins_counts.
39
81
  """
40
82
  bins_values, bins_counts = None, None
41
-
42
83
  # If the statistics container collected the histogram, we start by filtering outliers using z threshold
43
84
  # filtering, and then computing the threshold based on the filtered histogram.
44
85
  if out_stats_container.require_collection():
@@ -46,14 +87,15 @@ def get_histogram_data(
46
87
  bins_values, bins_counts = out_stats_container.weighted_hc.get_histogram()
47
88
  else:
48
89
  bins_values, bins_counts = out_stats_container.hc.get_histogram()
49
- bins_counts = quantization_params_generation.z_score_filter(
90
+ bins_counts = qpg.z_score_filter(
50
91
  activation_quant_cfg.z_threshold,
51
92
  bins_values,
52
93
  bins_counts
53
94
  )
54
95
  return bins_values, bins_counts
55
96
 
56
- def determine_signedness(
97
+
98
+ def _determine_signedness(
57
99
  activation_quant_cfg: NodeActivationQuantizationConfig,
58
100
  nodes_prior_info: NodePriorInfo,
59
101
  min_value: float,
@@ -83,73 +125,37 @@ def determine_signedness(
83
125
  return np.any(bins_values[:-1][bins_counts > 0] < 0)
84
126
 
85
127
 
86
- def update_activation_quantization_params_fn(
87
- activation_quant_cfg: NodeActivationQuantizationConfig,
88
- nodes_prior_info: NodePriorInfo):
89
- """
90
- Update the activation quantization parameters function based on the quantization method
91
- and whether the node's output is bounded.
128
+ _activation_quant_params_fns = {
129
+ QuantizationMethod.POWER_OF_TWO: qpg.power_of_two_selection_histogram,
130
+ QuantizationMethod.SYMMETRIC: qpg.symmetric_selection_histogram,
131
+ QuantizationMethod.UNIFORM: qpg.uniform_selection_histogram,
132
+ QuantizationMethod.LUT_POT_QUANTIZER: qpg.lut_kmeans_histogram
133
+ }
134
+ _activation_no_clipping_quant_params_fns = {
135
+ QuantizationMethod.POWER_OF_TWO: qpg.power_of_two_no_clipping_selection_min_max,
136
+ QuantizationMethod.SYMMETRIC: qpg.symmetric_no_clipping_selection_min_max,
137
+ QuantizationMethod.UNIFORM: qpg.uniform_no_clipping_selection_min_max,
138
+ QuantizationMethod.LUT_POT_QUANTIZER: qpg.lut_kmeans_histogram
139
+ }
92
140
 
93
- Args:
94
- activation_quant_cfg: Node's activation quantization configuration.
95
- nodes_prior_info: Prior info collected for the node that is being quantized.
96
- """
97
- if nodes_prior_info.is_output_bounded():
98
- if activation_quant_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
99
- activation_quant_cfg.set_activation_quantization_params_fn(
100
- quantization_params_generation.power_of_two_no_clipping_selection_min_max
101
- )
102
- elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.SYMMETRIC:
103
- activation_quant_cfg.set_activation_quantization_params_fn(
104
- quantization_params_generation.symmetric_no_clipping_selection_min_max
105
- )
106
- elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.UNIFORM:
107
- activation_quant_cfg.set_activation_quantization_params_fn(
108
- quantization_params_generation.uniform_no_clipping_selection_min_max
109
- )
110
-
111
-
112
- def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
113
- nodes_prior_info: NodePriorInfo,
114
- out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
141
+
142
+ def _get_activation_quantization_params_fn(activation_quantization_method: QuantizationMethod,
143
+ no_clipping: bool) -> Callable:
115
144
  """
116
- Compute the activations params for a given node in a graph according to a params function.
145
+ Generate a function for finding activation quantization parameters.
117
146
 
118
147
  Args:
119
- activation_quant_cfg: node's activation quantization configuration.
120
- nodes_prior_info: Prior info collected for the node that is being quantized.
121
- out_stats_container: Tensor containing output statistics of the node.
148
+ activation_quantization_method: Which quantization method to use for activations.
149
+ no_clipping: Whether to use the no-clipping version of the quantizer (if available).
122
150
 
123
151
  Returns:
124
- The computed activation quantization params.
152
+ A function to find the quantization parameters.
125
153
  """
126
- # Update quantization parameters function based on output bounds and quantization method.
127
- update_activation_quantization_params_fn(activation_quant_cfg, nodes_prior_info)
128
-
129
- # Extract and filter histogram data from the statistics container.
130
- bins_values, bins_counts = get_histogram_data(activation_quant_cfg, out_stats_container)
131
-
132
- # Retrieve the minimum and maximum values from the statistics container.
133
- min_value, max_value = out_stats_container.get_min_max_values()
134
-
135
- # Determine if the activations should be considered signed.
136
- signed = determine_signedness(
137
- activation_quant_cfg,
138
- nodes_prior_info,
139
- min_value,
140
- bins_values,
141
- bins_counts
142
- )
143
-
144
- # Compute and return the activation quantization parameters.
145
- return activation_quant_cfg.activation_quantization_params_fn(
146
- bins_values,
147
- bins_counts,
148
- activation_quant_cfg.l_p_value,
149
- activation_quant_cfg.activation_n_bits,
150
- min_value,
151
- max_value,
152
- min_threshold=activation_quant_cfg.min_threshold,
153
- quant_error_method=activation_quant_cfg.activation_error_method,
154
- is_signed=signed
155
- )
154
+ if no_clipping:
155
+ params_fn = _activation_no_clipping_quant_params_fns.get(activation_quantization_method)
156
+ else:
157
+ params_fn = _activation_quant_params_fns.get(activation_quantization_method)
158
+ if params_fn is None:
159
+ raise ValueError(f"No parameter function found for the specified quantization method: "
160
+ "{activation_quantization_method}") # pragma: no cover
161
+ return params_fn
@@ -25,9 +25,9 @@ from model_compression_toolkit.core.common.framework_implementation import Frame
25
25
  from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
26
26
  HessianScoresGranularity
27
27
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
28
- import get_activations_qparams
28
+ import compute_activation_qparams
29
29
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
30
- get_weights_qparams
30
+ compute_weights_qparams
31
31
  from model_compression_toolkit.logger import Logger
32
32
 
33
33
 
@@ -119,21 +119,19 @@ def calculate_quantization_params(graph: Graph,
119
119
  mod_attr_cfg = copy.deepcopy(attr_cfg)
120
120
  mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE
121
121
 
122
- weights_params, output_channels_axis = get_weights_qparams(n.get_weights_by_keys(attr),
123
- candidate_qc.weights_quantization_cfg,
124
- mod_attr_cfg,
125
- output_channels_axis,
126
- node=n,
127
- hessian_info_service=hessian_info_service,
128
- num_hessian_samples=num_hessian_samples)
122
+ min_threshold = candidate_qc.weights_quantization_cfg.min_threshold
123
+ weights_params, output_channels_axis = compute_weights_qparams(n.get_weights_by_keys(attr),
124
+ mod_attr_cfg, output_channels_axis,
125
+ min_threshold=min_threshold, node=n,
126
+ hessian_info_service=hessian_info_service,
127
+ num_hessian_samples=num_hessian_samples)
129
128
  attr_cfg.weights_channels_axis = ChannelAxisMapping(output_channels_axis, attr_cfg.weights_channels_axis.input)
130
129
  attr_cfg.set_weights_quantization_param(weights_params)
131
130
 
132
131
  if n.is_activation_quantization_enabled():
133
132
  # If node's activations should be quantized as well, we compute its activation quantization parameters
134
- activation_params = get_activations_qparams(
135
- activation_quant_cfg=candidate_qc.activation_quantization_cfg,
136
- nodes_prior_info=n.prior_info,
133
+ activation_params = compute_activation_qparams(
134
+ activation_quant_cfg=candidate_qc.activation_quantization_cfg, node_prior_info=n.prior_info,
137
135
  out_stats_container=graph.get_out_stats_collector(n))
138
136
  # Create a NodeQuantizationConfig containing all quantization params and attach it to the node
139
137
  candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params)