mct-nightly 2.4.0.20250629.706__py3-none-any.whl → 2.4.0.20250701.185106__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/METADATA +16 -16
  2. {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/RECORD +75 -72
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
  5. model_compression_toolkit/core/common/framework_info.py +5 -32
  6. model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
  7. model_compression_toolkit/core/common/graph/base_graph.py +20 -37
  8. model_compression_toolkit/core/common/graph/base_node.py +13 -106
  9. model_compression_toolkit/core/common/graph/functional_node.py +1 -1
  10. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
  11. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
  12. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
  13. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
  14. model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
  15. model_compression_toolkit/core/common/network_editors/actions.py +4 -96
  16. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  17. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
  18. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
  19. model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
  20. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
  22. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
  23. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
  24. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
  25. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  26. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
  27. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
  28. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
  29. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
  30. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
  31. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
  32. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
  33. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
  34. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
  35. model_compression_toolkit/core/graph_prep_runner.py +31 -20
  36. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
  37. model_compression_toolkit/core/keras/default_framework_info.py +0 -11
  38. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
  39. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
  40. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
  41. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
  42. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
  43. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
  44. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
  45. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
  46. model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
  47. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  48. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
  49. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
  50. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
  51. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
  52. model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
  53. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
  54. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
  55. model_compression_toolkit/core/runner.py +1 -1
  56. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
  57. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  58. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
  59. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
  60. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
  61. model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
  62. model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
  63. model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
  64. model_compression_toolkit/quantization_preparation/__init__.py +14 -0
  65. model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
  66. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  67. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
  68. {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/WHEEL +0 -0
  69. {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/licenses/LICENSE.md +0 -0
  70. {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/top_level.txt +0 -0
  71. /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
  72. /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
  73. /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
  74. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
  75. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
  76. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
@@ -20,14 +20,8 @@ from typing import Callable
20
20
  from mct_quantizers import QuantizationMethod
21
21
  from model_compression_toolkit.core.common import Graph
22
22
  from model_compression_toolkit.logger import Logger
23
-
24
-
25
- from model_compression_toolkit.core.common.framework_info import get_fw_info
26
23
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
27
- from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
28
- get_activation_quantization_params_fn, get_weights_quantization_params_fn
29
- from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
30
- get_weights_quantization_fn
24
+
31
25
 
32
26
  _EditRule = namedtuple('EditRule', 'filter action')
33
27
 
@@ -174,47 +168,6 @@ class ChangeFinalActivationQuantConfigAttr(BaseAction):
174
168
  node.final_activation_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value)
175
169
 
176
170
 
177
- class ChangeQuantizationParamFunction(BaseAction):
178
- """
179
- Class ChangeQuantizationParamFunction to change a node's weights/activations quantization params function.
180
- """
181
-
182
- def __init__(self,
183
- attr_name: str = None,
184
- activation_quantization_params_fn: Callable = None,
185
- weights_quantization_params_fn: Callable = None):
186
- """
187
- Init a ChangeQuantizationParamFunction object.
188
-
189
- Args:
190
- attr_name: The weights attribute's name to set the weights quantization params function for (if setting weights params).
191
- activation_quantization_params_fn: a params function for a node's activations.
192
- weights_quantization_params_fn: a params function for a node's weights.
193
- """
194
- self.activation_quantization_params_fn = activation_quantization_params_fn
195
- self.weights_quantization_params_fn = weights_quantization_params_fn
196
- self.attr_name = attr_name
197
-
198
- def apply(self, node: BaseNode, graph):
199
- """
200
- Change the node's weights/activations quantization params function.
201
-
202
- Args:
203
- node: Node object to change its quantization params function.
204
- graph: Graph to apply the action on.
205
-
206
- Returns:
207
- The node after its quantization params function has been modified.
208
- """
209
- for nqc in node.candidates_quantization_cfg:
210
- if self.activation_quantization_params_fn is not None:
211
- nqc.activation_quantization_cfg.set_activation_quantization_params_fn(
212
- self.activation_quantization_params_fn)
213
- if self.weights_quantization_params_fn is not None:
214
- (nqc.weights_quantization_cfg.get_attr_config(self.attr_name)
215
- .set_weights_quantization_params_fn(self.weights_quantization_params_fn))
216
-
217
-
218
171
  class ChangeFinalActivationQuantizationMethod(BaseAction):
219
172
  """
220
173
  Class ChangeFinalActivationQuantizationMethod to change a node's weights/activations quantizer function.
@@ -243,16 +196,6 @@ class ChangeFinalActivationQuantizationMethod(BaseAction):
243
196
  """
244
197
 
245
198
  if self.activation_quantization_method is not None and node.final_activation_quantization_cfg is not None:
246
-
247
- activation_quantization_params_fn = get_activation_quantization_params_fn(
248
- self.activation_quantization_method)
249
-
250
- node.final_activation_quantization_cfg.set_activation_quantization_params_fn(
251
- activation_quantization_params_fn)
252
-
253
- activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(self.activation_quantization_method)
254
-
255
- node.final_activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
256
199
  node.final_activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
257
200
 
258
201
 
@@ -281,23 +224,12 @@ class ChangeCandidatesActivationQuantizationMethod(BaseAction):
281
224
  """
282
225
  if self.activation_quantization_method is not None:
283
226
  for qc in node.candidates_quantization_cfg:
284
- activation_quantization_params_fn = get_activation_quantization_params_fn(
285
- self.activation_quantization_method)
286
-
287
- qc.activation_quantization_cfg.set_activation_quantization_params_fn(activation_quantization_params_fn)
288
- activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(
289
- self.activation_quantization_method)
290
-
291
- if activation_quantization_fn is None:
292
- Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
293
-
294
- qc.activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
295
227
  qc.activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
296
228
 
297
229
 
298
230
  class ChangeFinalWeightsQuantizationMethod(BaseAction):
299
231
  """
300
- Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer function.
232
+ Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer method.
301
233
  """
302
234
 
303
235
  def __init__(self, attr_name: str, weights_quantization_method=None):
@@ -323,21 +255,8 @@ class ChangeFinalWeightsQuantizationMethod(BaseAction):
323
255
  """
324
256
 
325
257
  if self.weights_quantization_method is not None and node.final_weights_quantization_cfg is not None:
326
-
327
- weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)
328
-
329
- (node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
330
- .set_weights_quantization_params_fn(weights_quantization_params_fn))
331
-
332
- weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
333
-
334
- if weights_quantization_fn is None:
335
- Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
336
-
337
- (node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
338
- .set_weights_quantization_fn(weights_quantization_fn))
339
- node.final_weights_quantization_cfg.get_attr_config(self.attr_name).weights_quantization_method = \
340
- self.weights_quantization_method
258
+ attr_config = node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
259
+ attr_config.weights_quantization_method = self.weights_quantization_method
341
260
 
342
261
 
343
262
  class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
@@ -370,18 +289,7 @@ class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
370
289
 
371
290
  if self.weights_quantization_method is not None:
372
291
  for qc in node.candidates_quantization_cfg:
373
-
374
- weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)
375
-
376
292
  attr_qc = qc.weights_quantization_cfg.get_attr_config(self.attr_name)
377
- attr_qc.set_weights_quantization_params_fn(weights_quantization_params_fn)
378
-
379
- weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
380
-
381
- if weights_quantization_fn is None:
382
- Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
383
-
384
- attr_qc.set_weights_quantization_fn(weights_quantization_fn)
385
293
  attr_qc.weights_quantization_method = self.weights_quantization_method
386
294
 
387
295
 
@@ -19,8 +19,8 @@ from model_compression_toolkit.core.common import Graph
19
19
  from model_compression_toolkit.core.common.matchers.node_matcher import BaseNodeMatcher
20
20
  from model_compression_toolkit.logger import Logger
21
21
 
22
- from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
23
- from model_compression_toolkit.target_platform_capabilities.constants import POS_ATTR
22
+ from model_compression_toolkit.core.common.graph.base_node import WeightAttrT, BaseNode
23
+ from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
24
24
 
25
25
 
26
26
  @dataclass
@@ -95,7 +95,7 @@ class BitWidthConfig:
95
95
  for attr, bit_width, filter in zip (attrs, bit_widths, filters):
96
96
  self.manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection(filter, bit_width, attr)]
97
97
 
98
- def get_nodes_to_manipulate_activation_bit_widths(self, graph: Graph) -> Dict:
98
+ def get_nodes_activation_bit_widths(self, graph: Graph) -> Dict[BaseNode, int]:
99
99
  """
100
100
  Retrieve nodes from the graph that need their bit-widths for activation changed according to the manual bit-width selections.
101
101
 
@@ -108,7 +108,7 @@ class BitWidthConfig:
108
108
  activation_nodes_to_change_bit_width = self._construct_node_to_new_activation_bit_mapping(graph)
109
109
  return activation_nodes_to_change_bit_width
110
110
 
111
- def get_nodes_to_manipulate_weights_bit_widths(self, graph: Graph) -> Dict:
111
+ def get_nodes_weights_bit_widths(self, graph: Graph) -> Dict[BaseNode, Dict[str, int]]:
112
112
  """
113
113
  Retrieve nodes from the graph that need their bit-widths for weights changed according to the manual bit-width selections.
114
114
 
@@ -166,7 +166,7 @@ class BitWidthConfig:
166
166
  attrs = BitWidthConfig._expand_to_list_core(filters, attrs)
167
167
  return attrs, bit_widths, filters
168
168
 
169
- def _construct_node_to_new_activation_bit_mapping(self, graph) -> Dict:
169
+ def _construct_node_to_new_activation_bit_mapping(self, graph) -> Dict[BaseNode, int]:
170
170
  """
171
171
  Retrieve nodes from the graph that need their activation bit-widths changed according to the manual bit-width selections.
172
172
 
@@ -192,7 +192,7 @@ class BitWidthConfig:
192
192
  unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
193
193
  return unit_nodes_to_change_bit_width
194
194
 
195
- def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict:
195
+ def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict[BaseNode, Dict[str, int]]:
196
196
  """
197
197
  Retrieve nodes from the graph that need their weights bit-widths changed according to the manual bit-width selections.
198
198
 
@@ -212,7 +212,7 @@ class BitWidthConfig:
212
212
  f"to change their bit width to {manual_bit_width_selection.bit_width}.")
213
213
 
214
214
  for n in filtered_nodes:
215
- attr_to_change_bit_width = []
215
+ attr_to_change_bit_width = {}
216
216
 
217
217
  attrs_str = n.get_node_weights_attributes()
218
218
  if len(attrs_str) == 0:
@@ -225,8 +225,8 @@ class BitWidthConfig:
225
225
  attr.append(attr_str)
226
226
  # this is a positional attribute, so it needs to be handled separately.
227
227
  # Search manual_bit_width_selection's attribute that contain the POS_ATTR string.
228
- elif isinstance(attr_str, int) and POS_ATTR in manual_bit_width_selection.attr:
229
- attr.append(POS_ATTR)
228
+ elif isinstance(attr_str, int) and POSITIONAL_ATTR in manual_bit_width_selection.attr:
229
+ attr.append(POSITIONAL_ATTR)
230
230
  if len(attr) == 0:
231
231
  Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.')
232
232
 
@@ -239,7 +239,7 @@ class BitWidthConfig:
239
239
  f"Node {n} has an existing manual bit width configuration of {manual_bit_width_selection.attr}."
240
240
  f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
241
241
 
242
- attr_to_change_bit_width.append([manual_bit_width_selection.bit_width, manual_bit_width_selection.attr])
242
+ attr_to_change_bit_width[manual_bit_width_selection.attr] = manual_bit_width_selection.bit_width
243
243
  unit_nodes_to_change_bit_width.update({n: attr_to_change_bit_width})
244
244
 
245
245
  return unit_nodes_to_change_bit_width
@@ -12,73 +12,133 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Callable, List, Tuple
15
+ import copy
16
+ from dataclasses import dataclass, InitVar
17
+ from typing import Callable, List, Optional
16
18
 
17
- from model_compression_toolkit.core import QuantizationConfig
18
- from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
19
19
  from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
20
- NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
21
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig, \
22
- OpQuantizationConfig
23
- from model_compression_toolkit.logger import Logger
20
+ NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig, ActivationQuantizationMode
24
21
 
25
22
 
26
- ##########################################
27
- # Every node holds a quantization configuration
28
- # for its weights quantization, and a different quantization
29
- # configuration for its activation quantization configuration.
30
- ##########################################
31
-
23
+ @dataclass(eq=True)
32
24
  class CandidateNodeQuantizationConfig(BaseNodeQuantizationConfig):
33
25
  """
34
- Class for representing candidate node configuration, which includes weights and activation configuration combined.
26
+ Candidate quantization configuration for a node.
35
27
  """
28
+ activation_quantization_cfg: NodeActivationQuantizationConfig
29
+ # TODO irena: None is passed in several places, need to check if it's handled properly or it's only passed in cases
30
+ # that do not affect anything (my guess is it's the second).
31
+ # I think in general it makes more sense to set it to None when there are no weights, and maybe when all weights
32
+ # are unquantized, and handle it properly everywhere.
33
+ weights_quantization_cfg: Optional[NodeWeightsQuantizationConfig]
34
+
35
+
36
+ # TODO irena: currently all code still looks at candidates_quantization_cfg as previously, so this is just an initial
37
+ # implementation. For now base config is completely separated from candidates (base config must be equal to one of the
38
+ # candidates, but we create a separate copy), and updating in place is allowed. Also we require quantization mode to
39
+ # be identical between all configs.
40
+ @dataclass
41
+ class NodeQuantizationConfig:
42
+ # quantization config for single precision
43
+ base_quantization_cfg: CandidateNodeQuantizationConfig
44
+ # quantization candidate configs for mixed precision
45
+ candidates_quantization_cfg: List[CandidateNodeQuantizationConfig]
46
+
47
+ validate: InitVar[bool] = True
48
+
49
+ def update_all(self, update_fn: Callable[[CandidateNodeQuantizationConfig], None], remove_duplicates: bool = True):
50
+ """
51
+ Apply update function on the base config and all candidates configs.
52
+
53
+ Args:
54
+ update_fn: function to apply.
55
+ remove_duplicates: remove duplicate candidates.
56
+ """
57
+ if self.base_quantization_cfg:
58
+ update_fn(self.base_quantization_cfg)
59
+ for cfg in self.candidates_quantization_cfg:
60
+ update_fn(cfg)
61
+ if remove_duplicates:
62
+ self.remove_duplicates()
36
63
 
37
- def __init__(self,
38
- qc: QuantizationConfig = None,
39
- op_cfg: OpQuantizationConfig = None,
40
- activation_quantization_cfg: NodeActivationQuantizationConfig = None,
41
- activation_quantization_fn: Callable = None,
42
- activation_quantization_params_fn: Callable = None,
43
- weights_quantization_cfg: NodeWeightsQuantizationConfig = None,
44
- weights_channels_axis: ChannelAxisMapping = None,
45
- node_attrs_list: List[str] = None):
64
+ def update_activation_quantization_mode(self, mode: ActivationQuantizationMode):
46
65
  """
66
+ Update activation quantization mode for the base config and all candidates configs.
47
67
 
48
68
  Args:
49
- qc: QuantizationConfig to create the node's config from.
50
- op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
51
- activation_quantization_cfg: An option to pass a NodeActivationQuantizationConfig to create a new config from.
52
- activation_quantization_fn: Function to use when quantizing the node's activations.
53
- activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
54
- weights_quantization_cfg: An option to pass a NodeWeightsQuantizationConfig to create a new config from.
55
- weights_channels_axis: Axis to quantize a node's weights attribute when quantizing per-channel.
56
- node_attrs_list: A list of the node's weights attributes names.
69
+ mode: quantization mode.
70
+ """
71
+ def fn(c):
72
+ c.activation_quantization_cfg.quant_mode = mode
73
+
74
+ self.update_all(fn)
75
+
76
+ def disable_weights_quantization(self):
77
+ """
78
+ Disable all weights quantization for the base config and all candidates configs.
79
+ """
80
+ self.update_all(lambda c: c.weights_quantization_cfg.disable_all_weights_quantization())
81
+
82
+ def get_activation_quant_mode(self) -> ActivationQuantizationMode:
83
+ """
84
+ Retrieve activation quantization mode.
85
+
86
+ Returns:
87
+ Activation quantization mode.
88
+
89
+ Raises:
90
+ ValueError if not all candidates contain the same mode.
91
+ """
92
+ self._validate_consistent_activation_quant_mode()
93
+ return self.base_quantization_cfg.activation_quantization_cfg.quant_mode
94
+
95
+ def remove_duplicates(self):
96
+ """
97
+ Remove duplicate candidates. First candidate among duplicates is kept, and the order is preserved.
57
98
  """
99
+ uniq_qcs = []
100
+ for qc in self.candidates_quantization_cfg:
101
+ if qc not in uniq_qcs:
102
+ uniq_qcs.append(qc)
103
+ self.candidates_quantization_cfg = uniq_qcs
58
104
 
59
- if activation_quantization_cfg is not None:
60
- self.activation_quantization_cfg = activation_quantization_cfg
61
- else:
62
- if any(v is None for v in (qc, op_cfg, activation_quantization_fn, activation_quantization_params_fn)): # pragma: no cover
63
- Logger.critical(
64
- "Missing required arguments to initialize a node activation quantization configuration. "
65
- "Ensure QuantizationConfig, OpQuantizationConfig, activation quantization function, "
66
- "and parameters function are provided.")
67
- self.activation_quantization_cfg = (
68
- NodeActivationQuantizationConfig(qc=qc,
69
- op_cfg=op_cfg,
70
- activation_quantization_fn=activation_quantization_fn,
71
- activation_quantization_params_fn=activation_quantization_params_fn))
72
-
73
- if weights_quantization_cfg is not None:
74
- self.weights_quantization_cfg = weights_quantization_cfg
75
- elif all(v is not None for v in (qc, op_cfg, node_attrs_list)):
76
- self.weights_quantization_cfg = NodeWeightsQuantizationConfig(qc=qc,
77
- op_cfg=op_cfg,
78
- weights_channels_axis=weights_channels_axis,
79
- node_attrs_list=node_attrs_list)
80
- else:
81
- self.weights_quantization_cfg = None
82
- Logger.debug("Setting weights quantization config as None during CandidateNodeQuantizationConfig creation."
83
- "Notice, this should happen only for FLN nodes.")
105
+ def __post_init__(self, validate=True):
106
+ if validate:
107
+ if not any(self.base_quantization_cfg == qc for qc in self.candidates_quantization_cfg):
108
+ raise ValueError('Candidates should contain the base config.')
109
+ self._validate_consistent_activation_quant_mode()
110
+ self._validate_consistent_weights_quant_mode()
111
+ # TODO irena
112
+ # for now make sure they are separate objects so that one doesnt inadvertently modify the other
113
+ if any(self.base_quantization_cfg is qc for qc in self.candidates_quantization_cfg):
114
+ self.base_quantization_cfg = copy.deepcopy(self.base_quantization_cfg)
84
115
 
116
+ def _validate_consistent_activation_quant_mode(self):
117
+ """
118
+ Validate that base config and all candidates configs contain identical activation quantization mode.
119
+
120
+ Raises:
121
+ ValueError if activation quantization mode is not consistent.
122
+ """
123
+ activation_quant_mode = self.base_quantization_cfg.activation_quantization_cfg.quant_mode
124
+ if any(qc.activation_quantization_cfg.quant_mode != activation_quant_mode
125
+ for qc in self.candidates_quantization_cfg):
126
+ raise ValueError('Quantization candidates with different quantization modes are not currently supported.')
127
+
128
+ def _validate_consistent_weights_quant_mode(self):
129
+ """
130
+ Validate that base config and all candidates configs contain identical weights quantization mode per attribute,
131
+ i.e. quantization for each attribute should either be enabled in all configs, or disabled in all configs.
132
+
133
+ Raises:
134
+ ValueError if weights quantization is not consistent.
135
+ """
136
+ def get_weights_mode(qc):
137
+ # in graph fuser weights_quantization_cfg is set to None
138
+ if qc.weights_quantization_cfg is None:
139
+ return None
140
+ return {attr: attr_cfg.enable_weights_quantization for attr, attr_cfg
141
+ in qc.weights_quantization_cfg.get_all_weight_attrs_configs().items()}
142
+ if any(get_weights_mode(self.base_quantization_cfg) != get_weights_mode(qc)
143
+ for qc in self.candidates_quantization_cfg):
144
+ raise ValueError('Quantization candidates with different quantization modes are not currently supported.')
@@ -33,7 +33,7 @@ def filter_nodes_candidates(graph: Graph):
33
33
  """
34
34
  nodes = list(graph.nodes)
35
35
  for n in nodes:
36
- n.candidates_quantization_cfg = filter_node_candidates(node=n)
36
+ n.quantization_cfg.candidates_quantization_cfg = filter_node_candidates(node=n)
37
37
 
38
38
  return graph
39
39