mct-nightly 2.4.0.20250924.535__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -34,7 +34,6 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
34
34
  NodeFrameworkAttrMatcher
35
35
  from model_compression_toolkit.core.common.substitutions.shift_negative_activation import \
36
36
  apply_shift_negative_correction
37
- from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
38
37
  from model_compression_toolkit.core.keras.constants import KERNEL_SIZE, STRIDES, ACTIVATION, SWISH, \
39
38
  SELU, GELU, FUNCTION, ADD, PAD
40
39
  from model_compression_toolkit.core.keras.constants import NEGATIVE_SLOPE, PADDING, PAD_SAME, PAD_VALID, BIAS, USE_BIAS
@@ -228,13 +227,15 @@ def is_padding_node_and_node_has_padding(pad_node_to_consider: BaseNode,
228
227
 
229
228
 
230
229
  def keras_apply_shift_negative_correction(graph: Graph,
231
- core_config: CoreConfig) -> Graph:
230
+ core_config: CoreConfig,
231
+ fw_info: FrameworkInfo) -> Graph:
232
232
  """
233
233
  Apply shift negative correction (SNC) on a graph built from a Keras model.
234
234
 
235
235
  Args:
236
236
  graph: Graph to apply SNC on.
237
237
  core_config: Quantization configuration.
238
+ fw_info: FrameworkInfo object with information about the specific framework's module.
238
239
 
239
240
  Returns:
240
241
  Graph after SNC.
@@ -243,6 +244,7 @@ def keras_apply_shift_negative_correction(graph: Graph,
243
244
 
244
245
  return apply_shift_negative_correction(graph,
245
246
  core_config,
247
+ fw_info,
246
248
  snc_node,
247
249
  linear_node,
248
250
  bypass_node,
@@ -253,6 +255,5 @@ def keras_apply_shift_negative_correction(graph: Graph,
253
255
  is_padding_node_and_node_has_padding,
254
256
  PADDING,
255
257
  BIAS,
256
- USE_BIAS,
257
- get_activation_quantization_fn_factory
258
+ USE_BIAS
258
259
  )
@@ -22,6 +22,7 @@ from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESS
22
22
  from model_compression_toolkit.core.common import Graph
23
23
  from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
24
24
  from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
25
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
25
26
  from model_compression_toolkit.core.keras.hessian.hessian_scores_calculator_keras import HessianScoresCalculatorKeras
26
27
  from model_compression_toolkit.logger import Logger
27
28
 
@@ -94,11 +95,20 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
94
95
  for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
95
96
 
96
97
  # Check if the target node's layer type is supported.
97
- if not ipt_node.kernel_attr:
98
+ if not DEFAULT_KERAS_INFO.is_kernel_op(ipt_node.type):
98
99
  Logger.critical(f"Hessian information with respect to weights is not supported for "
99
100
  f"{ipt_node.type} layers.") # pragma: no cover
100
101
 
101
- weight_tensor = getattr(model.get_layer(ipt_node.name), ipt_node.kernel_attr)
102
+ # Get the weight attributes for the target node type
103
+ weight_attributes = DEFAULT_KERAS_INFO.get_kernel_op_attributes(ipt_node.type)
104
+
105
+ # Get the weight tensor for the target node
106
+ if len(weight_attributes) != 1: # pragma: no cover
107
+ Logger.critical(
108
+ f"Hessian-based scoring with respect to weights is currently supported only for nodes with "
109
+ f"a single weight attribute. Found {len(weight_attributes)} attributes.")
110
+
111
+ weight_tensor = getattr(model.get_layer(ipt_node.name), weight_attributes[0])
102
112
 
103
113
  if j == 0:
104
114
  # On the first iteration we store the weight_tensor shape for later reshaping the results
@@ -106,7 +116,7 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
106
116
  tensors_original_shape.append(weight_tensor.shape)
107
117
 
108
118
  # Get the output channel index (needed for HessianInfoGranularity.PER_OUTPUT_CHANNEL case)
109
- output_channel_axis = ipt_node.channel_axis.output
119
+ output_channel_axis, _ = DEFAULT_KERAS_INFO.kernel_channels_mapping.get(ipt_node.type)
110
120
 
111
121
  # Get number of scores that should be calculated by the granularity.
112
122
  num_of_scores = self._get_num_scores_by_granularity(weight_tensor,
@@ -65,6 +65,7 @@ from model_compression_toolkit.core.common import Graph, BaseNode
65
65
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
66
66
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
67
67
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
68
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
68
69
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.activation_decomposition import \
69
70
  ActivationDecomposition
70
71
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.matmul_substitution import \
@@ -174,16 +175,18 @@ class KerasImplementation(FrameworkImplementation):
174
175
  graph: Graph,
175
176
  mode: ModelBuilderMode,
176
177
  append2output: List[Any] = None,
178
+ fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
177
179
  return_float_outputs: bool = False) -> Tuple:
178
180
  """
179
181
  Build a Keras model from a graph.
180
- The mode determines how the model should be built. append2output is a list of Nodes
182
+ The mode determines how the model should be build. append2output is a list of Nodes
181
183
  to set as the model outputs.
182
184
 
183
185
  Args:
184
186
  graph: Graph to build the model from it.
185
187
  mode: Mode for how to build the model.
186
188
  append2output: List of Nodes to set as the model's outputs.
189
+ fw_info: FrameworkInfo object with information about the specific framework's model
187
190
  return_float_outputs (bool): whether to return outputs before or after quantization nodes (default)
188
191
  Returns:
189
192
  A tuple with the model and additional relevant supporting objects.
@@ -192,6 +195,7 @@ class KerasImplementation(FrameworkImplementation):
192
195
  keras_model_builder = get_keras_model_builder(mode)
193
196
  return keras_model_builder(graph=graph,
194
197
  append2output=append2output,
198
+ fw_info=fw_info,
195
199
  return_float_outputs=return_float_outputs).build_model()
196
200
 
197
201
  def run_model_inference(self,
@@ -223,57 +227,65 @@ class KerasImplementation(FrameworkImplementation):
223
227
 
224
228
  def shift_negative_correction(self,
225
229
  graph: Graph,
226
- core_config: CoreConfig) -> Graph:
230
+ core_config: CoreConfig,
231
+ fw_info: FrameworkInfo) -> Graph:
227
232
  """
228
233
  Apply shift negative correction (SNC) on a graph.
229
234
 
230
235
  Args:
231
236
  graph: Graph to apply SNC on.
232
237
  core_config: Quantization configuration.
238
+ fw_info: FrameworkInfo object with information about the specific framework's model.
233
239
 
234
240
  Returns:
235
241
  Graph after SNC.
236
242
  """
237
243
  return keras_apply_shift_negative_correction(graph,
238
- core_config)
244
+ core_config,
245
+ fw_info)
239
246
 
240
247
  def compute_activation_bias_correction(self,
241
248
  graph: Graph,
242
- quant_config: QuantizationConfig):
249
+ quant_config: QuantizationConfig,
250
+ fw_info: FrameworkInfo):
243
251
  """
244
252
  Compute activation bias correction on a graph.
245
253
 
246
254
  Args:
247
255
  graph: Graph to apply activation bias correction on.
248
256
  quant_config: QuantizationConfig of how the model should be quantized.
257
+ fw_info: FrameworkInfo object with information about the specific framework's model.
249
258
 
250
259
  Returns:
251
260
  Graph after activation bias correction computing.
252
261
  """
253
262
  return keras_compute_activation_bias_correction_of_graph(graph=graph,
254
263
  quant_config=quant_config,
264
+ fw_info=fw_info,
255
265
  fw_impl=self)
256
266
 
257
267
  def get_substitutions_channel_equalization(self,
258
- quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
268
+ quant_config: QuantizationConfig,
269
+ fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
259
270
  """
260
271
  Return a list of the framework substitutions used for channel equalization.
261
272
 
262
273
  Args:
263
274
  quant_config: QuantizationConfig to determine which substitutions to return.
275
+ fw_info: FrameworkInfo object with information about the specific framework's model.
264
276
 
265
277
  Returns:
266
278
  A list of the framework substitutions used after we collect statistics.
267
279
  """
268
280
  substitutions_list = []
269
281
  if quant_config.activation_channel_equalization:
270
- substitutions_list.extend([ScaleEqualization(quant_config),
271
- ScaleEqualizationWithPad(quant_config),
272
- ScaleEqualizationMidActivation(quant_config),
273
- ScaleEqualizationMidActivationWithPad(quant_config)])
282
+ substitutions_list.extend([ScaleEqualization(quant_config, fw_info),
283
+ ScaleEqualizationWithPad(quant_config, fw_info),
284
+ ScaleEqualizationMidActivation(quant_config, fw_info),
285
+ ScaleEqualizationMidActivationWithPad(quant_config, fw_info)])
274
286
  return substitutions_list
275
287
 
276
- def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
288
+ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List[common.BaseSubstitution]:
277
289
  """
278
290
 
279
291
  Returns: A list of the framework substitutions used to prepare the graph.
@@ -357,8 +369,8 @@ class KerasImplementation(FrameworkImplementation):
357
369
  if quant_config.softmax_shift:
358
370
  substitutions_list.append(keras_softmax_shift())
359
371
  if quant_config.input_scaling:
360
- substitutions_list.append(InputScaling(quant_config))
361
- substitutions_list.append(InputScalingWithPad(quant_config))
372
+ substitutions_list.append(InputScaling())
373
+ substitutions_list.append(InputScalingWithPad())
362
374
  if quant_config.concat_threshold_update:
363
375
  substitutions_list.append(ConcatThresholdUpdate())
364
376
  return substitutions_list
@@ -390,19 +402,22 @@ class KerasImplementation(FrameworkImplementation):
390
402
 
391
403
  def get_node_prior_info(self,
392
404
  node: BaseNode,
405
+ fw_info: FrameworkInfo,
393
406
  graph: Graph) -> NodePriorInfo:
394
407
  """
395
408
  Get a NodePriorInfo object for a node that represents a Keras layer.
396
409
 
397
410
  Args:
398
411
  node: Node to get its prior info.
412
+ fw_info: Framework specific information needed to create the prior info of the node.
399
413
  graph: Graph to check the next node type.
400
414
 
401
415
  Returns:
402
416
  NodePriorInfo with information about the node.
403
417
  """
404
418
 
405
- return create_node_prior_info(node=node, graph=graph)
419
+ return create_node_prior_info(node=node,
420
+ fw_info=fw_info, graph=graph)
406
421
 
407
422
  def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
408
423
  """
@@ -515,19 +530,23 @@ class KerasImplementation(FrameworkImplementation):
515
530
  return True
516
531
 
517
532
  def get_node_mac_operations(self,
518
- node: BaseNode) -> float:
533
+ node: BaseNode,
534
+ fw_info: FrameworkInfo) -> float:
519
535
  """
520
536
  Gets the MAC operation count for a given operation.
521
537
 
522
538
  Args:
523
539
  node: A graph node that wraps the operation for which the MAC count is computed.
540
+ fw_info: FrameworkInfo object with information about the Keras model.
524
541
 
525
542
  Returns: The MAC count og the operation
526
543
  """
527
- if node.kernel_attr is None:
544
+ kernels = fw_info.get_kernel_op_attributes(node.type)
545
+ if not kernels or kernels[0] is None:
528
546
  return 0
529
547
 
530
- kernel_shape = node.get_weights_by_keys(node.kernel_attr).shape
548
+ assert len(kernels) == 1
549
+ kernel_shape = node.get_weights_by_keys(kernels[0]).shape
531
550
 
532
551
  if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose) or node.is_match_type(DepthwiseConv2D):
533
552
  h, w = node.get_output_shapes_list()[0][-3:-1]
@@ -535,7 +554,8 @@ class KerasImplementation(FrameworkImplementation):
535
554
 
536
555
  if node.is_match_type(Dense):
537
556
  # IN * OUT * (all previous dims[:-1])
538
- return node.get_total_output_params() * kernel_shape[node.channel_axis.input]
557
+ _, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
558
+ return node.get_total_output_params() * kernel_shape[input_channel_axis]
539
559
 
540
560
  return 0
541
561
 
@@ -0,0 +1,38 @@
1
+ from tensorflow.keras.models import Model
2
+
3
+ from model_compression_toolkit.core import FrameworkInfo
4
+ from model_compression_toolkit.core.common.framework_info import ChannelAxis
5
+ from model_compression_toolkit.core.common.model_validation import ModelValidation
6
+ from model_compression_toolkit.core.keras.constants import CHANNELS_FORMAT, CHANNELS_FORMAT_LAST, CHANNELS_FORMAT_FIRST
7
+
8
+
9
+ class KerasModelValidation(ModelValidation):
10
+ """
11
+ Class to define validation methods in order to validate the received Keras model to quantize.
12
+ """
13
+
14
+ def __init__(self, model: Model, fw_info: FrameworkInfo):
15
+ """
16
+ Initialize a KerasModelValidation object.
17
+
18
+ Args:
19
+ model: Keras model to check its validity.
20
+ fw_info: Information about the framework of the model (Keras).
21
+ """
22
+
23
+ super(KerasModelValidation, self).__init__(model=model,
24
+ fw_info=fw_info)
25
+
26
+ def validate_output_channel_consistency(self):
27
+ """
28
+
29
+ Validate that output channels index in all layers of the model are the same.
30
+ If the model has layers with different output channels index, an exception is thrown.
31
+
32
+ """
33
+ for layer in self.model.layers:
34
+ data_format = layer.get_config().get(CHANNELS_FORMAT)
35
+ if data_format is not None:
36
+ assert (data_format == CHANNELS_FORMAT_LAST and self.fw_info.out_channel_axis_mapping.get(layer) == ChannelAxis.NHWC.value
37
+ or data_format == CHANNELS_FORMAT_FIRST and self.fw_info.out_channel_axis_mapping.get(layer) == ChannelAxis.NCHW.value), \
38
+ f'Model can not have layers with different data formats.'
@@ -17,19 +17,22 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
17
17
 
18
18
 
19
19
  def create_node_prior_info(node: BaseNode,
20
+ fw_info: FrameworkInfo,
20
21
  graph: Graph):
21
22
  """
22
23
  Create a NodePriorInfo object for a given node.
23
24
 
24
25
  Args:
25
26
  node: Node to create its prior info.
27
+ fw_info: Information about a specific framework the node was generated from.
26
28
  graph: Graph to check the next node type.
27
29
 
28
30
  Returns:
29
31
  NodePriorInfo object with info about the node.
30
32
  """
31
33
 
32
- min_output, max_output = _get_min_max_outputs(node=node)
34
+ min_output, max_output = _get_min_max_outputs(node=node,
35
+ fw_info=fw_info)
33
36
 
34
37
  mean_output, std_output = _get_mean_std_outputs(node=node,
35
38
  graph=graph)
@@ -39,12 +42,14 @@ def create_node_prior_info(node: BaseNode,
39
42
  std_output=std_output)
40
43
 
41
44
 
42
- def _get_min_max_outputs(node: BaseNode) -> Tuple[Any, Any]:
45
+ def _get_min_max_outputs(node: BaseNode,
46
+ fw_info: FrameworkInfo) -> Tuple[Any, Any]:
43
47
  """
44
48
  Return the min/max output values of a node if known.
45
49
  If one of them (or both of them) is unknown - return None instead of a value.
46
50
  Args:
47
51
  node: Node to create its prior info.
52
+ fw_info: Information about a specific framework the node was generated from.
48
53
 
49
54
  Returns:
50
55
  Min/max output values if known.
@@ -53,8 +58,12 @@ def _get_min_max_outputs(node: BaseNode) -> Tuple[Any, Any]:
53
58
 
54
59
  if node.is_match_type(ReLU):
55
60
  min_output = node.framework_attr[THRESHOLD] if node.framework_attr[NEGATIVE_SLOPE] == 0 else None
56
- else:
57
- min_output, max_output = node.minmax
61
+
62
+ elif fw_info.layers_has_min_max(node.type):
63
+ min_output, max_output = fw_info.layer_min_max_mapping[node.type]
64
+
65
+ elif node.is_match_type(Activation) and fw_info.activation_has_min_max(node.framework_attr[ACTIVATION]):
66
+ min_output, max_output = fw_info.activation_min_max_mapping[node.framework_attr[ACTIVATION]]
58
67
 
59
68
  return min_output, max_output
60
69
 
@@ -23,7 +23,6 @@ from model_compression_toolkit.core.common.mixed_precision.configurable_quantize
23
23
  verify_candidates_descending_order, init_activation_quantizers
24
24
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
25
25
  CandidateNodeQuantizationConfig
26
- from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
27
26
  from model_compression_toolkit.logger import Logger
28
27
 
29
28
  import tensorflow as tf
@@ -68,7 +67,7 @@ class ConfigurableActivationQuantizer(BaseKerasInferableQuantizer):
68
67
  if qc.activation_quantization_cfg.quant_mode != node_q_cfg[0].activation_quantization_cfg.quant_mode:
69
68
  Logger.critical("Unsupported configuration: Mixing candidates with differing activation quantization states (enabled/disabled).") # pragma: no cover
70
69
 
71
- self.activation_quantizers = init_activation_quantizers(self.node_q_cfg, get_activation_quantization_fn_factory)
70
+ self.activation_quantizers = init_activation_quantizers(self.node_q_cfg)
72
71
  self.active_quantization_config_index = max_candidate_idx # initialize with first config as default
73
72
 
74
73
  def set_active_activation_quantizer(self, index: Optional[int]):
@@ -19,6 +19,7 @@ from model_compression_toolkit.core.common.pruning.pruning_framework_implementat
19
19
  PruningFrameworkImplementation
20
20
  from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
21
21
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
22
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
23
  from model_compression_toolkit.core.common import BaseNode
23
24
  from model_compression_toolkit.core.keras.constants import BIAS, GROUPS, FILTERS, UNITS, USE_BIAS
24
25
  import keras
@@ -28,10 +29,6 @@ import numpy as np
28
29
  from model_compression_toolkit.logger import Logger
29
30
 
30
31
 
31
- # default output channel axis to use when it's not defined in node's fw_info.
32
- _default_output_channel_axis = -1
33
-
34
-
35
32
  class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementation):
36
33
  """
37
34
  Implementation of the PruningFramework for the Keras framework. This class provides
@@ -41,23 +38,27 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
41
38
 
42
39
  def prune_entry_node(self,
43
40
  node: BaseNode,
44
- output_mask: np.ndarray):
41
+ output_mask: np.ndarray,
42
+ fw_info: FrameworkInfo):
45
43
  """
46
44
  Prunes the entry node of a model in Keras.
47
45
 
48
46
  Args:
49
47
  node (BaseNode): The entry node to be pruned.
50
48
  output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
49
+ fw_info (FrameworkInfo): Framework-specific information object.
51
50
 
52
51
  """
53
52
  return _prune_keras_edge_node(node=node,
54
53
  mask=output_mask,
54
+ fw_info=fw_info,
55
55
  is_exit_node=False)
56
56
 
57
57
  def prune_intermediate_node(self,
58
58
  node: BaseNode,
59
59
  input_mask: np.ndarray,
60
- output_mask: np.ndarray):
60
+ output_mask: np.ndarray,
61
+ fw_info: FrameworkInfo):
61
62
  """
62
63
  Prunes an intermediate node in a Keras model.
63
64
 
@@ -65,6 +66,7 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
65
66
  node (BaseNode): The intermediate node to be pruned.
66
67
  input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
67
68
  output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
69
+ fw_info (FrameworkInfo): Framework-specific information object.
68
70
 
69
71
  """
70
72
  _edit_node_input_shape(input_mask, node)
@@ -77,17 +79,20 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
77
79
 
78
80
  def prune_exit_node(self,
79
81
  node: BaseNode,
80
- input_mask: np.ndarray):
82
+ input_mask: np.ndarray,
83
+ fw_info: FrameworkInfo):
81
84
  """
82
85
  Prunes the exit node of a model in Keras.
83
86
 
84
87
  Args:
85
88
  node (BaseNode): The exit node to be pruned.
86
89
  input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
90
+ fw_info (FrameworkInfo): Framework-specific information object.
87
91
 
88
92
  """
89
93
  return _prune_keras_edge_node(node=node,
90
94
  mask=input_mask,
95
+ fw_info=fw_info,
91
96
  is_exit_node=True)
92
97
 
93
98
  def is_node_entry_node(self, node: BaseNode) -> bool:
@@ -104,19 +109,22 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
104
109
 
105
110
  def is_node_exit_node(self,
106
111
  node: BaseNode,
107
- corresponding_entry_node: BaseNode) -> bool:
112
+ corresponding_entry_node: BaseNode,
113
+ fw_info: FrameworkInfo) -> bool:
108
114
  """
109
115
  Determines whether a node is an exit node in a Keras model.
110
116
 
111
117
  Args:
112
118
  node (BaseNode): The node to be checked.
113
119
  corresponding_entry_node (BaseNode): The entry node of the pruning section that is checked.
120
+ fw_info (FrameworkInfo): Framework-specific information object.
114
121
 
115
122
  Returns:
116
123
  bool: Boolean indicating if the node is an exit node.
117
124
  """
118
125
  return _is_keras_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node,
119
- corresponding_entry_node)
126
+ corresponding_entry_node,
127
+ fw_info)
120
128
 
121
129
  def is_node_intermediate_pruning_section(self, node: BaseNode) -> bool:
122
130
  """
@@ -135,7 +143,8 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
135
143
  keras.layers.Dense]
136
144
 
137
145
  def attrs_oi_channels_info_for_pruning(self,
138
- node: BaseNode) -> Dict[str, Tuple[int, int]]:
146
+ node: BaseNode,
147
+ fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
139
148
  """
140
149
  Retrieves the attributes of a given node along with the output/input (OI) channel axis
141
150
  for each attribute used to prune these attributes.
@@ -152,6 +161,7 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
152
161
 
153
162
  Args:
154
163
  node (BaseNode): The node from the computational graph.
164
+ fw_info (FrameworkInfo): Contains framework-specific information and utilities.
155
165
 
156
166
  Returns:
157
167
  Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'kernel' or 'bias')
@@ -159,8 +169,13 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
159
169
  """
160
170
 
161
171
  attributes_with_axis = {}
162
- if node.kernel_attr:
163
- attributes_with_axis[node.kernel_attr] = (node.channel_axis.output, node.channel_axis.input)
172
+ if fw_info.is_kernel_op(node.type):
173
+ kernel_attributes = fw_info.get_kernel_op_attributes(node.type)
174
+ if kernel_attributes is None or len(kernel_attributes)==0:
175
+ Logger.critical(f"Expected kernel attributes for operation for node type {node.type}, found None or empty.")
176
+
177
+ for attr in kernel_attributes:
178
+ attributes_with_axis[attr] = fw_info.kernel_channels_mapping.get(node.type)
164
179
 
165
180
  # Bias is a vector at the length of the number of output channels.
166
181
  # For this reason, input channel axis is irrelevant to the bias attribute.
@@ -176,10 +191,6 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
176
191
 
177
192
  return attributes_with_axis
178
193
 
179
- @property
180
- def default_output_channel_axis(self):
181
- return _default_output_channel_axis
182
-
183
194
 
184
195
  def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
185
196
  """
@@ -205,6 +216,7 @@ def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
205
216
 
206
217
  def _prune_keras_edge_node(node: BaseNode,
207
218
  mask: np.ndarray,
219
+ fw_info: FrameworkInfo,
208
220
  is_exit_node: bool):
209
221
  """
210
222
  Prunes the given Keras node by applying the mask to the node's weights (kernels and biases).
@@ -213,18 +225,21 @@ def _prune_keras_edge_node(node: BaseNode,
213
225
  Args:
214
226
  node: The node to be pruned.
215
227
  mask: The pruning mask to be applied.
228
+ fw_info: Framework-specific information object.
216
229
  is_exit_node: A boolean indicating whether the node is an exit node.
217
230
 
218
231
  """
219
232
 
220
233
  # Retrieve the kernel attribute and the axes to prune.
221
- axis_to_prune = node.channel_axis.input if is_exit_node else node.channel_axis.output
222
- kernel = node.get_weights_by_keys(node.kernel_attr)
234
+ kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
235
+ io_axis = fw_info.kernel_channels_mapping.get(node.type)
236
+ axis_to_prune = io_axis[int(is_exit_node)]
237
+ kernel = node.get_weights_by_keys(kernel_attr)
223
238
  # Convert mask to boolean.
224
239
  mask_bool = mask.astype(bool)
225
240
 
226
241
  pruned_kernel = kernel.compress(mask_bool, axis=axis_to_prune)
227
- node.set_weights_by_keys(name=node.kernel_attr, tensor=pruned_kernel)
242
+ node.set_weights_by_keys(name=kernel_attr, tensor=pruned_kernel)
228
243
 
229
244
  if not is_exit_node and node.framework_attr[USE_BIAS]:
230
245
  # Prune the bias if applicable and it's an entry node.
@@ -27,15 +27,14 @@ if FOUND_TF:
27
27
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
28
28
  AttachTpcToKeras
29
29
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
30
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
30
31
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
31
- from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
32
32
  from tensorflow.keras.models import Model
33
33
 
34
34
  from model_compression_toolkit import get_target_platform_capabilities
35
35
 
36
36
  KERAS_DEFAULT_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
37
37
 
38
- @set_keras_info
39
38
  def keras_resource_utilization_data(in_model: Model,
40
39
  representative_data_gen: Callable,
41
40
  core_config: CoreConfig = CoreConfig(
@@ -94,6 +93,7 @@ if FOUND_TF:
94
93
  representative_data_gen,
95
94
  core_config,
96
95
  target_platform_capabilities,
96
+ DEFAULT_KERAS_INFO,
97
97
  fw_impl)
98
98
 
99
99
  else:
@@ -25,7 +25,7 @@ else:
25
25
 
26
26
  from model_compression_toolkit.core import QuantizationConfig
27
27
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
28
- from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
28
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
29
29
  from model_compression_toolkit.core.common import Graph
30
30
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
31
31
  from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \
@@ -43,6 +43,7 @@ def activation_bias_correction_node_matchers():
43
43
 
44
44
  def keras_compute_activation_bias_correction_of_graph(graph: Graph,
45
45
  quant_config: QuantizationConfig,
46
+ fw_info: FrameworkInfo,
46
47
  fw_impl: FrameworkImplementation) -> Graph:
47
48
  """
48
49
  Compute the activation bias correction term for graph based on a Keras model.
@@ -50,6 +51,7 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph,
50
51
  Args:
51
52
  graph: Graph with nodes to compute the activation bias correction.
52
53
  quant_config: QuantizationConfig of how the model should be quantized.
54
+ fw_info: Framework info like lists of nodes their kernel should quantized.
53
55
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
54
56
 
55
57
  Returns:
@@ -57,9 +59,9 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph,
57
59
  """
58
60
  graph = compute_activation_bias_correction_of_graph(graph=graph,
59
61
  quant_config=quant_config,
62
+ fw_info=fw_info,
60
63
  fw_impl=fw_impl,
61
64
  activation_bias_correction_node_matchers=
62
65
  activation_bias_correction_node_matchers,
63
- kernel_size=KERNEL_SIZE,
64
- get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
66
+ kernel_size=KERNEL_SIZE)
65
67
  return graph