mct-nightly 2.4.0.20250616.616__py3-none-any.whl → 2.4.0.20250618.606__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/RECORD +120 -120
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +2 -5
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
  6. model_compression_toolkit/core/common/framework_implementation.py +10 -22
  7. model_compression_toolkit/core/common/framework_info.py +105 -68
  8. model_compression_toolkit/core/common/graph/base_graph.py +15 -42
  9. model_compression_toolkit/core/common/graph/base_node.py +103 -42
  10. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  11. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
  12. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
  16. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
  17. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
  18. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
  19. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
  20. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
  21. model_compression_toolkit/core/common/model_collector.py +10 -20
  22. model_compression_toolkit/core/common/model_validation.py +1 -4
  23. model_compression_toolkit/core/common/network_editors/actions.py +14 -38
  24. model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
  25. model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
  26. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
  27. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
  28. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
  29. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
  30. model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
  31. model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
  32. model_compression_toolkit/core/common/pruning/pruner.py +1 -6
  33. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
  34. model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
  35. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
  36. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
  37. model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
  39. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
  40. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
  41. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
  42. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
  43. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
  44. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
  45. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
  46. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
  47. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
  48. model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
  49. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
  50. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
  51. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  52. model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
  53. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
  54. model_compression_toolkit/core/graph_prep_runner.py +2 -16
  55. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
  56. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
  57. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
  58. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
  59. model_compression_toolkit/core/keras/default_framework_info.py +138 -87
  60. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
  61. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
  62. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
  63. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
  64. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  65. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
  66. model_compression_toolkit/core/keras/keras_implementation.py +15 -35
  67. model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
  68. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
  69. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
  70. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +0 -2
  71. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
  72. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
  73. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
  74. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
  75. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
  76. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
  77. model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
  78. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
  79. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
  80. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  81. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
  82. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
  83. model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
  84. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
  85. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  86. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
  87. model_compression_toolkit/core/quantization_prep_runner.py +4 -9
  88. model_compression_toolkit/core/runner.py +5 -15
  89. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  90. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  91. model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
  92. model_compression_toolkit/gptq/common/gptq_training.py +1 -8
  93. model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
  94. model_compression_toolkit/gptq/keras/graph_info.py +4 -6
  95. model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
  96. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  97. model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
  98. model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
  99. model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
  100. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  101. model_compression_toolkit/gptq/runner.py +1 -7
  102. model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
  103. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
  104. model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
  105. model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
  106. model_compression_toolkit/ptq/runner.py +1 -4
  107. model_compression_toolkit/qat/common/qat_config.py +2 -6
  108. model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
  109. model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
  110. model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
  111. model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
  112. model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
  113. model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
  114. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
  115. model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
  116. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
  117. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
  118. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/WHEEL +0 -0
  119. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/licenses/LICENSE.md +0 -0
  120. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/top_level.txt +0 -0
@@ -227,15 +227,13 @@ def is_padding_node_and_node_has_padding(pad_node_to_consider: BaseNode,
227
227
 
228
228
 
229
229
  def keras_apply_shift_negative_correction(graph: Graph,
230
- core_config: CoreConfig,
231
- fw_info: FrameworkInfo) -> Graph:
230
+ core_config: CoreConfig) -> Graph:
232
231
  """
233
232
  Apply shift negative correction (SNC) on a graph built from a Keras model.
234
233
 
235
234
  Args:
236
235
  graph: Graph to apply SNC on.
237
236
  core_config: Quantization configuration.
238
- fw_info: FrameworkInfo object with information about the specific framework's module.
239
237
 
240
238
  Returns:
241
239
  Graph after SNC.
@@ -244,7 +242,6 @@ def keras_apply_shift_negative_correction(graph: Graph,
244
242
 
245
243
  return apply_shift_negative_correction(graph,
246
244
  core_config,
247
- fw_info,
248
245
  snc_node,
249
246
  linear_node,
250
247
  bypass_node,
@@ -22,7 +22,6 @@ from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESS
22
22
  from model_compression_toolkit.core.common import Graph
23
23
  from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
24
24
  from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
25
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
26
25
  from model_compression_toolkit.core.keras.hessian.hessian_scores_calculator_keras import HessianScoresCalculatorKeras
27
26
  from model_compression_toolkit.logger import Logger
28
27
 
@@ -95,20 +94,11 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
95
94
  for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
96
95
 
97
96
  # Check if the target node's layer type is supported.
98
- if not DEFAULT_KERAS_INFO.is_kernel_op(ipt_node.type):
97
+ if not ipt_node.is_kernel_op:
99
98
  Logger.critical(f"Hessian information with respect to weights is not supported for "
100
99
  f"{ipt_node.type} layers.") # pragma: no cover
101
100
 
102
- # Get the weight attributes for the target node type
103
- weight_attributes = DEFAULT_KERAS_INFO.get_kernel_op_attributes(ipt_node.type)
104
-
105
- # Get the weight tensor for the target node
106
- if len(weight_attributes) != 1: # pragma: no cover
107
- Logger.critical(
108
- f"Hessian-based scoring with respect to weights is currently supported only for nodes with "
109
- f"a single weight attribute. Found {len(weight_attributes)} attributes.")
110
-
111
- weight_tensor = getattr(model.get_layer(ipt_node.name), weight_attributes[0])
101
+ weight_tensor = getattr(model.get_layer(ipt_node.name), ipt_node.kernel_attr)
112
102
 
113
103
  if j == 0:
114
104
  # On the first iteration we store the weight_tensor shape for later reshaping the results
@@ -116,7 +106,7 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
116
106
  tensors_original_shape.append(weight_tensor.shape)
117
107
 
118
108
  # Get the output channel index (needed for HessianInfoGranularity.PER_OUTPUT_CHANNEL case)
119
- output_channel_axis, _ = DEFAULT_KERAS_INFO.kernel_channels_mapping.get(ipt_node.type)
109
+ output_channel_axis = ipt_node.channel_axis.output
120
110
 
121
111
  # Get number of scores that should be calculated by the granularity.
122
112
  num_of_scores = self._get_num_scores_by_granularity(weight_tensor,
@@ -65,7 +65,6 @@ from model_compression_toolkit.core.common import Graph, BaseNode
65
65
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
66
66
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
67
67
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
68
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
69
68
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.activation_decomposition import \
70
69
  ActivationDecomposition
71
70
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.matmul_substitution import \
@@ -175,18 +174,16 @@ class KerasImplementation(FrameworkImplementation):
175
174
  graph: Graph,
176
175
  mode: ModelBuilderMode,
177
176
  append2output: List[Any] = None,
178
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
179
177
  return_float_outputs: bool = False) -> Tuple:
180
178
  """
181
179
  Build a Keras model from a graph.
182
- The mode determines how the model should be build. append2output is a list of Nodes
180
+ The mode determines how the model should be built. append2output is a list of Nodes
183
181
  to set as the model outputs.
184
182
 
185
183
  Args:
186
184
  graph: Graph to build the model from it.
187
185
  mode: Mode for how to build the model.
188
186
  append2output: List of Nodes to set as the model's outputs.
189
- fw_info: FrameworkInfo object with information about the specific framework's model
190
187
  return_float_outputs (bool): whether to return outputs before or after quantization nodes (default)
191
188
  Returns:
192
189
  A tuple with the model and additional relevant supporting objects.
@@ -195,7 +192,6 @@ class KerasImplementation(FrameworkImplementation):
195
192
  keras_model_builder = get_keras_model_builder(mode)
196
193
  return keras_model_builder(graph=graph,
197
194
  append2output=append2output,
198
- fw_info=fw_info,
199
195
  return_float_outputs=return_float_outputs).build_model()
200
196
 
201
197
  def run_model_inference(self,
@@ -227,65 +223,57 @@ class KerasImplementation(FrameworkImplementation):
227
223
 
228
224
  def shift_negative_correction(self,
229
225
  graph: Graph,
230
- core_config: CoreConfig,
231
- fw_info: FrameworkInfo) -> Graph:
226
+ core_config: CoreConfig) -> Graph:
232
227
  """
233
228
  Apply shift negative correction (SNC) on a graph.
234
229
 
235
230
  Args:
236
231
  graph: Graph to apply SNC on.
237
232
  core_config: Quantization configuration.
238
- fw_info: FrameworkInfo object with information about the specific framework's model.
239
233
 
240
234
  Returns:
241
235
  Graph after SNC.
242
236
  """
243
237
  return keras_apply_shift_negative_correction(graph,
244
- core_config,
245
- fw_info)
238
+ core_config)
246
239
 
247
240
  def compute_activation_bias_correction(self,
248
241
  graph: Graph,
249
- quant_config: QuantizationConfig,
250
- fw_info: FrameworkInfo):
242
+ quant_config: QuantizationConfig):
251
243
  """
252
244
  Compute activation bias correction on a graph.
253
245
 
254
246
  Args:
255
247
  graph: Graph to apply activation bias correction on.
256
248
  quant_config: QuantizationConfig of how the model should be quantized.
257
- fw_info: FrameworkInfo object with information about the specific framework's model.
258
249
 
259
250
  Returns:
260
251
  Graph after activation bias correction computing.
261
252
  """
262
253
  return keras_compute_activation_bias_correction_of_graph(graph=graph,
263
254
  quant_config=quant_config,
264
- fw_info=fw_info,
265
255
  fw_impl=self)
266
256
 
267
257
  def get_substitutions_channel_equalization(self,
268
- quant_config: QuantizationConfig,
269
- fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
258
+ quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
270
259
  """
271
260
  Return a list of the framework substitutions used for channel equalization.
272
261
 
273
262
  Args:
274
263
  quant_config: QuantizationConfig to determine which substitutions to return.
275
- fw_info: FrameworkInfo object with information about the specific framework's model.
276
264
 
277
265
  Returns:
278
266
  A list of the framework substitutions used after we collect statistics.
279
267
  """
280
268
  substitutions_list = []
281
269
  if quant_config.activation_channel_equalization:
282
- substitutions_list.extend([ScaleEqualization(quant_config, fw_info),
283
- ScaleEqualizationWithPad(quant_config, fw_info),
284
- ScaleEqualizationMidActivation(quant_config, fw_info),
285
- ScaleEqualizationMidActivationWithPad(quant_config, fw_info)])
270
+ substitutions_list.extend([ScaleEqualization(quant_config),
271
+ ScaleEqualizationWithPad(quant_config),
272
+ ScaleEqualizationMidActivation(quant_config),
273
+ ScaleEqualizationMidActivationWithPad(quant_config)])
286
274
  return substitutions_list
287
275
 
288
- def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List[common.BaseSubstitution]:
276
+ def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
289
277
  """
290
278
 
291
279
  Returns: A list of the framework substitutions used to prepare the graph.
@@ -402,22 +390,19 @@ class KerasImplementation(FrameworkImplementation):
402
390
 
403
391
  def get_node_prior_info(self,
404
392
  node: BaseNode,
405
- fw_info: FrameworkInfo,
406
393
  graph: Graph) -> NodePriorInfo:
407
394
  """
408
395
  Get a NodePriorInfo object for a node that represents a Keras layer.
409
396
 
410
397
  Args:
411
398
  node: Node to get its prior info.
412
- fw_info: Framework specific information needed to create the prior info of the node.
413
399
  graph: Graph to check the next node type.
414
400
 
415
401
  Returns:
416
402
  NodePriorInfo with information about the node.
417
403
  """
418
404
 
419
- return create_node_prior_info(node=node,
420
- fw_info=fw_info, graph=graph)
405
+ return create_node_prior_info(node=node, graph=graph)
421
406
 
422
407
  def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
423
408
  """
@@ -530,23 +515,19 @@ class KerasImplementation(FrameworkImplementation):
530
515
  return True
531
516
 
532
517
  def get_node_mac_operations(self,
533
- node: BaseNode,
534
- fw_info: FrameworkInfo) -> float:
518
+ node: BaseNode) -> float:
535
519
  """
536
520
  Gets the MAC operation count for a given operation.
537
521
 
538
522
  Args:
539
523
  node: A graph node that wraps the operation for which the MAC count is computed.
540
- fw_info: FrameworkInfo object with information about the Keras model.
541
524
 
542
525
  Returns: The MAC count og the operation
543
526
  """
544
- kernels = fw_info.get_kernel_op_attributes(node.type)
545
- if not kernels or kernels[0] is None:
527
+ if node.kernel_attr is None:
546
528
  return 0
547
529
 
548
- assert len(kernels) == 1
549
- kernel_shape = node.get_weights_by_keys(kernels[0]).shape
530
+ kernel_shape = node.get_weights_by_keys(node.kernel_attr).shape
550
531
 
551
532
  if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose) or node.is_match_type(DepthwiseConv2D):
552
533
  h, w = node.get_output_shapes_list()[0][-3:-1]
@@ -554,8 +535,7 @@ class KerasImplementation(FrameworkImplementation):
554
535
 
555
536
  if node.is_match_type(Dense):
556
537
  # IN * OUT * (all previous dims[:-1])
557
- _, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
558
- return node.get_total_output_params() * kernel_shape[input_channel_axis]
538
+ return node.get_total_output_params() * kernel_shape[node.channel_axis.input]
559
539
 
560
540
  return 0
561
541
 
@@ -1,6 +1,6 @@
1
1
  from tensorflow.keras.models import Model
2
2
 
3
- from model_compression_toolkit.core import FrameworkInfo
3
+ from model_compression_toolkit.core.common.framework_info import get_fw_info
4
4
  from model_compression_toolkit.core.common.framework_info import ChannelAxis
5
5
  from model_compression_toolkit.core.common.model_validation import ModelValidation
6
6
  from model_compression_toolkit.core.keras.constants import CHANNELS_FORMAT, CHANNELS_FORMAT_LAST, CHANNELS_FORMAT_FIRST
@@ -11,17 +11,15 @@ class KerasModelValidation(ModelValidation):
11
11
  Class to define validation methods in order to validate the received Keras model to quantize.
12
12
  """
13
13
 
14
- def __init__(self, model: Model, fw_info: FrameworkInfo):
14
+ def __init__(self, model: Model):
15
15
  """
16
16
  Initialize a KerasModelValidation object.
17
17
 
18
18
  Args:
19
19
  model: Keras model to check its validity.
20
- fw_info: Information about the framework of the model (Keras).
21
20
  """
22
21
 
23
- super(KerasModelValidation, self).__init__(model=model,
24
- fw_info=fw_info)
22
+ super(KerasModelValidation, self).__init__(model=model)
25
23
 
26
24
  def validate_output_channel_consistency(self):
27
25
  """
@@ -30,9 +28,10 @@ class KerasModelValidation(ModelValidation):
30
28
  If the model has layers with different output channels index, an exception is thrown.
31
29
 
32
30
  """
31
+ fw_info = get_fw_info()
33
32
  for layer in self.model.layers:
34
33
  data_format = layer.get_config().get(CHANNELS_FORMAT)
35
34
  if data_format is not None:
36
- assert (data_format == CHANNELS_FORMAT_LAST and self.fw_info.out_channel_axis_mapping.get(layer) == ChannelAxis.NHWC.value
37
- or data_format == CHANNELS_FORMAT_FIRST and self.fw_info.out_channel_axis_mapping.get(layer) == ChannelAxis.NCHW.value), \
35
+ assert (data_format == CHANNELS_FORMAT_LAST and fw_info.get_out_channel_axis(layer) == ChannelAxis.NHWC.value
36
+ or data_format == CHANNELS_FORMAT_FIRST and fw_info.get_out_channel_axis(layer) == ChannelAxis.NCHW.value), \
38
37
  f'Model can not have layers with different data formats.'
@@ -17,22 +17,19 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
17
17
 
18
18
 
19
19
  def create_node_prior_info(node: BaseNode,
20
- fw_info: FrameworkInfo,
21
20
  graph: Graph):
22
21
  """
23
22
  Create a NodePriorInfo object for a given node.
24
23
 
25
24
  Args:
26
25
  node: Node to create its prior info.
27
- fw_info: Information about a specific framework the node was generated from.
28
26
  graph: Graph to check the next node type.
29
27
 
30
28
  Returns:
31
29
  NodePriorInfo object with info about the node.
32
30
  """
33
31
 
34
- min_output, max_output = _get_min_max_outputs(node=node,
35
- fw_info=fw_info)
32
+ min_output, max_output = _get_min_max_outputs(node=node)
36
33
 
37
34
  mean_output, std_output = _get_mean_std_outputs(node=node,
38
35
  graph=graph)
@@ -42,14 +39,12 @@ def create_node_prior_info(node: BaseNode,
42
39
  std_output=std_output)
43
40
 
44
41
 
45
- def _get_min_max_outputs(node: BaseNode,
46
- fw_info: FrameworkInfo) -> Tuple[Any, Any]:
42
+ def _get_min_max_outputs(node: BaseNode) -> Tuple[Any, Any]:
47
43
  """
48
44
  Return the min/max output values of a node if known.
49
45
  If one of them (or both of them) is unknown - return None instead of a value.
50
46
  Args:
51
47
  node: Node to create its prior info.
52
- fw_info: Information about a specific framework the node was generated from.
53
48
 
54
49
  Returns:
55
50
  Min/max output values if known.
@@ -58,12 +53,8 @@ def _get_min_max_outputs(node: BaseNode,
58
53
 
59
54
  if node.is_match_type(ReLU):
60
55
  min_output = node.framework_attr[THRESHOLD] if node.framework_attr[NEGATIVE_SLOPE] == 0 else None
61
-
62
- elif fw_info.layers_has_min_max(node.type):
63
- min_output, max_output = fw_info.layer_min_max_mapping[node.type]
64
-
65
- elif node.is_match_type(Activation) and fw_info.activation_has_min_max(node.framework_attr[ACTIVATION]):
66
- min_output, max_output = fw_info.activation_min_max_mapping[node.framework_attr[ACTIVATION]]
56
+ else:
57
+ min_output, max_output = node.minmax
67
58
 
68
59
  return min_output, max_output
69
60
 
@@ -19,7 +19,6 @@ from model_compression_toolkit.core.common.pruning.pruning_framework_implementat
19
19
  PruningFrameworkImplementation
20
20
  from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
21
21
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
22
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
22
  from model_compression_toolkit.core.common import BaseNode
24
23
  from model_compression_toolkit.core.keras.constants import BIAS, GROUPS, FILTERS, UNITS, USE_BIAS
25
24
  import keras
@@ -38,27 +37,23 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
38
37
 
39
38
  def prune_entry_node(self,
40
39
  node: BaseNode,
41
- output_mask: np.ndarray,
42
- fw_info: FrameworkInfo):
40
+ output_mask: np.ndarray):
43
41
  """
44
42
  Prunes the entry node of a model in Keras.
45
43
 
46
44
  Args:
47
45
  node (BaseNode): The entry node to be pruned.
48
46
  output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
49
- fw_info (FrameworkInfo): Framework-specific information object.
50
47
 
51
48
  """
52
49
  return _prune_keras_edge_node(node=node,
53
50
  mask=output_mask,
54
- fw_info=fw_info,
55
51
  is_exit_node=False)
56
52
 
57
53
  def prune_intermediate_node(self,
58
54
  node: BaseNode,
59
55
  input_mask: np.ndarray,
60
- output_mask: np.ndarray,
61
- fw_info: FrameworkInfo):
56
+ output_mask: np.ndarray):
62
57
  """
63
58
  Prunes an intermediate node in a Keras model.
64
59
 
@@ -66,7 +61,6 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
66
61
  node (BaseNode): The intermediate node to be pruned.
67
62
  input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
68
63
  output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
69
- fw_info (FrameworkInfo): Framework-specific information object.
70
64
 
71
65
  """
72
66
  _edit_node_input_shape(input_mask, node)
@@ -79,20 +73,17 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
79
73
 
80
74
  def prune_exit_node(self,
81
75
  node: BaseNode,
82
- input_mask: np.ndarray,
83
- fw_info: FrameworkInfo):
76
+ input_mask: np.ndarray):
84
77
  """
85
78
  Prunes the exit node of a model in Keras.
86
79
 
87
80
  Args:
88
81
  node (BaseNode): The exit node to be pruned.
89
82
  input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
90
- fw_info (FrameworkInfo): Framework-specific information object.
91
83
 
92
84
  """
93
85
  return _prune_keras_edge_node(node=node,
94
86
  mask=input_mask,
95
- fw_info=fw_info,
96
87
  is_exit_node=True)
97
88
 
98
89
  def is_node_entry_node(self, node: BaseNode) -> bool:
@@ -109,22 +100,19 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
109
100
 
110
101
  def is_node_exit_node(self,
111
102
  node: BaseNode,
112
- corresponding_entry_node: BaseNode,
113
- fw_info: FrameworkInfo) -> bool:
103
+ corresponding_entry_node: BaseNode) -> bool:
114
104
  """
115
105
  Determines whether a node is an exit node in a Keras model.
116
106
 
117
107
  Args:
118
108
  node (BaseNode): The node to be checked.
119
109
  corresponding_entry_node (BaseNode): The entry node of the pruning section that is checked.
120
- fw_info (FrameworkInfo): Framework-specific information object.
121
110
 
122
111
  Returns:
123
112
  bool: Boolean indicating if the node is an exit node.
124
113
  """
125
114
  return _is_keras_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node,
126
- corresponding_entry_node,
127
- fw_info)
115
+ corresponding_entry_node)
128
116
 
129
117
  def is_node_intermediate_pruning_section(self, node: BaseNode) -> bool:
130
118
  """
@@ -143,8 +131,7 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
143
131
  keras.layers.Dense]
144
132
 
145
133
  def attrs_oi_channels_info_for_pruning(self,
146
- node: BaseNode,
147
- fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
134
+ node: BaseNode) -> Dict[str, Tuple[int, int]]:
148
135
  """
149
136
  Retrieves the attributes of a given node along with the output/input (OI) channel axis
150
137
  for each attribute used to prune these attributes.
@@ -161,7 +148,6 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
161
148
 
162
149
  Args:
163
150
  node (BaseNode): The node from the computational graph.
164
- fw_info (FrameworkInfo): Contains framework-specific information and utilities.
165
151
 
166
152
  Returns:
167
153
  Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'kernel' or 'bias')
@@ -169,13 +155,8 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
169
155
  """
170
156
 
171
157
  attributes_with_axis = {}
172
- if fw_info.is_kernel_op(node.type):
173
- kernel_attributes = fw_info.get_kernel_op_attributes(node.type)
174
- if kernel_attributes is None or len(kernel_attributes)==0:
175
- Logger.critical(f"Expected kernel attributes for operation for node type {node.type}, found None or empty.")
176
-
177
- for attr in kernel_attributes:
178
- attributes_with_axis[attr] = fw_info.kernel_channels_mapping.get(node.type)
158
+ if node.is_kernel_op:
159
+ attributes_with_axis[node.kernel_attr] = (node.channel_axis.output, node.channel_axis.input)
179
160
 
180
161
  # Bias is a vector at the length of the number of output channels.
181
162
  # For this reason, input channel axis is irrelevant to the bias attribute.
@@ -216,7 +197,6 @@ def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
216
197
 
217
198
  def _prune_keras_edge_node(node: BaseNode,
218
199
  mask: np.ndarray,
219
- fw_info: FrameworkInfo,
220
200
  is_exit_node: bool):
221
201
  """
222
202
  Prunes the given Keras node by applying the mask to the node's weights (kernels and biases).
@@ -225,21 +205,18 @@ def _prune_keras_edge_node(node: BaseNode,
225
205
  Args:
226
206
  node: The node to be pruned.
227
207
  mask: The pruning mask to be applied.
228
- fw_info: Framework-specific information object.
229
208
  is_exit_node: A boolean indicating whether the node is an exit node.
230
209
 
231
210
  """
232
211
 
233
212
  # Retrieve the kernel attribute and the axes to prune.
234
- kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
235
- io_axis = fw_info.kernel_channels_mapping.get(node.type)
236
- axis_to_prune = io_axis[int(is_exit_node)]
237
- kernel = node.get_weights_by_keys(kernel_attr)
213
+ axis_to_prune = node.channel_axis.input if is_exit_node else node.channel_axis.output
214
+ kernel = node.get_weights_by_keys(node.kernel_attr)
238
215
  # Convert mask to boolean.
239
216
  mask_bool = mask.astype(bool)
240
217
 
241
218
  pruned_kernel = kernel.compress(mask_bool, axis=axis_to_prune)
242
- node.set_weights_by_keys(name=kernel_attr, tensor=pruned_kernel)
219
+ node.set_weights_by_keys(name=node.kernel_attr, tensor=pruned_kernel)
243
220
 
244
221
  if not is_exit_node and node.framework_attr[USE_BIAS]:
245
222
  # Prune the bias if applicable and it's an entry node.
@@ -27,7 +27,6 @@ if FOUND_TF:
27
27
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
28
28
  AttachTpcToKeras
29
29
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
30
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
31
30
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
32
31
  from tensorflow.keras.models import Model
33
32
 
@@ -93,7 +92,6 @@ if FOUND_TF:
93
92
  representative_data_gen,
94
93
  core_config,
95
94
  target_platform_capabilities,
96
- DEFAULT_KERAS_INFO,
97
95
  fw_impl)
98
96
 
99
97
  else:
@@ -43,7 +43,6 @@ def activation_bias_correction_node_matchers():
43
43
 
44
44
  def keras_compute_activation_bias_correction_of_graph(graph: Graph,
45
45
  quant_config: QuantizationConfig,
46
- fw_info: FrameworkInfo,
47
46
  fw_impl: FrameworkImplementation) -> Graph:
48
47
  """
49
48
  Compute the activation bias correction term for graph based on a Keras model.
@@ -51,7 +50,6 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph,
51
50
  Args:
52
51
  graph: Graph with nodes to compute the activation bias correction.
53
52
  quant_config: QuantizationConfig of how the model should be quantized.
54
- fw_info: Framework info like lists of nodes their kernel should quantized.
55
53
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
56
54
 
57
55
  Returns:
@@ -59,7 +57,6 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph,
59
57
  """
60
58
  graph = compute_activation_bias_correction_of_graph(graph=graph,
61
59
  quant_config=quant_config,
62
- fw_info=fw_info,
63
60
  fw_impl=fw_impl,
64
61
  activation_bias_correction_node_matchers=
65
62
  activation_bias_correction_node_matchers,
@@ -24,8 +24,6 @@ from model_compression_toolkit.core.common.user_info import UserInformation
24
24
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
25
25
  PytorchModel
26
26
 
27
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
28
-
29
27
 
30
28
  class FloatPyTorchModel(PytorchModel):
31
29
  """
@@ -34,19 +32,16 @@ class FloatPyTorchModel(PytorchModel):
34
32
 
35
33
  def __init__(self,
36
34
  graph: common.Graph,
37
- append2output=None,
38
- fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO):
35
+ append2output=None):
39
36
  """
40
37
 
41
38
  Args:
42
39
  graph: Graph to build its corresponding Pytorch model.
43
40
  append2output: List of nodes or OutTensor objects.
44
- fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
45
41
  """
46
42
 
47
43
  super().__init__(graph,
48
- append2output,
49
- fw_info)
44
+ append2output)
50
45
 
51
46
  def _quantize_node_activations(self,
52
47
  node: BaseNode,
@@ -71,20 +66,17 @@ class FloatPyTorchModelBuilder(PyTorchModelBuilder):
71
66
  def __init__(self,
72
67
  graph: common.Graph,
73
68
  append2output=None,
74
- fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
75
69
  return_float_outputs: bool = False):
76
70
  """
77
71
 
78
72
  Args:
79
73
  graph: Graph to build the model from.
80
74
  append2output: Nodes to append to model's output.
81
- fw_info: Information about the specific framework of the model that is built.
82
75
  return_float_outputs: Whether the model returns float tensors or not.
83
76
  """
84
77
 
85
78
  super().__init__(graph,
86
79
  append2output,
87
- fw_info,
88
80
  return_float_outputs)
89
81
 
90
82
  def build_model(self) -> Tuple[PytorchModel, UserInformation]:
@@ -94,5 +86,4 @@ class FloatPyTorchModelBuilder(PyTorchModelBuilder):
94
86
 
95
87
  """
96
88
  return FloatPyTorchModel(self.graph,
97
- self.append2output,
98
- self.fw_info), self.graph.user_info
89
+ self.append2output), self.graph.user_info