mct-nightly 2.4.0.20250924.535__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -20,9 +20,15 @@ from typing import Callable
20
20
  from mct_quantizers import QuantizationMethod
21
21
  from model_compression_toolkit.core.common import Graph
22
22
  from model_compression_toolkit.logger import Logger
23
- from model_compression_toolkit.core.common.graph.base_node import BaseNode
24
23
 
25
24
 
25
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
26
+ from model_compression_toolkit.core.common.graph.base_node import BaseNode
27
+ from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
28
+ get_activation_quantization_params_fn, get_weights_quantization_params_fn
29
+ from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
30
+ get_weights_quantization_fn
31
+
26
32
  _EditRule = namedtuple('EditRule', 'filter action')
27
33
 
28
34
 
@@ -58,13 +64,15 @@ class BaseAction(ABC):
58
64
  """
59
65
 
60
66
  @abstractmethod
61
- def apply(self, node: BaseNode, graph):
67
+ def apply(self, node: BaseNode, graph, fw_info):
62
68
  """
63
69
  Apply an action on the node after matching the node with a node filter.
64
70
 
65
71
  Args:
66
72
  node: Node to apply the action on.
67
73
  graph: Graph to apply the action on.
74
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
75
+ groups of layers by how they should be quantized, etc.)
68
76
 
69
77
  Returns:
70
78
  Node after action is applied.
@@ -87,13 +95,15 @@ class ChangeCandidatesWeightsQuantConfigAttr(BaseAction):
87
95
  self.kwargs = kwargs
88
96
  self.attr_name = attr_name
89
97
 
90
- def apply(self, node: BaseNode, graph):
98
+ def apply(self, node: BaseNode, graph, fw_info):
91
99
  """
92
100
  Change the attribute 'attr_name' in weights quantization config candidates with 'attr_value'.
93
101
 
94
102
  Args:
95
103
  node: Node object to change its quant_config.
96
104
  graph: Graph to apply the action on.
105
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
106
+ groups of layers by how they should be quantized, etc.)
97
107
  Returns:
98
108
  The node after its weights' quantization config candidates have been modified.
99
109
  """
@@ -118,7 +128,7 @@ class ChangeFinalWeightsQuantConfigAttr(BaseAction):
118
128
  self.kwargs = kwargs
119
129
  self.attr_name = attr_name
120
130
 
121
- def apply(self, node: BaseNode, graph):
131
+ def apply(self, node: BaseNode, graph, fw_info):
122
132
  if node.final_weights_quantization_cfg is not None:
123
133
  for parameter_name, parameter_value in self.kwargs.items():
124
134
  node.final_weights_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value,
@@ -137,13 +147,17 @@ class ChangeCandidatesActivationQuantConfigAttr(BaseAction):
137
147
  """
138
148
  self.kwargs = kwargs
139
149
 
140
- def apply(self, node: BaseNode, graph):
150
+ def apply(self, node: BaseNode, graph, fw_info):
141
151
  """
142
152
  Change the attribute 'attr_name' in activation quantization configuration candidates with 'attr_value'.
143
153
 
144
154
  Args:
145
155
  node: Node object to change its quant_config.
146
156
  graph: Graph to apply the action on.
157
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
158
+ groups of layers by how they should be quantized, etc.)
159
+ Returns:q
160
+ The node after its activation quantization configuration candidates have been modified.
147
161
  """
148
162
  for nqc in node.candidates_quantization_cfg:
149
163
  for parameter_name, parameter_value in self.kwargs.items():
@@ -162,12 +176,55 @@ class ChangeFinalActivationQuantConfigAttr(BaseAction):
162
176
  """
163
177
  self.kwargs = kwargs
164
178
 
165
- def apply(self, node: BaseNode, graph):
179
+ def apply(self, node: BaseNode, graph, fw_info):
166
180
  if node.final_activation_quantization_cfg is not None:
167
181
  for parameter_name, parameter_value in self.kwargs.items():
168
182
  node.final_activation_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value)
169
183
 
170
184
 
185
+ class ChangeQuantizationParamFunction(BaseAction):
186
+ """
187
+ Class ChangeQuantizationParamFunction to change a node's weights/activations quantization params function.
188
+ """
189
+
190
+ def __init__(self,
191
+ attr_name: str = None,
192
+ activation_quantization_params_fn: Callable = None,
193
+ weights_quantization_params_fn: Callable = None):
194
+ """
195
+ Init a ChangeQuantizationParamFunction object.
196
+
197
+ Args:
198
+ attr_name: The weights attribute's name to set the weights quantization params function for (if setting weights params).
199
+ activation_quantization_params_fn: a params function for a node's activations.
200
+ weights_quantization_params_fn: a params function for a node's weights.
201
+ """
202
+ self.activation_quantization_params_fn = activation_quantization_params_fn
203
+ self.weights_quantization_params_fn = weights_quantization_params_fn
204
+ self.attr_name = attr_name
205
+
206
+ def apply(self, node: BaseNode, graph, fw_info):
207
+ """
208
+ Change the node's weights/activations quantization params function.
209
+
210
+ Args:
211
+ node: Node object to change its quantization params function.
212
+ graph: Graph to apply the action on.
213
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
214
+ groups of layers by how they should be quantized, etc.)
215
+
216
+ Returns:
217
+ The node after its quantization params function has been modified.
218
+ """
219
+ for nqc in node.candidates_quantization_cfg:
220
+ if self.activation_quantization_params_fn is not None:
221
+ nqc.activation_quantization_cfg.set_activation_quantization_params_fn(
222
+ self.activation_quantization_params_fn)
223
+ if self.weights_quantization_params_fn is not None:
224
+ (nqc.weights_quantization_cfg.get_attr_config(self.attr_name)
225
+ .set_weights_quantization_params_fn(self.weights_quantization_params_fn))
226
+
227
+
171
228
  class ChangeFinalActivationQuantizationMethod(BaseAction):
172
229
  """
173
230
  Class ChangeFinalActivationQuantizationMethod to change a node's weights/activations quantizer function.
@@ -183,19 +240,31 @@ class ChangeFinalActivationQuantizationMethod(BaseAction):
183
240
 
184
241
  self.activation_quantization_method = activation_quantization_method
185
242
 
186
- def apply(self, node: BaseNode, graph):
243
+ def apply(self, node: BaseNode, graph, fw_info):
187
244
  """
188
245
  Change the node's activations quantization function.
189
246
 
190
247
  Args:
191
248
  node: Node object to change its threshold selection function.
192
249
  graph: Graph to apply the action on.
250
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
251
+ groups of layers by how they should be quantized, etc.)
193
252
 
194
253
  Returns:
195
254
  The node after its quantization function has been modified.
196
255
  """
197
256
 
198
257
  if self.activation_quantization_method is not None and node.final_activation_quantization_cfg is not None:
258
+
259
+ activation_quantization_params_fn = get_activation_quantization_params_fn(
260
+ self.activation_quantization_method)
261
+
262
+ node.final_activation_quantization_cfg.set_activation_quantization_params_fn(
263
+ activation_quantization_params_fn)
264
+
265
+ activation_quantization_fn = fw_info.activation_quantizer_mapping.get(self.activation_quantization_method)
266
+
267
+ node.final_activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
199
268
  node.final_activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
200
269
 
201
270
 
@@ -213,23 +282,38 @@ class ChangeCandidatesActivationQuantizationMethod(BaseAction):
213
282
  """
214
283
  self.activation_quantization_method = activation_quantization_method
215
284
 
216
- def apply(self, node: BaseNode, graph):
285
+ def apply(self, node: BaseNode, graph, fw_info):
217
286
  """
218
287
  Change the node's activations quantization function.
219
288
 
220
289
  Args:
221
290
  node: Node object to change its threshold selection function.
222
291
  graph: Graph to apply the action on.
292
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
293
+ groups of layers by how they should be quantized, etc.)
223
294
 
295
+ Returns:
296
+ The node after its quantization function has been modified.
224
297
  """
225
298
  if self.activation_quantization_method is not None:
226
299
  for qc in node.candidates_quantization_cfg:
300
+ activation_quantization_params_fn = get_activation_quantization_params_fn(
301
+ self.activation_quantization_method)
302
+
303
+ qc.activation_quantization_cfg.set_activation_quantization_params_fn(activation_quantization_params_fn)
304
+ activation_quantization_fn = fw_info.activation_quantizer_mapping.get(
305
+ self.activation_quantization_method)
306
+
307
+ if activation_quantization_fn is None:
308
+ Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
309
+
310
+ qc.activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
227
311
  qc.activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
228
312
 
229
313
 
230
314
  class ChangeFinalWeightsQuantizationMethod(BaseAction):
231
315
  """
232
- Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer method.
316
+ Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer function.
233
317
  """
234
318
 
235
319
  def __init__(self, attr_name: str, weights_quantization_method=None):
@@ -244,19 +328,36 @@ class ChangeFinalWeightsQuantizationMethod(BaseAction):
244
328
  self.weights_quantization_method = weights_quantization_method
245
329
  self.attr_name = attr_name
246
330
 
247
- def apply(self, node: BaseNode, graph):
331
+ def apply(self, node: BaseNode, graph, fw_info):
248
332
  """
249
333
  Change the node's weights quantization function.
250
334
 
251
335
  Args:
252
336
  node: Node object to change its threshold selection function.
253
337
  graph: Graph to apply the action on.
338
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
339
+ groups of layers by how they should be quantized, etc.)
254
340
 
341
+ Returns:
342
+ The node after its quantization function has been modified.
255
343
  """
256
344
 
257
345
  if self.weights_quantization_method is not None and node.final_weights_quantization_cfg is not None:
258
- attr_config = node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
259
- attr_config.weights_quantization_method = self.weights_quantization_method
346
+
347
+ weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)
348
+
349
+ (node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
350
+ .set_weights_quantization_params_fn(weights_quantization_params_fn))
351
+
352
+ weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
353
+
354
+ if weights_quantization_fn is None:
355
+ Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
356
+
357
+ (node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
358
+ .set_weights_quantization_fn(weights_quantization_fn))
359
+ node.final_weights_quantization_cfg.get_attr_config(self.attr_name).weights_quantization_method = \
360
+ self.weights_quantization_method
260
361
 
261
362
 
262
363
  class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
@@ -275,13 +376,15 @@ class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
275
376
  self.weights_quantization_method = weights_quantization_method
276
377
  self.attr_name = attr_name
277
378
 
278
- def apply(self, node: BaseNode, graph: Graph):
379
+ def apply(self, node: BaseNode, graph: Graph, fw_info: FrameworkInfo):
279
380
  """
280
381
  Change the node's weights quantization function.
281
382
 
282
383
  Args:
283
384
  node: Node object to change its threshold selection function.
284
385
  graph: Graph to apply the action on.
386
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
387
+ groups of layers by how they should be quantized, etc.)
285
388
 
286
389
  Returns:
287
390
  The node after its quantization function has been modified.
@@ -289,7 +392,18 @@ class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
289
392
 
290
393
  if self.weights_quantization_method is not None:
291
394
  for qc in node.candidates_quantization_cfg:
395
+
396
+ weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)
397
+
292
398
  attr_qc = qc.weights_quantization_cfg.get_attr_config(self.attr_name)
399
+ attr_qc.set_weights_quantization_params_fn(weights_quantization_params_fn)
400
+
401
+ weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
402
+
403
+ if weights_quantization_fn is None:
404
+ Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
405
+
406
+ attr_qc.set_weights_quantization_fn(weights_quantization_fn)
293
407
  attr_qc.weights_quantization_method = self.weights_quantization_method
294
408
 
295
409
 
@@ -308,13 +422,15 @@ class ReplaceLayer(BaseAction):
308
422
  self.layer_type = layer_type
309
423
  self.get_params_and_weights_fn = get_params_and_weights_fn
310
424
 
311
- def apply(self, node: BaseNode, graph: Graph):
425
+ def apply(self, node: BaseNode, graph: Graph, fw_info: FrameworkInfo):
312
426
  """
313
427
  Replacing node's layer type and configurations
314
428
 
315
429
  Args:
316
430
  node: Node object to replace or modify
317
431
  graph: Graph to apply the action on.
432
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
433
+ groups of layers by how they should be quantized, etc.)
318
434
 
319
435
  Returns:
320
436
  The node after its layer functionality has been modified.
@@ -14,17 +14,20 @@
14
14
  # ==============================================================================
15
15
  from typing import List
16
16
 
17
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
17
18
  from model_compression_toolkit.core.common.graph.base_graph import Graph
18
19
  from model_compression_toolkit.core.common.network_editors import EditRule
19
20
 
20
21
 
21
22
  def edit_network_graph(graph: Graph,
23
+ fw_info: FrameworkInfo,
22
24
  network_editor: List[EditRule]):
23
25
  """
24
26
  Apply a list of edit rules on a graph.
25
27
 
26
28
  Args:
27
29
  graph: The graph to edit.
30
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
28
31
  groups of layers by how they should be quantized, etc.)
29
32
  network_editor: List of edit rules to apply to the graph.
30
33
 
@@ -35,5 +38,5 @@ def edit_network_graph(graph: Graph,
35
38
  for edit_rule in network_editor:
36
39
  filtered_nodes = graph.filter(edit_rule.filter)
37
40
  for node in filtered_nodes:
38
- edit_rule.action.apply(node, graph)
41
+ edit_rule.action.apply(node, graph, fw_info)
39
42
  # return graph
@@ -26,14 +26,18 @@ class ChannelGrouping:
26
26
  based on their importance scores and SIMD group sizes.
27
27
  """
28
28
 
29
- def __init__(self, prunable_nodes: List[BaseNode]):
29
+ def __init__(self,
30
+ prunable_nodes: List[BaseNode],
31
+ fw_info: FrameworkInfo):
30
32
  """
31
33
  Initializes the ChannelGrouping with necessary information.
32
34
 
33
35
  Args:
34
36
  prunable_nodes: List of nodes that can be pruned.
37
+ fw_info: Framework-specific information and utilities.
35
38
  """
36
39
  self.prunable_nodes = prunable_nodes
40
+ self.fw_info = fw_info
37
41
  # Store for each node a list of numpy arrays. Each numpy array represents the
38
42
  # indices of the channels in an SIMD group.
39
43
  self._simd_groups_indices = {}
@@ -38,6 +38,7 @@ class GreedyMaskCalculator:
38
38
  """
39
39
  def __init__(self,
40
40
  prunable_nodes: List[BaseNode],
41
+ fw_info: FrameworkInfo,
41
42
  simd_groups_scores: Dict[BaseNode, np.ndarray],
42
43
  target_resource_utilization: ResourceUtilization,
43
44
  graph: Graph,
@@ -47,6 +48,7 @@ class GreedyMaskCalculator:
47
48
  """
48
49
  Args:
49
50
  prunable_nodes (List[BaseNode]): Nodes that are eligible for pruning.
51
+ fw_info (FrameworkInfo): Framework-specific information and utilities.
50
52
  simd_groups_scores (Dict[BaseNode, np.ndarray]): Importance scores for each SIMG group in a prunable node.
51
53
  target_resource_utilization (ResourceUtilization): The target resource utilization to achieve.
52
54
  graph (Graph): The computational graph of the model.
@@ -55,6 +57,7 @@ class GreedyMaskCalculator:
55
57
  simd_groups_indices (Dict[BaseNode, List[List[int]]]): Indices of SIMD groups in each node.
56
58
  """
57
59
  self.prunable_nodes = prunable_nodes
60
+ self.fw_info = fw_info
58
61
  self.target_resource_utilization = target_resource_utilization
59
62
  self.graph = graph
60
63
  self.fw_impl = fw_impl
@@ -64,11 +67,14 @@ class GreedyMaskCalculator:
64
67
  self.simd_groups_scores = simd_groups_scores
65
68
 
66
69
  self.oc_pruning_mask = PerSIMDGroupMask(prunable_nodes=prunable_nodes,
70
+ fw_info=fw_info,
67
71
  simd_groups_indices=simd_groups_indices)
68
72
 
69
73
  self.memory_calculator = MemoryCalculator(graph=graph,
74
+ fw_info=fw_info,
70
75
  fw_impl=fw_impl)
71
76
 
77
+
72
78
  def get_mask(self) -> Dict[BaseNode, np.ndarray]:
73
79
  """
74
80
  Retrieves the current pruning mask for each prunable node.
@@ -38,7 +38,8 @@ class LFHImportanceMetric(BaseImportanceMetric):
38
38
  graph: Graph,
39
39
  representative_data_gen: Callable,
40
40
  fw_impl: PruningFrameworkImplementation,
41
- pruning_config: PruningConfig):
41
+ pruning_config: PruningConfig,
42
+ fw_info: FrameworkInfo):
42
43
  """
43
44
  Initialize the LFHImportanceMetric instance.
44
45
 
@@ -47,11 +48,13 @@ class LFHImportanceMetric(BaseImportanceMetric):
47
48
  representative_data_gen (Callable): Function to generate representative data.
48
49
  fw_impl (PruningFrameworkImplementation): Implementation of pruning for the framework.
49
50
  pruning_config (PruningConfig): Configuration for pruning.
51
+ fw_info (FrameworkInfo): Framework-specific information.
50
52
  """
51
53
  self.float_graph = graph
52
54
  self.representative_data_gen = representative_data_gen
53
55
  self.fw_impl = fw_impl
54
56
  self.pruning_config = pruning_config
57
+ self.fw_info = fw_info
55
58
 
56
59
  # Initialize internal dictionaries for storing intermediate computations.
57
60
  self._entry_node_to_hessian_score = {}
@@ -155,7 +158,8 @@ class LFHImportanceMetric(BaseImportanceMetric):
155
158
  Dict[BaseNode, List[np.ndarray]]: Dictionary of entry nodes mapped to their SIMD group indices.
156
159
  """
157
160
  # Initialize channel grouping utility.
158
- channel_grouping = ChannelGrouping(prunable_nodes=list(entry_node_to_score.keys()))
161
+ channel_grouping = ChannelGrouping(prunable_nodes=list(entry_node_to_score.keys()),
162
+ fw_info=self.fw_info)
159
163
 
160
164
  channel_grouping.group_scores_by_simd_groups(entry_node_to_score)
161
165
  grouped_indices = channel_grouping.simd_groups_indices
@@ -245,14 +249,20 @@ class LFHImportanceMetric(BaseImportanceMetric):
245
249
  Returns:
246
250
  tuple: A tuple containing the kernel attribute, the number of output channels, and the axis of the output channels.
247
251
  """
252
+ kernel_attr = self.fw_info.get_kernel_op_attributes(entry_node.type)
253
+ # Ensure only one kernel attribute exists for the given node.
254
+ if len(kernel_attr) != 1:
255
+ Logger.critical(f"Expected a single attribute but found multiple attributes ({len(kernel_attr)}) for node {entry_node}.")
256
+ kernel_attr = kernel_attr[0]
257
+
248
258
  # Retrieve and validate the axis index for the output channels.
249
- oc_axis = entry_node.channel_axis.output
259
+ oc_axis, _ = self.fw_info.kernel_channels_mapping.get(entry_node.type)
250
260
  if oc_axis is None or int(oc_axis) != oc_axis:
251
261
  Logger.critical(f"Invalid output channel axis type for node {entry_node}: expected integer but got {oc_axis}.")
252
262
 
253
263
  # Get the number of output channels based on the kernel attribute and axis.
254
- num_oc = entry_node.get_weights_by_keys(entry_node.kernel_attr).shape[oc_axis]
255
- return entry_node.kernel_attr, num_oc, oc_axis
264
+ num_oc = entry_node.get_weights_by_keys(kernel_attr[0]).shape[oc_axis]
265
+ return kernel_attr, num_oc, oc_axis
256
266
 
257
267
  def _concatenate_tensors_by_indices(self,
258
268
  channels: List[np.ndarray],
@@ -35,8 +35,9 @@ class MaskIndicator(Enum):
35
35
  REMAINED = 1
36
36
 
37
37
 
38
+
38
39
  class PerChannelMask:
39
- def __init__(self, prunable_nodes: List[BaseNode]):
40
+ def __init__(self, prunable_nodes: List[BaseNode], fw_info: FrameworkInfo):
40
41
  """
41
42
  Initializes the PerChannelMask with prunable nodes and framework information.
42
43
  This class is responsible for maintaining and updating the pruning masks for each
@@ -45,8 +46,10 @@ class PerChannelMask:
45
46
 
46
47
  Args:
47
48
  prunable_nodes: List of nodes in the model that are subject to pruning.
49
+ fw_info: Framework-specific information required for pruning operations.
48
50
  """
49
51
  self.prunable_nodes = prunable_nodes
52
+ self.fw_info = fw_info
50
53
  self._mask = None # Initialize the mask dictionary
51
54
  self._init_masks() # Call to initialize masks for each prunable node
52
55
 
@@ -103,7 +106,8 @@ class PerChannelMask:
103
106
  Returns:
104
107
  int: Number of output channels for the node.
105
108
  """
106
- oc_axis = node.channel_axis.output
107
- num_oc = node.get_weights_by_keys(node.kernel_attr).shape[oc_axis]
109
+ kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)[0]
110
+ oc_axis = self.fw_info.kernel_channels_mapping.get(node.type)[0]
111
+ num_oc = node.get_weights_by_keys(kernel_attr).shape[oc_axis]
108
112
  return num_oc
109
113
 
@@ -24,10 +24,10 @@ from model_compression_toolkit.core.common.pruning.memory_calculator import Memo
24
24
  from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation
25
25
  from model_compression_toolkit.logger import Logger
26
26
 
27
-
28
27
  class PerSIMDGroupMask:
29
28
  def __init__(self,
30
29
  prunable_nodes: List[BaseNode],
30
+ fw_info: FrameworkInfo,
31
31
  simd_groups_indices: Dict[BaseNode, List[List[int]]]):
32
32
  """
33
33
  Initializes a mask calculator for SIMD groups in prunable nodes.
@@ -35,11 +35,13 @@ class PerSIMDGroupMask:
35
35
 
36
36
  Args:
37
37
  prunable_nodes: List of nodes that can be pruned.
38
+ fw_info: Framework-specific information.
38
39
  simd_groups_indices: A dictionary mapping each node to its SIMD groups' indices.
39
40
  """
40
41
  # Initialize the per-channel mask
41
- self.per_channel_mask = PerChannelMask(prunable_nodes=prunable_nodes)
42
+ self.per_channel_mask = PerChannelMask(prunable_nodes=prunable_nodes, fw_info=fw_info)
42
43
  self.prunable_nodes = prunable_nodes
44
+ self.fw_info = fw_info
43
45
  self.simd_groups_indices = simd_groups_indices
44
46
  self._mask_simd = None # Initialize the SIMD group mask dictionary
45
47
  self._init_masks() # Initialize masks for each prunable node
@@ -34,16 +34,18 @@ class MemoryCalculator:
34
34
  which is crucial for deploying models on memory-constrained devices or optimizing for computational efficiency.
35
35
  """
36
36
 
37
- def __init__(self, graph: Graph, fw_impl: PruningFrameworkImplementation):
37
+ def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: PruningFrameworkImplementation):
38
38
  """
39
39
  Initializes the MemoryCalculator with necessary information about the model's graph,
40
40
  framework-specific details, and pruning implementation.
41
41
 
42
42
  Args:
43
43
  graph (Graph): Computational graph of the model.
44
+ fw_info (FrameworkInfo): Contains framework-specific information.
44
45
  fw_impl (PruningFrameworkImplementation): Implementation details for pruning.
45
46
  """
46
47
  self.graph = graph
48
+ self.fw_info = fw_info
47
49
  self.fw_impl = fw_impl
48
50
 
49
51
  def get_pruned_graph_memory(self,
@@ -202,13 +204,19 @@ class MemoryCalculator:
202
204
  if node == section.exit_node:
203
205
  return masks.get(section.entry_node)
204
206
 
207
+ kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)
208
+ # Ensure only one kernel attribute exists for the given node.
209
+ if len(kernel_attr) != 1:
210
+ Logger.critical(f"Expected a single attribute, but found {len(kernel_attr)} attributes for node '{node}'. Ensure the node configuration is correct.")
211
+ kernel_attr = kernel_attr[0]
212
+
205
213
  # Retrieve and validate the axis index for the output channels.
206
- ic_axis = node.channel_axis.input
214
+ _, ic_axis = self.fw_info.kernel_channels_mapping.get(node.type)
207
215
  if ic_axis is None or int(ic_axis) != ic_axis:
208
216
  Logger.critical(f"Invalid input channel axis type for node '{node}': expected integer but got '{ic_axis}'.")
209
217
 
210
218
  # Get the number of output channels based on the kernel attribute and axis.
211
- num_ic = node.get_weights_by_keys(node.kernel_attr).shape[ic_axis]
219
+ num_ic = node.get_weights_by_keys(kernel_attr).shape[ic_axis]
212
220
  mask = np.ones(num_ic, dtype=bool)
213
221
  return mask
214
222
 
@@ -281,7 +289,7 @@ class MemoryCalculator:
281
289
  int: The total number of parameters in the node after pruning.
282
290
  """
283
291
  total_params = 0
284
- attributes_and_oc_axis = self.fw_impl.attrs_oi_channels_info_for_pruning(node)
292
+ attributes_and_oc_axis = self.fw_impl.attrs_oi_channels_info_for_pruning(node, self.fw_info)
285
293
 
286
294
  # Iterate over the node's weights and apply pruning based on the masks.
287
295
  for w_attr, w in node.weights.items():
@@ -303,7 +311,7 @@ class MemoryCalculator:
303
311
  num_oc = np.sum(output_mask)
304
312
  else:
305
313
  # Get the node channel axis from framework info
306
- channel_axis = self.fw_impl.default_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis
314
+ channel_axis = self.fw_info.out_channel_axis_mapping.get(node.type)
307
315
  if channel_axis is None:
308
316
  Logger.critical(f"The channel axis is undefined. Please ensure the channel axis is explicitly defined for node {node.type} in the framework info.")
309
317
 
@@ -27,6 +27,7 @@ from model_compression_toolkit.logger import Logger
27
27
 
28
28
  def build_pruned_graph(graph: Graph,
29
29
  masks: Dict[BaseNode, np.ndarray],
30
+ fw_info: FrameworkInfo,
30
31
  fw_impl: FrameworkImplementation) -> Graph:
31
32
  """
32
33
  Prunes the provided graph according to the given pruning output-channels masks.
@@ -34,6 +35,7 @@ def build_pruned_graph(graph: Graph,
34
35
  Args:
35
36
  graph: The original computational graph to be pruned.
36
37
  masks: A dictionary mapping each prunable node to its pruning mask.
38
+ fw_info: Framework-specific information object.
37
39
  fw_impl: Framework-specific implementation object.
38
40
 
39
41
  Returns:
@@ -64,7 +66,8 @@ def build_pruned_graph(graph: Graph,
64
66
  section_mask = PruningSectionMask(entry_node_oc_mask=mask,
65
67
  exit_node_ic_mask=mask)
66
68
  pruning_section.apply_inner_section_mask(section_mask,
67
- fw_impl)
69
+ fw_impl,
70
+ fw_info)
68
71
 
69
72
  return graph_to_prune
70
73
 
@@ -40,6 +40,7 @@ class Pruner:
40
40
  """
41
41
  def __init__(self,
42
42
  float_graph: Graph,
43
+ fw_info: FrameworkInfo,
43
44
  fw_impl: PruningFrameworkImplementation,
44
45
  target_resource_utilization: ResourceUtilization,
45
46
  representative_data_gen: Callable,
@@ -48,6 +49,7 @@ class Pruner:
48
49
  """
49
50
  Args:
50
51
  float_graph (Graph): The floating-point representation of the model's computation graph.
52
+ fw_info (FrameworkInfo): Contains metadata and helper functions for the framework.
51
53
  fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.
52
54
  target_resource_utilization (ResourceUtilization): The target resource utilization to be achieved after pruning.
53
55
  representative_data_gen (Callable): Generator function for representative dataset used in pruning analysis.
@@ -55,6 +57,7 @@ class Pruner:
55
57
  target_platform_capabilities (FrameworkQuantizationCapabilities): Object encapsulating the capabilities of the target hardware platform.
56
58
  """
57
59
  self.float_graph = float_graph
60
+ self.fw_info = fw_info
58
61
  self.fw_impl = fw_impl
59
62
  self.target_resource_utilization = target_resource_utilization
60
63
  self.representative_data_gen = representative_data_gen
@@ -81,6 +84,7 @@ class Pruner:
81
84
  # Apply Greedy strategy to compute masks based on importance scores.
82
85
  if self.pruning_config.channels_filtering_strategy == ChannelsFilteringStrategy.GREEDY:
83
86
  mask_calculator = GreedyMaskCalculator(entry_nodes,
87
+ self.fw_info,
84
88
  self.simd_scores,
85
89
  self.target_resource_utilization,
86
90
  self.float_graph,
@@ -95,6 +99,7 @@ class Pruner:
95
99
  Logger.info("Start pruning graph...")
96
100
  _pruned_graph = build_pruned_graph(self.float_graph,
97
101
  self.per_oc_mask,
102
+ self.fw_info,
98
103
  self.fw_impl)
99
104
  return _pruned_graph
100
105
 
@@ -111,7 +116,7 @@ class Pruner:
111
116
  # Retrieve and initialize the importance metric.
112
117
  im = get_importance_metric(self.pruning_config.importance_metric, graph=self.float_graph,
113
118
  representative_data_gen=self.representative_data_gen, fw_impl=self.fw_impl,
114
- pruning_config=self.pruning_config)
119
+ pruning_config=self.pruning_config, fw_info=self.fw_info)
115
120
  entry_node_to_simd_score, simd_groups_indices = im.get_entry_node_to_simd_score(entry_nodes)
116
121
  return entry_node_to_simd_score, simd_groups_indices
117
122