mct-nightly 2.4.0.20250924.535__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -19,6 +19,7 @@ from model_compression_toolkit.core.common.pruning.pruning_framework_implementat
19
19
  PruningFrameworkImplementation
20
20
  from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
21
21
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
22
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
23
  from model_compression_toolkit.core.common import BaseNode
23
24
  from model_compression_toolkit.core.pytorch.constants import BIAS, GROUPS, OUT_CHANNELS, OUT_FEATURES, NUM_FEATURES, \
24
25
  IN_CHANNELS, IN_FEATURES, NUM_PARAMETERS
@@ -29,10 +30,6 @@ import numpy as np
29
30
  from model_compression_toolkit.logger import Logger
30
31
 
31
32
 
32
- # default output channel axis to use when it's not defined in node's fw_info.
33
- _default_output_channel_axis = 1
34
-
35
-
36
33
  class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplementation):
37
34
  """
38
35
  Implementation of the PruningFramework for the Pytorch framework. This class provides
@@ -42,23 +39,27 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
42
39
 
43
40
  def prune_entry_node(self,
44
41
  node: BaseNode,
45
- output_mask: np.ndarray):
42
+ output_mask: np.ndarray,
43
+ fw_info: FrameworkInfo):
46
44
  """
47
45
  Prunes the entry node of a model in Pytorch.
48
46
 
49
47
  Args:
50
48
  node (BaseNode): The entry node to be pruned.
51
49
  output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
50
+ fw_info (FrameworkInfo): Framework-specific information object.
52
51
 
53
52
  """
54
53
  return _prune_pytorch_edge_node(node=node,
55
54
  mask=output_mask,
55
+ fw_info=fw_info,
56
56
  is_exit_node=False)
57
57
 
58
58
  def prune_intermediate_node(self,
59
59
  node: BaseNode,
60
60
  input_mask: np.ndarray,
61
- output_mask: np.ndarray):
61
+ output_mask: np.ndarray,
62
+ fw_info: FrameworkInfo):
62
63
  """
63
64
  Prunes an intermediate node in a Pytorch model.
64
65
 
@@ -66,11 +67,12 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
66
67
  node (BaseNode): The intermediate node to be pruned.
67
68
  input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
68
69
  output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
70
+ fw_info (FrameworkInfo): Framework-specific information object.
69
71
 
70
72
  """
71
73
  # TODO (reuvenp/liord): Address handling of node parameters that can be either a single value across all channels or distinct per channel, e.g., PReLU. Consider developing a structured approach.
72
74
  pruning_en = True
73
- _edit_node_input_shape(node, input_mask)
75
+ _edit_node_input_shape(node, input_mask, fw_info)
74
76
  pruned_parameters = {}
75
77
  mask_bool = output_mask.astype(bool)
76
78
  node.weights = pruned_parameters
@@ -89,17 +91,20 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
89
91
 
90
92
  def prune_exit_node(self,
91
93
  node: BaseNode,
92
- input_mask: np.ndarray):
94
+ input_mask: np.ndarray,
95
+ fw_info: FrameworkInfo):
93
96
  """
94
97
  Prunes the exit node of a model in Pytorch.
95
98
 
96
99
  Args:
97
100
  node (BaseNode): The exit node to be pruned.
98
101
  input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
102
+ fw_info (FrameworkInfo): Framework-specific information object.
99
103
 
100
104
  """
101
105
  return _prune_pytorch_edge_node(node=node,
102
106
  mask=input_mask,
107
+ fw_info=fw_info,
103
108
  is_exit_node=True)
104
109
 
105
110
  def is_node_entry_node(self, node: BaseNode) -> bool:
@@ -116,19 +121,22 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
116
121
 
117
122
  def is_node_exit_node(self,
118
123
  node: BaseNode,
119
- corresponding_entry_node: BaseNode) -> bool:
124
+ corresponding_entry_node: BaseNode,
125
+ fw_info: FrameworkInfo) -> bool:
120
126
  """
121
127
  Determines whether a node is an exit node in a Pytorch model.
122
128
 
123
129
  Args:
124
130
  node (BaseNode): The node to be checked.
125
131
  corresponding_entry_node (BaseNode): The entry node of the pruning section that is checked.
132
+ fw_info (FrameworkInfo) Framework-specific information object.
126
133
 
127
134
  Returns:
128
135
  bool: Boolean indicating if the node is an exit node.
129
136
  """
130
137
  return _is_pytorch_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node,
131
- corresponding_entry_node)
138
+ corresponding_entry_node,
139
+ fw_info)
132
140
 
133
141
  def is_node_intermediate_pruning_section(self, node: BaseNode) -> bool:
134
142
  """
@@ -147,7 +155,8 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
147
155
  torch.nn.Linear]
148
156
 
149
157
  def attrs_oi_channels_info_for_pruning(self,
150
- node: BaseNode) -> Dict[str, Tuple[int, int]]:
158
+ node: BaseNode,
159
+ fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
151
160
  """
152
161
  Retrieves the attributes of a given node along with the output/input (OI) channel axis
153
162
  for each attribute used to prune these attributes.
@@ -164,6 +173,7 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
164
173
 
165
174
  Args:
166
175
  node (BaseNode): The node from the computational graph.
176
+ fw_info (FrameworkInfo): Contains framework-specific information and utilities.
167
177
 
168
178
  Returns:
169
179
  Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'weight' or 'bias')
@@ -171,8 +181,13 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
171
181
  """
172
182
 
173
183
  attributes_with_axis = {}
174
- if node.kernel_attr:
175
- attributes_with_axis[node.kernel_attr] = (node.channel_axis.output, node.channel_axis.input)
184
+ if fw_info.is_kernel_op(node.type):
185
+ kernel_attributes = fw_info.get_kernel_op_attributes(node.type)
186
+ if kernel_attributes is None or len(kernel_attributes) == 0:
187
+ Logger.critical(f"Expected to find kernel attributes but none were identified for node '{node.name}' of type {node.type}.")
188
+
189
+ for attr in kernel_attributes:
190
+ attributes_with_axis[attr] = fw_info.kernel_channels_mapping.get(node.type)
176
191
 
177
192
  # Bias is a vector at the length of the number of output channels.
178
193
  # For this reason, input channel axis is irrelevant to the bias attribute.
@@ -187,17 +202,13 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
187
202
  # If the number of float parameters is 1 or less - is the case where
188
203
  # we have one parameter for all channels. For this case, we don't
189
204
  # want to prune the parameter.
190
- if node.get_num_parameters()[1] <= 1:
205
+ if node.get_num_parameters(fw_info)[1] <= 1:
191
206
  attributes_with_axis[attr] = (None, None)
192
207
  else:
193
208
  attributes_with_axis[attr] = (-1, None)
194
209
 
195
210
  return attributes_with_axis
196
211
 
197
- @property
198
- def default_output_channel_axis(self):
199
- return _default_output_channel_axis
200
-
201
212
 
202
213
  def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
203
214
  """
@@ -223,6 +234,7 @@ def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
223
234
 
224
235
  def _prune_pytorch_edge_node(node: BaseNode,
225
236
  mask: np.ndarray,
237
+ fw_info: FrameworkInfo,
226
238
  is_exit_node: bool):
227
239
  """
228
240
  Prunes the given Pytorch node by applying the mask to the node's weights (weights and biases).
@@ -231,18 +243,21 @@ def _prune_pytorch_edge_node(node: BaseNode,
231
243
  Args:
232
244
  node (BaseNode): The node to be pruned.
233
245
  mask (np.ndarray): The pruning mask to be applied.
246
+ fw_info (FrameworkInfo): Framework-specific information object.
234
247
  is_exit_node (bool): A boolean indicating whether the node is an exit node.
235
248
 
236
249
  """
237
250
 
238
251
  # Retrieve the kernel attribute and the axes to prune.
239
- axis_to_prune = node.channel_axis.input if is_exit_node else node.channel_axis.output
240
- kernel = node.get_weights_by_keys(node.kernel_attr)
252
+ kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
253
+ io_axis = fw_info.kernel_channels_mapping.get(node.type)
254
+ axis_to_prune = io_axis[int(is_exit_node)]
255
+ kernel = node.get_weights_by_keys(kernel_attr)
241
256
  # Convert mask to boolean.
242
257
  mask_bool = mask.astype(bool)
243
258
 
244
259
  pruned_kernel = kernel.compress(mask_bool, axis=axis_to_prune)
245
- node.set_weights_by_keys(name=node.kernel_attr, tensor=pruned_kernel)
260
+ node.set_weights_by_keys(name=kernel_attr, tensor=pruned_kernel)
246
261
 
247
262
  if not is_exit_node and node.framework_attr[BIAS]:
248
263
  # Prune the bias if applicable and it's an entry node.
@@ -270,11 +285,12 @@ def _prune_pytorch_edge_node(node: BaseNode,
270
285
  Logger.critical(f"{node.type} is currently not supported"
271
286
  f"as an edge node in a pruning section")
272
287
  # Adjust the input shape for the last node in the section.
273
- _edit_node_input_shape(node, mask_bool)
288
+ _edit_node_input_shape(node, mask_bool, fw_info)
274
289
 
275
290
 
276
291
  def _edit_node_input_shape(node: BaseNode,
277
- input_mask: np.ndarray):
292
+ input_mask: np.ndarray,
293
+ fw_info: FrameworkInfo):
278
294
  """
279
295
  Adjusts the input shape of a node based on the given input mask.
280
296
 
@@ -285,13 +301,14 @@ def _edit_node_input_shape(node: BaseNode,
285
301
  Args:
286
302
  node (BaseNode): The node whose input shape needs to be adjusted.
287
303
  input_mask (np.ndarray): A binary array where 1 indicates the channel is kept and 0 means pruned.
304
+ fw_info (FrameworkInfo): Framework-specific information object.
288
305
  """
289
306
  # Start with the current input shape of the node.
290
307
  new_input_shape = list(node.input_shape)
291
308
 
292
309
  # Adjust the last dimension of the shape to match the number of unpruned (retained) channels.
293
310
  # This is done by summing the mask, as each '1' in the mask represents a retained channel.
294
- channel_axis = _default_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis
311
+ channel_axis = fw_info.out_channel_axis_mapping.get(node.type)
295
312
  new_input_shape[0][channel_axis] = int(np.sum(input_mask))
296
313
 
297
314
  # Update the node's input shape with the new dimensions.
@@ -26,7 +26,7 @@ from torch.nn import Module, Sigmoid, Softmax
26
26
 
27
27
  import model_compression_toolkit.core.pytorch.constants as pytorch_constants
28
28
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
29
- from model_compression_toolkit.core import QuantizationConfig, CoreConfig
29
+ from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig
30
30
  from model_compression_toolkit.core import common
31
31
  from model_compression_toolkit.core.common import Graph, BaseNode
32
32
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
@@ -37,6 +37,7 @@ from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
37
37
  from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_kl_divergence, compute_cs
38
38
  from model_compression_toolkit.core.pytorch.back2framework import get_pytorch_model_builder
39
39
  from model_compression_toolkit.core.pytorch.data_util import data_gen_to_dataloader
40
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
40
41
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_folding import \
41
42
  pytorch_batchnorm_folding, pytorch_batchnorm_forward_folding
42
43
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_reconstruction import \
@@ -177,6 +178,7 @@ class PytorchImplementation(FrameworkImplementation):
177
178
  graph: Graph,
178
179
  mode: ModelBuilderMode,
179
180
  append2output: List[Any] = None,
181
+ fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
180
182
  return_float_outputs: bool = False) -> Tuple:
181
183
  """
182
184
  Build a Pytorch module from a graph.
@@ -187,6 +189,7 @@ class PytorchImplementation(FrameworkImplementation):
187
189
  graph: Graph to build the module from it.
188
190
  mode: Mode for how to build the module.
189
191
  append2output: List of Nodes to set as the module's outputs.
192
+ fw_info: FrameworkInfo object with information about the specific framework's module
190
193
  return_float_outputs (bool): whether to return outputs before or after quantization nodes (default)
191
194
 
192
195
  Returns:
@@ -195,6 +198,7 @@ class PytorchImplementation(FrameworkImplementation):
195
198
  pytorch_model_builder = get_pytorch_model_builder(mode)
196
199
  return pytorch_model_builder(graph=graph,
197
200
  append2output=append2output,
201
+ fw_info=fw_info,
198
202
  return_float_outputs=return_float_outputs).build_model()
199
203
 
200
204
  def run_model_inference(self,
@@ -228,55 +232,63 @@ class PytorchImplementation(FrameworkImplementation):
228
232
 
229
233
  def shift_negative_correction(self,
230
234
  graph: Graph,
231
- core_config: CoreConfig) -> Graph:
235
+ core_config: CoreConfig,
236
+ fw_info: FrameworkInfo) -> Graph:
232
237
  """
233
238
  Apply shift negative correction (SNC) on a graph.
234
239
 
235
240
  Args:
236
241
  graph: Graph to apply SNC on.
237
242
  core_config: Quantization configuration.
243
+ fw_info: FrameworkInfo object with information about the specific framework's module.
238
244
 
239
245
  Returns:
240
246
  Graph after SNC.
241
247
  """
242
248
  return pytorch_apply_shift_negative_correction(graph,
243
- core_config)
249
+ core_config,
250
+ fw_info)
244
251
 
245
252
  def compute_activation_bias_correction(self,
246
253
  graph: Graph,
247
- quant_config: QuantizationConfig):
254
+ quant_config: QuantizationConfig,
255
+ fw_info: FrameworkInfo):
248
256
  """
249
257
  Compute activation bias correction on a graph.
250
258
 
251
259
  Args:
252
260
  graph: Graph to apply activation bias correction on.
253
261
  quant_config: QuantizationConfig of how the model should be quantized.
262
+ fw_info: FrameworkInfo object with information about the specific framework's model.
254
263
 
255
264
  Returns:
256
265
  Graph after activation bias correction computing.
257
266
  """
258
267
  return pytorch_compute_activation_bias_correction_of_graph(graph=graph,
259
268
  quant_config=quant_config,
269
+ fw_info=fw_info,
260
270
  fw_impl=self)
261
271
 
262
272
  def get_substitutions_channel_equalization(self,
263
- quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
273
+ quant_config: QuantizationConfig,
274
+ fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
264
275
  """
265
276
  Return a list of the framework substitutions used for channel equalization.
266
277
 
267
278
  Args:
268
279
  quant_config: QuantizationConfig to determine which substitutions to return.
280
+ fw_info: FrameworkInfo object with information about the specific framework's model.
269
281
 
270
282
  Returns:
271
283
  A list of the framework substitutions used after we collect statistics.
272
284
  """
273
285
  substitutions_list = []
274
286
  if quant_config.activation_channel_equalization:
275
- substitutions_list.extend([ScaleEqualization(quant_config),
276
- ScaleEqualizationWithPad(quant_config)])
287
+ substitutions_list.extend([ScaleEqualization(quant_config, fw_info),
288
+ ScaleEqualizationWithPad(quant_config, fw_info)])
277
289
  return substitutions_list
278
290
 
279
- def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
291
+ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List[common.BaseSubstitution]:
280
292
  """
281
293
 
282
294
  Returns: A list of the framework substitutions used before we collect the prior information.
@@ -287,7 +299,7 @@ class PytorchImplementation(FrameworkImplementation):
287
299
  ScaledDotProductDecomposition(),
288
300
  MatMulDecomposition(),
289
301
  TransformFunctionCallMethod(),
290
- FunctionalConvSubstitution(),
302
+ FunctionalConvSubstitution(fw_info),
291
303
  FunctionalBatchNorm(),
292
304
  FunctionalLayerNorm(),
293
305
  FunctionalLinear(),
@@ -389,17 +401,20 @@ class PytorchImplementation(FrameworkImplementation):
389
401
 
390
402
  def get_node_prior_info(self,
391
403
  node: BaseNode,
404
+ fw_info: FrameworkInfo,
392
405
  graph: Graph) -> NodePriorInfo:
393
406
  """
394
407
  Get a NodePriorInfo object for a node that represents a Pytorch layer.
395
408
  Args:
396
409
  node: Node to get its prior info.
410
+ fw_info: Framework specific information needed to create the prior info of the node.
397
411
  graph: Graph to check the next node type.
398
412
  Returns:
399
413
  NodePriorInfo with information about the node.
400
414
  """
401
415
 
402
416
  return create_node_prior_info(node=node,
417
+ fw_info=fw_info,
403
418
  graph=graph)
404
419
 
405
420
  def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
@@ -461,19 +476,23 @@ class PytorchImplementation(FrameworkImplementation):
461
476
  return node.layer_class not in [argmax, softmax, Softmax]
462
477
 
463
478
  def get_node_mac_operations(self,
464
- node: BaseNode) -> float:
479
+ node: BaseNode,
480
+ fw_info: FrameworkInfo) -> float:
465
481
  """
466
482
  Gets the MAC operation count for a given operation.
467
483
 
468
484
  Args:
469
485
  node: A graph node that wraps the operation for which the MAC count is computed.
486
+ fw_info: FrameworkInfo object with information about the Pytorch model.
470
487
 
471
488
  Returns: The MAC count of the operation
472
489
  """
473
- if node.kernel_attr is None:
490
+ kernels = fw_info.get_kernel_op_attributes(node.type)
491
+ if not kernels or kernels[0] is None:
474
492
  return 0
475
493
 
476
- kernel_shape = node.get_weights_by_keys(node.kernel_attr).shape
494
+ assert len(kernels) == 1
495
+ kernel_shape = node.get_weights_by_keys(kernels[0]).shape
477
496
 
478
497
  if node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d):
479
498
  h, w = node.get_output_shapes_list()[0][-2:]
@@ -481,7 +500,8 @@ class PytorchImplementation(FrameworkImplementation):
481
500
 
482
501
  if node.is_match_type(Linear):
483
502
  # IN * OUT * (all previous dims[:-1])
484
- return node.get_total_output_params() * kernel_shape[node.channel_axis.input]
503
+ _, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
504
+ return node.get_total_output_params() * kernel_shape[input_channel_axis]
485
505
 
486
506
  return 0
487
507
 
@@ -23,19 +23,23 @@ from model_compression_toolkit.core.pytorch.constants import MOVING_MEAN, MOVING
23
23
 
24
24
 
25
25
  def create_node_prior_info(node: BaseNode,
26
+ fw_info: FrameworkInfo,
26
27
  graph: Graph):
27
28
  """
28
29
  Create a NodePriorInfo object for a given node.
29
30
 
30
31
  Args:
31
32
  node: Node to create its prior info.
33
+ fw_info: Information about a specific framework the node was generated from.
32
34
  graph: Graph to check the next node type.
33
35
 
34
36
  Returns:
35
37
  NodePriorInfo object with info about the node.
36
38
  """
37
39
 
38
- min_output, max_output = node.minmax
40
+ min_output, max_output = None, None
41
+ if fw_info.layers_has_min_max(node.type):
42
+ min_output, max_output = fw_info.layer_min_max_mapping[node.type]
39
43
  mean_output, std_output = _get_mean_std_outputs(node=node,
40
44
  graph=graph)
41
45
  return NodePriorInfo(min_output=min_output,
@@ -27,7 +27,7 @@ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler impor
27
27
  from model_compression_toolkit.verify_packages import FOUND_TORCH
28
28
 
29
29
  if FOUND_TORCH:
30
- from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
30
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
31
31
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
32
32
  from torch.nn import Module
33
33
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
@@ -38,7 +38,6 @@ if FOUND_TORCH:
38
38
  PYTORCH_DEFAULT_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
39
39
 
40
40
 
41
- @set_pytorch_info
42
41
  def pytorch_resource_utilization_data(in_model: Module,
43
42
  representative_data_gen: Callable,
44
43
  core_config: CoreConfig = CoreConfig(),
@@ -94,6 +93,7 @@ if FOUND_TORCH:
94
93
  representative_data_gen,
95
94
  core_config,
96
95
  target_platform_capabilities,
96
+ DEFAULT_PYTORCH_INFO,
97
97
  fw_impl)
98
98
 
99
99
  else:
@@ -18,7 +18,7 @@ from torch.nn import Conv2d, Linear, ConvTranspose2d
18
18
  from model_compression_toolkit.core import QuantizationConfig
19
19
  from model_compression_toolkit.core.common import Graph
20
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
- from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
21
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
22
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
23
23
  from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \
24
24
  compute_activation_bias_correction_of_graph
@@ -33,6 +33,7 @@ def activation_bias_correction_node_matchers():
33
33
 
34
34
  def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
35
35
  quant_config: QuantizationConfig,
36
+ fw_info: FrameworkInfo,
36
37
  fw_impl: FrameworkImplementation) -> Graph:
37
38
  """
38
39
  Compute the activation bias correction term for graph based on a PyTorch model.
@@ -40,6 +41,7 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
40
41
  Args:
41
42
  graph: Graph with nodes to compute the activation bias correction.
42
43
  quant_config: QuantizationConfig of how the model should be quantized.
44
+ fw_info: Framework info like lists of nodes their kernel should quantized.
43
45
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
44
46
 
45
47
  Returns:
@@ -47,9 +49,9 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
47
49
  """
48
50
  graph = compute_activation_bias_correction_of_graph(graph=graph,
49
51
  quant_config=quant_config,
52
+ fw_info=fw_info,
50
53
  fw_impl=fw_impl,
51
54
  activation_bias_correction_node_matchers=
52
55
  activation_bias_correction_node_matchers,
53
- kernel_size=KERNEL_SIZE,
54
- get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
56
+ kernel_size=KERNEL_SIZE)
55
57
  return graph
@@ -37,6 +37,7 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
37
37
  def quantization_preparation_runner(graph: Graph,
38
38
  representative_data_gen: Callable,
39
39
  core_config: CoreConfig,
40
+ fw_info: FrameworkInfo,
40
41
  fw_impl: FrameworkImplementation,
41
42
  tb_w: TensorboardWriter = None,
42
43
  hessian_info_service: HessianInfoService = None, ) -> Graph:
@@ -52,6 +53,8 @@ def quantization_preparation_runner(graph: Graph,
52
53
  graph: A graph representation of the model to be quantized.
53
54
  representative_data_gen: Dataset used for calibration.
54
55
  core_config: CoreConfig containing parameters of how the model should be quantized
56
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
57
+ groups of layers by how they should be quantized, etc.).
55
58
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
56
59
  tb_w: TensorboardWriter object for logging
57
60
  hessian_info_service: HessianInfoService object for retrieving Hessian-based scores.
@@ -65,6 +68,7 @@ def quantization_preparation_runner(graph: Graph,
65
68
  ######################################
66
69
  mi = ModelCollector(graph,
67
70
  fw_impl,
71
+ fw_info,
68
72
  hessian_info_service,
69
73
  core_config.quantization_config) # Mark points for statistics collection
70
74
 
@@ -81,14 +85,14 @@ def quantization_preparation_runner(graph: Graph,
81
85
  # Notice that not all actions affect at this stage (for example, actions that edit the final configuration as
82
86
  # there are no final configurations at this stage of the optimization). For this reason we edit the graph
83
87
  # again at the end of the optimization process.
84
- edit_network_graph(graph, core_config.debug_config.network_editor)
88
+ edit_network_graph(graph, fw_info, core_config.debug_config.network_editor)
85
89
 
86
90
  ######################################
87
91
  # Calculate quantization params
88
92
  ######################################
89
93
 
90
- calculate_quantization_params(graph, core_config.quantization_config, fw_impl=fw_impl,
91
- repr_data_gen_fn=representative_data_gen, hessian_info_service=hessian_info_service)
94
+ calculate_quantization_params(graph, fw_impl=fw_impl, repr_data_gen_fn=representative_data_gen,
95
+ hessian_info_service=hessian_info_service)
92
96
 
93
97
  if tb_w is not None:
94
98
  tb_w.add_graph(graph, 'thresholds_selection')
@@ -105,7 +109,8 @@ def quantization_preparation_runner(graph: Graph,
105
109
  ######################################
106
110
  if core_config.quantization_config.shift_negative_activation_correction:
107
111
  transformed_graph = fw_impl.shift_negative_correction(transformed_graph,
108
- core_config)
112
+ core_config,
113
+ fw_info)
109
114
  if tb_w is not None:
110
115
  tb_w.add_graph(transformed_graph, 'after_shift_negative_correction')
111
116
  tb_w.add_all_statistics(transformed_graph, 'after_shift_negative_correction')
@@ -117,9 +122,9 @@ def quantization_preparation_runner(graph: Graph,
117
122
  ######################################
118
123
  # Statistics Correction
119
124
  ######################################
120
- tg_with_bias = statistics_correction_runner(transformed_graph, core_config, fw_impl, tb_w)
125
+ tg_with_bias = statistics_correction_runner(transformed_graph, core_config, fw_info, fw_impl, tb_w)
121
126
 
122
127
  for n in tg_with_bias.nodes:
123
128
  assert n.final_weights_quantization_cfg is None
124
129
 
125
- return tg_with_bias
130
+ return tg_with_bias
@@ -16,6 +16,7 @@
16
16
  import copy
17
17
  from typing import Callable, Any, List, Optional
18
18
 
19
+ from model_compression_toolkit.core.common import FrameworkInfo
19
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
20
21
  from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser
21
22
  from model_compression_toolkit.core.common.graph.base_graph import Graph
@@ -45,6 +46,7 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
45
46
  def core_runner(in_model: Any,
46
47
  representative_data_gen: Callable,
47
48
  core_config: CoreConfig,
49
+ fw_info: FrameworkInfo,
48
50
  fw_impl: FrameworkImplementation,
49
51
  fqc: FrameworkQuantizationCapabilities,
50
52
  target_resource_utilization: ResourceUtilization = None,
@@ -63,6 +65,7 @@ def core_runner(in_model: Any,
63
65
  in_model: Model to quantize.
64
66
  representative_data_gen: Dataset used for calibration.
65
67
  core_config: CoreConfig containing parameters of how the model should be quantized
68
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
66
69
  groups of layers by how they should be quantized, etc.).
67
70
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
68
71
  fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
@@ -96,6 +99,7 @@ def core_runner(in_model: Any,
96
99
  graph = graph_preparation_runner(in_model,
97
100
  representative_data_gen,
98
101
  core_config.quantization_config,
102
+ fw_info,
99
103
  fw_impl,
100
104
  fqc,
101
105
  core_config.bit_width_config,
@@ -108,6 +112,7 @@ def core_runner(in_model: Any,
108
112
  tg = quantization_preparation_runner(graph=graph,
109
113
  representative_data_gen=representative_data_gen,
110
114
  core_config=core_config,
115
+ fw_info=fw_info,
111
116
  fw_impl=fw_impl,
112
117
  tb_w=tb_w,
113
118
  hessian_info_service=hessian_info_service)
@@ -118,8 +123,9 @@ def core_runner(in_model: Any,
118
123
  if core_config.is_mixed_precision_enabled:
119
124
  if core_config.mixed_precision_config.configuration_overwrite is None:
120
125
 
121
- filter_candidates_for_mixed_precision(graph, target_resource_utilization)
126
+ filter_candidates_for_mixed_precision(graph, target_resource_utilization, fw_info, fqc)
122
127
  bit_widths_config = search_bit_width(tg,
128
+ fw_info,
123
129
  fw_impl,
124
130
  target_resource_utilization,
125
131
  core_config.mixed_precision_config,
@@ -147,20 +153,22 @@ def core_runner(in_model: Any,
147
153
  ######################################
148
154
  if core_config.quantization_config.activation_bias_correction:
149
155
  tg = fw_impl.compute_activation_bias_correction(graph=tg,
150
- quant_config=core_config.quantization_config)
156
+ quant_config=core_config.quantization_config,
157
+ fw_info=fw_info)
151
158
 
152
159
  # Edit the graph again after finalizing the configurations.
153
160
  # This is since some actions regard the final configuration and should be edited.
154
- edit_network_graph(tg, core_config.debug_config.network_editor)
161
+ edit_network_graph(tg, fw_info, core_config.debug_config.network_editor)
155
162
 
156
163
  _set_final_resource_utilization(graph=tg,
157
164
  final_bit_widths_config=bit_widths_config,
158
165
  target_resource_utilization=target_resource_utilization,
166
+ fw_info=fw_info,
159
167
  fw_impl=fw_impl)
160
168
 
161
169
  if core_config.is_mixed_precision_enabled:
162
170
  # Retrieve lists of tuples (node, node's final weights/activation bitwidth)
163
- weights_conf_nodes_bitwidth = tg.get_final_weights_config()
171
+ weights_conf_nodes_bitwidth = tg.get_final_weights_config(fw_info)
164
172
  activation_conf_nodes_bitwidth = tg.get_final_activation_config()
165
173
 
166
174
  if len(weights_conf_nodes_bitwidth) > 0:
@@ -192,6 +200,7 @@ def core_runner(in_model: Any,
192
200
  def _set_final_resource_utilization(graph: Graph,
193
201
  final_bit_widths_config: List[int],
194
202
  target_resource_utilization: Optional[ResourceUtilization],
203
+ fw_info: FrameworkInfo,
195
204
  fw_impl: FrameworkImplementation):
196
205
  """
197
206
  Computing the resource utilization of the model according to the final bit-width configuration,
@@ -201,13 +210,14 @@ def _set_final_resource_utilization(graph: Graph,
201
210
  graph: Graph to compute the resource utilization for.
202
211
  final_bit_widths_config: The final bit-width configuration to quantize the model accordingly.
203
212
  target_resource_utilization: Requested target resource utilization if relevant.
213
+ fw_info: A FrameworkInfo object.
204
214
  fw_impl: FrameworkImplementation object with specific framework methods implementation.
205
215
 
206
216
  """
207
217
  ru_targets = target_resource_utilization.get_restricted_targets() if target_resource_utilization else None
208
218
  final_ru = None
209
219
  if ru_targets:
210
- ru_calculator = ResourceUtilizationCalculator(graph, fw_impl)
220
+ ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
211
221
  w_qcs = {n.name: n.final_weights_quantization_cfg for n in graph.nodes}
212
222
  a_qcs = {n.name: n.final_activation_quantization_cfg for n in graph.nodes}
213
223
  final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused,