mct-nightly 2.4.0.20250924.535__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -14,37 +14,31 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import copy
17
- from typing import Dict, Any, Tuple, List, Type, Union, NamedTuple
17
+ from typing import Dict, Any, Tuple, List, Type, Union
18
18
 
19
19
  import numpy as np
20
20
 
21
- from model_compression_toolkit.core.common.framework_info import get_fw_info, ChannelAxisMapping
22
21
  from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \
23
22
  ACTIVATION_N_BITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
24
- from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import NodeQuantizationConfig
25
23
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
26
24
  ActivationQuantizationMode
27
25
  from model_compression_toolkit.logger import Logger
26
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
27
+ OpQuantizationConfig
28
+ from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
28
29
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
30
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
31
+ FrameworkQuantizationCapabilities
29
32
 
30
33
 
31
34
  WeightAttrT = Union[str, int]
32
35
 
33
36
 
34
- class NodeFrameworkInfo(NamedTuple):
35
- """
36
- Node's specific framework information.
37
- """
38
- channel_axis: ChannelAxisMapping
39
- out_channel_axis: int
40
- minmax: Tuple[float, float]
41
- kernel_attr: str
42
-
43
-
44
37
  class BaseNode:
45
38
  """
46
39
  Class to represent a node in a graph that represents the model.
47
40
  """
41
+
48
42
  def __init__(self,
49
43
  name: str,
50
44
  framework_attr: Dict[str, Any],
@@ -90,84 +84,28 @@ class BaseNode:
90
84
  self.inputs_as_list = inputs_as_list
91
85
  self.final_weights_quantization_cfg = None
92
86
  self.final_activation_quantization_cfg = None
93
- self.quantization_cfg: NodeQuantizationConfig = None
87
+ self.candidates_quantization_cfg = None
94
88
  self.prior_info = None
95
89
  self.has_activation = has_activation
96
90
  self.is_custom = is_custom
97
- self.node_fw_info = self._get_fw_node_attrs(layer_class, framework_attr)
98
-
99
- def _get_fw_node_attrs(self, node_type, framework_attr):
100
- fw_info = get_fw_info()
101
- return None if fw_info is None else NodeFrameworkInfo(
102
- fw_info.get_kernel_channels(node_type),
103
- fw_info.get_out_channel_axis(node_type),
104
- fw_info.get_layer_min_max(node_type, framework_attr),
105
- fw_info.get_kernel_op_attribute(node_type),
106
- )
107
-
108
- def _assert_fw_info_exists(self):
109
- """
110
- Verify NodeFrameworkInfo was initialized.
111
- """
112
- assert self.node_fw_info is not None, f"NodeFrameworkInfo not initialized for node {self.name}" # pragma: no cover
113
-
114
- @property
115
- def channel_axis(self) -> ChannelAxisMapping:
116
- """
117
- Extract channels axis from node's NodeFrameworkInfo.
118
-
119
- Returns:
120
- Channels axis named tuple.
121
- """
122
- self._assert_fw_info_exists()
123
- return self.node_fw_info.channel_axis
124
-
125
- @property
126
- def out_channel_axis(self) -> int:
127
- """
128
- Extract output channel axis from node's NodeFrameworkInfo.
129
-
130
- Returns:
131
- Output channel axis.
132
- """
133
- self._assert_fw_info_exists()
134
- return self.node_fw_info.out_channel_axis
135
91
 
136
92
  @property
137
- def minmax(self) -> Tuple[float, float]:
93
+ def type(self):
138
94
  """
139
- Extract expected min-max activation values from node's NodeFrameworkInfo.
140
-
95
+ A function to get the node's layer_class op for convenient comparison
141
96
  Returns:
142
- A tuple of min-max values.
143
- """
144
- self._assert_fw_info_exists()
145
- return self.node_fw_info.minmax
146
-
147
- @property
148
- def kernel_attr(self) -> str:
97
+ the node's layer_class
149
98
  """
150
- Extract kernel name from node's NodeFrameworkInfo.
99
+ return self.layer_class
151
100
 
152
- Returns:
153
- Kernel name.
101
+ def get_has_activation(self):
154
102
  """
155
- self._assert_fw_info_exists()
156
- return self.node_fw_info.kernel_attr
103
+ Returns has_activation attribute.
157
104
 
158
- @property
159
- def candidates_quantization_cfg(self):
160
- assert self.quantization_cfg
161
- return self.quantization_cfg.candidates_quantization_cfg
105
+ Returns: Whether the node has activation to quantize.
162
106
 
163
- @property
164
- def type(self):
165
- """
166
- A function to get the node's layer_class op for convenient comparison
167
- Returns:
168
- the node's layer_class
169
107
  """
170
- return self.layer_class
108
+ return self.has_activation
171
109
 
172
110
  @property
173
111
  def has_positional_weights(self):
@@ -195,31 +133,19 @@ class BaseNode:
195
133
  Returns: Whether node activation quantization is enabled or not.
196
134
  """
197
135
  return self._is_single_quant_mode(ActivationQuantizationMode.QUANT)
198
-
199
- def is_fln_no_quantization(self) -> bool:
136
+
137
+ def is_fln_quantization(self) -> bool:
200
138
  """
201
- Returns: Whether node is FLN no quantization.
139
+ Returns: Whether the node's activation quantization is FLN
202
140
  """
203
- return self._is_single_quant_mode(ActivationQuantizationMode.FLN_NO_QUANT)
204
-
141
+ return self._is_single_quant_mode(ActivationQuantizationMode.FLN_QUANT)
142
+
205
143
  def is_quantization_preserving(self) -> bool:
206
144
  """
207
145
  Returns: Whether node activation quantization information is preserved from its inputs.
208
146
  """
209
147
  return self._is_single_quant_mode(ActivationQuantizationMode.PRESERVE_QUANT)
210
148
 
211
- def is_no_quantization(self) -> bool:
212
- """
213
- Returns: Whether node is no quantization.
214
- """
215
- return self._is_single_quant_mode(ActivationQuantizationMode.NO_QUANT)
216
-
217
- def is_fln_quantization(self) -> bool:
218
- """
219
- Returns: Whether the node's activation quantization is FLN
220
- """
221
- return self._is_single_quant_mode(ActivationQuantizationMode.FLN_QUANT)
222
-
223
149
  def is_weights_quantization_enabled(self, attr_name: str) -> bool:
224
150
  """
225
151
  Checks whether a node's weights attribute quantization is enabled.
@@ -372,11 +298,14 @@ class BaseNode:
372
298
 
373
299
  return input_tensors
374
300
 
375
- def get_num_parameters(self) -> Tuple[int,int]:
301
+ def get_num_parameters(self, fw_info) -> Tuple[int,int]:
376
302
  """
377
303
  Compute the number of parameters the node holds.
378
304
  It returns a tuple: Number of quantized parameters, number of float parameters.
379
305
 
306
+ Args:
307
+ fw_info: Framework info to decide which attributes should be quantized.
308
+
380
309
  Returns:
381
310
  A tuple of (Number of quantized parameters, number of float parameters).
382
311
 
@@ -385,10 +314,11 @@ class BaseNode:
385
314
 
386
315
  q_node_num_params = 0
387
316
 
388
- if self.kernel_attr is not None:
389
- w = self.get_weights_by_keys(self.kernel_attr)
390
- if w is not None:
391
- q_node_num_params += w.flatten().shape[0]
317
+ for attr in fw_info.get_kernel_op_attributes(self.type):
318
+ if attr is not None:
319
+ w = self.get_weights_by_keys(attr)
320
+ if w is not None:
321
+ q_node_num_params += w.flatten().shape[0]
392
322
 
393
323
  f_node_num_params = total_node_params - q_node_num_params
394
324
 
@@ -396,19 +326,22 @@ class BaseNode:
396
326
  assert int(f_node_num_params) == f_node_num_params
397
327
  return int(q_node_num_params), int(f_node_num_params)
398
328
 
399
- def get_memory_bytes(self) -> float:
329
+ def get_memory_bytes(self, fw_info) -> float:
400
330
  """
401
331
  Compute the number of bytes the node's memory requires.
402
332
 
333
+ Args:
334
+ fw_info: Framework info to decide which attributes should be quantized.
335
+
403
336
  Returns: Number of bytes the node's memory requires.
404
337
 
405
338
  """
406
339
  # TODO: this method is used for tensorboard only. If we want to enable logging of other attributes memory
407
340
  # then it needs to be modified. But, it might be better to remove this method from the BaseNode completely.
408
- kernel_attr = self.kernel_attr
341
+ kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0]
409
342
  if kernel_attr is None:
410
343
  return 0
411
- q_params, f_params = self.get_num_parameters()
344
+ q_params, f_params = self.get_num_parameters(fw_info)
412
345
  if self.final_weights_quantization_cfg is None: # float coefficients
413
346
  memory = (f_params+q_params) * FP32_BYTES_PER_PARAMETER
414
347
  else:
@@ -418,12 +351,15 @@ class BaseNode:
418
351
 
419
352
  return memory
420
353
 
421
- def get_unified_weights_candidates_dict(self) -> Dict[str, Any]:
354
+ def get_unified_weights_candidates_dict(self, fw_info) -> Dict[str, Any]:
422
355
  """
423
356
  In Mixed-Precision, a node's kernel can have multiple candidates for weights quantization configuration.
424
357
  In order to display a single view of a node (for example, for logging in TensorBoard) we need a way
425
358
  to create a single dictionary from all candidates.
426
- This method is aimed to build such a unified dictionary for a node.
359
+ This method is aimed to build such an unified dictionary for a node.
360
+
361
+ Args:
362
+ fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
427
363
 
428
364
  Returns: A dictionary containing information from node's weight quantization configuration candidates.
429
365
 
@@ -433,7 +369,7 @@ class BaseNode:
433
369
  # We assume that only the kernel attribute have more than one candidate, since we only allow to
434
370
  # quantize the kernel using mixed precision
435
371
  # TODO: need to modify if we want to present a unified config for other attributes
436
- kernel_attr = self.kernel_attr
372
+ kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0]
437
373
  if kernel_attr is None:
438
374
  # This node doesn't have a kernel attribute
439
375
  return {}
@@ -501,13 +437,20 @@ class BaseNode:
501
437
  candidates = self.get_all_weights_attr_candidates(attr)
502
438
  return all(candidate == candidates[0] for candidate in candidates[1:])
503
439
 
504
- def has_kernel_weight_to_quantize(self):
440
+ def has_kernel_weight_to_quantize(self, fw_info):
505
441
  """
506
- Checks whether the node has kernel attribute that need to be quantized according to the node's framework info.
442
+ Checks whether the node has kernel attribute that need to be quantized according to the framework info.
443
+
444
+ Args:
445
+ fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
507
446
 
508
- Returns: Whether the node's kernel need to be quantized.
447
+ Returns: Whether the node has weights that need to be quantized.
509
448
  """
510
- return self.kernel_attr and self.get_weights_by_keys(self.kernel_attr) is not None
449
+ attrs = fw_info.get_kernel_op_attributes(self.type)
450
+ for attr in attrs:
451
+ if attr and self.get_weights_by_keys(attr) is not None:
452
+ return True
453
+ return False
511
454
 
512
455
  def has_any_weight_attr_to_quantize(self) -> bool:
513
456
  """
@@ -625,9 +568,8 @@ class BaseNode:
625
568
  Returns: True if the node has at list one quantization configuration candidate with activation quantization enabled.
626
569
  """
627
570
 
628
- return (len(self.candidates_quantization_cfg) > 0 and
629
- any([c.activation_quantization_cfg.enable_activation_quantization
630
- for c in self.candidates_quantization_cfg]))
571
+ return len(self.candidates_quantization_cfg) > 0 and \
572
+ any([c.activation_quantization_cfg.enable_activation_quantization for c in self.candidates_quantization_cfg])
631
573
 
632
574
  def get_all_weights_attr_candidates(self, attr: str) -> List[WeightsAttrQuantizationConfig]:
633
575
  """
@@ -643,6 +585,79 @@ class BaseNode:
643
585
  # the inner method would log an exception.
644
586
  return [c.weights_quantization_cfg.get_attr_config(attr) for c in self.candidates_quantization_cfg]
645
587
 
588
+ def get_qco(self, fqc: FrameworkQuantizationCapabilities) -> QuantizationConfigOptions:
589
+ """
590
+ Get the QuantizationConfigOptions of the node according
591
+ to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformCapabilities.
592
+
593
+ Args:
594
+ fqc: FQC to extract the QuantizationConfigOptions for the node.
595
+
596
+ Returns:
597
+ QuantizationConfigOptions of the node.
598
+ """
599
+
600
+ if fqc is None:
601
+ Logger.critical(f'Can not retrieve QC options for None FQC') # pragma: no cover
602
+
603
+ for fl, qco in fqc.filterlayer2qco.items():
604
+ if self.is_match_filter_params(fl):
605
+ return qco
606
+ # Extract qco with is_match_type to overcome mismatch of function types in TF 2.15
607
+ matching_qcos = [_qco for _type, _qco in fqc.layer2qco.items() if self.is_match_type(_type)]
608
+ if matching_qcos:
609
+ if all([_qco == matching_qcos[0] for _qco in matching_qcos]):
610
+ return matching_qcos[0]
611
+ else:
612
+ Logger.critical(f"Found duplicate qco types for node '{self.name}' of type '{self.type}'!") # pragma: no cover
613
+ return fqc.tpc.default_qco
614
+
615
+ def filter_node_qco_by_graph(self, fqc: FrameworkQuantizationCapabilities,
616
+ next_nodes: List, node_qc_options: QuantizationConfigOptions
617
+ ) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
618
+ """
619
+ Filter quantization config options that don't match the graph.
620
+ A node may have several quantization config options with 'activation_n_bits' values, and
621
+ the next nodes in the graph may support different bit-width as input activation. This function
622
+ filters out quantization config that don't comply to these attributes.
623
+
624
+ Args:
625
+ fqc: FQC to extract the QuantizationConfigOptions for the next nodes.
626
+ next_nodes: Output nodes of current node.
627
+ node_qc_options: Node's QuantizationConfigOptions.
628
+
629
+ Returns:
630
+
631
+ """
632
+ # Filter quantization config options that don't match the graph.
633
+ _base_config = node_qc_options.base_config
634
+ _node_qc_options = node_qc_options.quantization_configurations
635
+ if len(next_nodes):
636
+ next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
637
+ next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
638
+ for qc_opts in next_nodes_qc_options
639
+ for op_cfg in qc_opts.quantization_configurations])
640
+
641
+ # Filter node's QC options that match next nodes input bit-width.
642
+ _node_qc_options = [_option for _option in _node_qc_options
643
+ if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
644
+ if len(_node_qc_options) == 0:
645
+ Logger.critical(f"Graph doesn't match FQC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
646
+
647
+ # Verify base config match
648
+ if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
649
+ for qc_opt in next_nodes_qc_options]):
650
+ # base_config activation bits doesn't match next node supported input bit-width -> replace with
651
+ # a qco from quantization_configurations with maximum activation bit-width.
652
+ if len(_node_qc_options) > 0:
653
+ output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
654
+ _base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
655
+ Logger.warning(f"Node {self} base quantization config changed to match Graph and FQC configuration.\nCause: {self} -> {next_nodes}.")
656
+ else:
657
+ Logger.critical(f"Graph doesn't match FQC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
658
+
659
+ return _base_config, _node_qc_options
660
+
646
661
  def is_match_type(self, _type: Type) -> bool:
647
662
  """
648
663
  Check if input type matches the node type, either in instance type or in type name.
@@ -675,7 +690,7 @@ class BaseNode:
675
690
  return False
676
691
 
677
692
  # Get attributes from node to filter
678
- layer_config = self.framework_attr.copy()
693
+ layer_config = self.framework_attr
679
694
  if hasattr(self, "op_call_kwargs"):
680
695
  layer_config.update(self.op_call_kwargs)
681
696
 
@@ -709,7 +724,7 @@ class BaseNode:
709
724
  Logger.critical(f"SIMD is expected to be a non-positive integer but found: {_simd}")
710
725
  return _simd
711
726
 
712
- def sort_node_candidates(self):
727
+ def sort_node_candidates(self, fw_info):
713
728
  """
714
729
  Sorts the node candidates.
715
730
  We assume that the candidates are ordered in the following way (for mixed precision purposes):
@@ -718,12 +733,17 @@ class BaseNode:
718
733
  - If the node doesn't have a kernel we only consider the candidate activation number of bits to sort
719
734
  the candidates in descending order.
720
735
  The operation is done inplace.
736
+
737
+ Args:
738
+ fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
739
+
721
740
  """
722
- if self.quantization_cfg.candidates_quantization_cfg is not None:
723
- if self.kernel_attr is not None:
724
- self.quantization_cfg.candidates_quantization_cfg.sort(
725
- key=lambda c: (c.weights_quantization_cfg.get_attr_config(self.kernel_attr).weights_n_bits,
741
+ if self.candidates_quantization_cfg is not None:
742
+ kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0]
743
+ if kernel_attr is not None:
744
+ self.candidates_quantization_cfg.sort(
745
+ key=lambda c: (c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits,
726
746
  c.activation_quantization_cfg.activation_n_bits), reverse=True)
727
747
  else:
728
- self.quantization_cfg.candidates_quantization_cfg.sort(
729
- key=lambda c: c.activation_quantization_cfg.activation_n_bits, reverse=True)
748
+ self.candidates_quantization_cfg.sort(key=lambda c: c.activation_quantization_cfg.activation_n_bits,
749
+ reverse=True)
@@ -1,21 +1,6 @@
1
- # Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
1
  from typing import Dict, Any, Tuple, Type, List, Union
17
2
 
18
- from model_compression_toolkit.core.common.framework_info import get_fw_info
3
+ from model_compression_toolkit.verify_packages import FOUND_TF
19
4
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
20
5
  import numpy as np
21
6
 
@@ -60,7 +45,6 @@ class FunctionalNode(BaseNode):
60
45
  inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
61
46
  has_activation: Whether the node has activations that we might want to quantize.
62
47
  tensor_input_allocs: A list of indices and strings for allocations input tensors in the node's args and kwargs.
63
-
64
48
  """
65
49
 
66
50
  super().__init__(name,
@@ -79,7 +63,6 @@ class FunctionalNode(BaseNode):
79
63
  self.op_call_args = list(op_call_args)
80
64
  self.functional_op = functional_op
81
65
  self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs
82
- self.node_fw_info = self._get_fw_node_attrs(functional_op, framework_attr)
83
66
 
84
67
  @property
85
68
  def type(self):
@@ -103,4 +86,4 @@ class FunctionalNode(BaseNode):
103
86
 
104
87
  """
105
88
  names_match = _type.__name__ == self.type.__name__
106
- return names_match or super().is_match_type(_type)
89
+ return super().is_match_type(_type) or names_match
@@ -15,11 +15,13 @@
15
15
  import abc
16
16
  import uuid
17
17
 
18
+ from model_compression_toolkit.core import FrameworkInfo
18
19
  from model_compression_toolkit.constants import VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX, \
19
20
  VIRTUAL_WEIGHTS_SUFFIX, VIRTUAL_ACTIVATION_SUFFIX, FLOAT_BITWIDTH
21
+ from model_compression_toolkit.core.common.framework_info import DEFAULT_KERNEL_ATTRIBUTES
20
22
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
21
23
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
22
- CandidateNodeQuantizationConfig, NodeQuantizationConfig
24
+ CandidateNodeQuantizationConfig
23
25
  from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
24
26
 
25
27
 
@@ -75,11 +77,8 @@ class VirtualSplitWeightsNode(VirtualSplitNode):
75
77
 
76
78
  self.name = origin_node.name + VIRTUAL_WEIGHTS_SUFFIX
77
79
 
78
- self.quantization_cfg = NodeQuantizationConfig(
79
- candidates_quantization_cfg=origin_node.get_unique_weights_candidates(kernel_attr),
80
- base_quantization_cfg=None, validate=False
81
- )
82
- for c in self.quantization_cfg.candidates_quantization_cfg:
80
+ self.candidates_quantization_cfg = origin_node.get_unique_weights_candidates(kernel_attr)
81
+ for c in self.candidates_quantization_cfg:
83
82
  c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
84
83
  c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
85
84
 
@@ -108,9 +107,10 @@ class VirtualSplitActivationNode(VirtualSplitNode):
108
107
  self.weights = {}
109
108
  self.layer_class = activation_class
110
109
 
111
- self.quantization_cfg = NodeQuantizationConfig(candidates_quantization_cfg=origin_node.get_unique_activation_candidates(),
112
- base_quantization_cfg=None, validate=False)
113
- self.quantization_cfg.disable_weights_quantization()
110
+ self.candidates_quantization_cfg = origin_node.get_unique_activation_candidates()
111
+ for c in self.candidates_quantization_cfg:
112
+ c.weights_quantization_cfg.enable_weights_quantization = False
113
+ c.weights_quantization_cfg.weights_n_bits = FLOAT_BITWIDTH
114
114
 
115
115
 
116
116
  class VirtualActivationWeightsNode(VirtualNode):
@@ -128,23 +128,28 @@ class VirtualActivationWeightsNode(VirtualNode):
128
128
 
129
129
  def __init__(self,
130
130
  act_node: BaseNode,
131
- weights_node: BaseNode):
131
+ weights_node: BaseNode,
132
+ fw_info: FrameworkInfo):
132
133
  """
133
134
  Init a VirtualActivationWeightsNode object.
134
135
 
135
136
  Args:
136
137
  act_node: The original activation node.
137
138
  weights_node: The original weights node.
139
+ fw_info: A FrameworkInfo object with framework specific information.
138
140
  """
139
141
  # Validate weights node
142
+ kernel_attrs = fw_info.get_kernel_op_attributes(weights_node.type)
143
+ assert len(kernel_attrs) == 1 and kernel_attrs[0] is not None, f'Expected exactly one kernel attr, {kernel_attrs}'
144
+ kernel_attr = kernel_attrs[0]
140
145
  conf_weights = [attr for attr in weights_node.weights if weights_node.is_configurable_weight(attr)]
141
- if len(conf_weights) > 1 or len(conf_weights) == 1 and not weights_node.is_configurable_weight(weights_node.kernel_attr):
146
+ if len(conf_weights) > 1 or len(conf_weights) == 1 and not weights_node.is_configurable_weight(kernel_attr):
142
147
  raise NotImplementedError(f'Only kernel weight can be configurable. Got configurable {conf_weights}.')
143
148
 
144
149
  weights = weights_node.weights.copy()
145
150
  act_node_w_rename = {}
146
151
  if act_node.weights:
147
- if act_node.kernel_attr:
152
+ if fw_info.get_kernel_op_attributes(act_node) != DEFAULT_KERNEL_ATTRIBUTES:
148
153
  raise NotImplementedError(f'Node {act_node} with kernel cannot be used as activation for '
149
154
  f'VirtualActivationWeightsNode.')
150
155
  if act_node.has_any_configurable_weight():
@@ -152,7 +157,7 @@ class VirtualActivationWeightsNode(VirtualNode):
152
157
  'VirtualActivationWeightsNode.')
153
158
  # combine weights from activation and weights
154
159
  for w_id, w in act_node.weights.items():
155
- if w_id not in weights and not (isinstance(w_id, str) and weights_node.kernel_attr in w_id):
160
+ if w_id not in weights and not (isinstance(w_id, str) and kernel_attr in w_id):
156
161
  weights[w_id] = w
157
162
  continue
158
163
  # if same identifier is used as in weight nodes (or contains the kernel substring), generate a new
@@ -180,7 +185,7 @@ class VirtualActivationWeightsNode(VirtualNode):
180
185
  self.original_weights_node = weights_node
181
186
 
182
187
  v_candidates = []
183
- weights_candidates_quantization_cfg = weights_node.get_unique_weights_candidates(weights_node.kernel_attr)
188
+ weights_candidates_quantization_cfg = weights_node.get_unique_weights_candidates(kernel_attr)
184
189
  for c_a in act_node.candidates_quantization_cfg:
185
190
  for c_w in weights_candidates_quantization_cfg:
186
191
  composed_candidate = CandidateNodeQuantizationConfig(activation_quantization_cfg=c_a.activation_quantization_cfg,
@@ -198,8 +203,7 @@ class VirtualActivationWeightsNode(VirtualNode):
198
203
  v_candidates.append(composed_candidate)
199
204
 
200
205
  # sorting the candidates by weights number of bits first and then by activation number of bits (reversed order)
201
- v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(weights_node.kernel_attr).weights_n_bits,
206
+ v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits,
202
207
  c.activation_quantization_cfg.activation_n_bits), reverse=True)
203
208
 
204
- self.quantization_cfg = NodeQuantizationConfig(candidates_quantization_cfg=v_candidates,
205
- base_quantization_cfg=None, validate=False)
209
+ self.candidates_quantization_cfg = v_candidates
@@ -37,18 +37,20 @@ def set_bit_widths(mixed_precision_enable: bool,
37
37
  """
38
38
  if mixed_precision_enable:
39
39
  assert all([len(n.candidates_quantization_cfg) > 0
40
- for n in graph.get_configurable_sorted_nodes()]), \
40
+ for n in graph.get_configurable_sorted_nodes(graph.fw_info)]), \
41
41
  "All configurable nodes in graph should have at least one candidate configuration in mixed precision mode"
42
42
 
43
43
  # Get a list of nodes' names we need to finalize (that they have at least one weight qc candidate).
44
- sorted_nodes_names = graph.get_configurable_sorted_nodes_names()
44
+ sorted_nodes_names = graph.get_configurable_sorted_nodes_names(graph.fw_info)
45
45
 
46
46
  for node in graph.nodes: # set a specific node qc for each node final qc
47
47
  # If it's reused, take the configuration that the base node has
48
48
  node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2])
49
49
  if node_name in sorted_nodes_names: # only configurable nodes are in this list
50
50
  node_index_in_graph = sorted_nodes_names.index(node_name)
51
- _set_node_final_qc(bit_widths_config[node_index_in_graph], node)
51
+ _set_node_final_qc(bit_widths_config[node_index_in_graph],
52
+ node,
53
+ graph.fw_info)
52
54
  else:
53
55
  if node.is_activation_quantization_enabled():
54
56
  # If we are here, this means that we are in weights-only mixed-precision
@@ -81,7 +83,8 @@ def set_bit_widths(mixed_precision_enable: bool,
81
83
 
82
84
 
83
85
  def _get_node_qc_by_bit_widths(node: BaseNode,
84
- node_bit_width_cfg: int) -> Any:
86
+ node_bit_width_cfg: int,
87
+ fw_info) -> Any:
85
88
  """
86
89
  Get the node's quantization configuration that
87
90
  matches to the bit width index as in the MP configuration bit_width_cfg.
@@ -90,18 +93,21 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
90
93
  Args:
91
94
  node: Node to get its quantization configuration candidate.
92
95
  node_bit_width_cfg: Configuration which determines the node's desired bit width.
96
+ fw_info: Information relevant to a specific framework about how layers should be quantized.
93
97
 
94
98
  Returns:
95
99
  Node quantization configuration if it was found, or None otherwise.
96
100
  """
97
101
  # only the weights kernel attribute is quantized in weights mixed precision at the moment
102
+ kernel_attr = fw_info.get_kernel_op_attributes(node.type)
103
+
98
104
  if node.is_activation_quantization_enabled():
99
105
  qc = node.candidates_quantization_cfg[node_bit_width_cfg]
100
106
 
101
107
  return qc
102
108
 
103
- elif node.kernel_attr is not None:
104
- if node.is_weights_quantization_enabled(node.kernel_attr):
109
+ elif kernel_attr is not None:
110
+ if node.is_weights_quantization_enabled(kernel_attr[0]):
105
111
  qc = node.candidates_quantization_cfg[node_bit_width_cfg]
106
112
 
107
113
  return qc
@@ -110,7 +116,8 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
110
116
 
111
117
 
112
118
  def _set_node_final_qc(node_bit_width_cfg: int,
113
- node: BaseNode):
119
+ node: BaseNode,
120
+ fw_info):
114
121
  """
115
122
  Get the node's quantization configuration that
116
123
  matches to the bit width index as in the MP configuration bit_width_cfg, and use it to finalize the node's
@@ -120,9 +127,12 @@ def _set_node_final_qc(node_bit_width_cfg: int,
120
127
  Args:
121
128
  node_bit_width_cfg: Configuration which determines the node's desired bit width.
122
129
  node: Node to set its node quantization configuration.
130
+ fw_info: Information relevant to a specific framework about how layers should be quantized.
123
131
 
124
132
  """
125
- node_qc = _get_node_qc_by_bit_widths(node, node_bit_width_cfg)
133
+ node_qc = _get_node_qc_by_bit_widths(node,
134
+ node_bit_width_cfg,
135
+ fw_info)
126
136
 
127
137
  if node_qc is None:
128
138
  Logger.critical(f'Node {node.name} quantization configuration from configuration file' # pragma: no cover