mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250619.621__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/RECORD +123 -123
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +2 -5
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
  6. model_compression_toolkit/core/common/framework_implementation.py +10 -22
  7. model_compression_toolkit/core/common/framework_info.py +105 -68
  8. model_compression_toolkit/core/common/graph/base_graph.py +15 -42
  9. model_compression_toolkit/core/common/graph/base_node.py +103 -42
  10. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  11. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
  12. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
  16. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
  17. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
  18. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
  19. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
  20. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
  21. model_compression_toolkit/core/common/model_collector.py +10 -20
  22. model_compression_toolkit/core/common/model_validation.py +1 -4
  23. model_compression_toolkit/core/common/network_editors/actions.py +14 -38
  24. model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
  25. model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
  26. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
  27. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
  28. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
  29. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
  30. model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
  31. model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
  32. model_compression_toolkit/core/common/pruning/pruner.py +1 -6
  33. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
  34. model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
  35. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
  36. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
  37. model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
  39. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
  40. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
  41. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
  42. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
  43. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
  44. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
  45. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
  46. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
  47. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
  48. model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
  49. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
  50. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
  51. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  52. model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
  53. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
  54. model_compression_toolkit/core/graph_prep_runner.py +2 -16
  55. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
  56. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
  57. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
  58. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
  59. model_compression_toolkit/core/keras/default_framework_info.py +138 -87
  60. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
  61. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
  62. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
  63. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
  64. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  65. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
  66. model_compression_toolkit/core/keras/keras_implementation.py +15 -35
  67. model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
  68. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
  69. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
  70. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  71. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
  72. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
  73. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
  74. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
  75. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
  76. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
  77. model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
  78. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
  79. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
  80. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  81. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
  82. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
  83. model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
  84. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
  85. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  86. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
  87. model_compression_toolkit/core/quantization_prep_runner.py +4 -9
  88. model_compression_toolkit/core/runner.py +5 -15
  89. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  90. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  91. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -0
  92. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +19 -17
  93. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -0
  94. model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
  95. model_compression_toolkit/gptq/common/gptq_training.py +1 -8
  96. model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
  97. model_compression_toolkit/gptq/keras/graph_info.py +4 -6
  98. model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  100. model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
  101. model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
  102. model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
  103. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  104. model_compression_toolkit/gptq/runner.py +1 -7
  105. model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
  106. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
  107. model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
  108. model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
  109. model_compression_toolkit/ptq/runner.py +1 -4
  110. model_compression_toolkit/qat/common/qat_config.py +2 -6
  111. model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
  112. model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
  113. model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
  114. model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
  115. model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
  116. model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
  117. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
  118. model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
  119. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
  120. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
  121. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/WHEEL +0 -0
  122. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/licenses/LICENSE.md +0 -0
  123. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/top_level.txt +0 -0
@@ -34,18 +34,16 @@ class MemoryCalculator:
34
34
  which is crucial for deploying models on memory-constrained devices or optimizing for computational efficiency.
35
35
  """
36
36
 
37
- def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: PruningFrameworkImplementation):
37
+ def __init__(self, graph: Graph, fw_impl: PruningFrameworkImplementation):
38
38
  """
39
39
  Initializes the MemoryCalculator with necessary information about the model's graph,
40
40
  framework-specific details, and pruning implementation.
41
41
 
42
42
  Args:
43
43
  graph (Graph): Computational graph of the model.
44
- fw_info (FrameworkInfo): Contains framework-specific information.
45
44
  fw_impl (PruningFrameworkImplementation): Implementation details for pruning.
46
45
  """
47
46
  self.graph = graph
48
- self.fw_info = fw_info
49
47
  self.fw_impl = fw_impl
50
48
 
51
49
  def get_pruned_graph_memory(self,
@@ -204,19 +202,13 @@ class MemoryCalculator:
204
202
  if node == section.exit_node:
205
203
  return masks.get(section.entry_node)
206
204
 
207
- kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)
208
- # Ensure only one kernel attribute exists for the given node.
209
- if len(kernel_attr) != 1:
210
- Logger.critical(f"Expected a single attribute, but found {len(kernel_attr)} attributes for node '{node}'. Ensure the node configuration is correct.")
211
- kernel_attr = kernel_attr[0]
212
-
213
205
  # Retrieve and validate the axis index for the output channels.
214
- _, ic_axis = self.fw_info.kernel_channels_mapping.get(node.type)
206
+ ic_axis = node.channel_axis.input
215
207
  if ic_axis is None or int(ic_axis) != ic_axis:
216
208
  Logger.critical(f"Invalid input channel axis type for node '{node}': expected integer but got '{ic_axis}'.")
217
209
 
218
210
  # Get the number of output channels based on the kernel attribute and axis.
219
- num_ic = node.get_weights_by_keys(kernel_attr).shape[ic_axis]
211
+ num_ic = node.get_weights_by_keys(node.kernel_attr).shape[ic_axis]
220
212
  mask = np.ones(num_ic, dtype=bool)
221
213
  return mask
222
214
 
@@ -289,7 +281,7 @@ class MemoryCalculator:
289
281
  int: The total number of parameters in the node after pruning.
290
282
  """
291
283
  total_params = 0
292
- attributes_and_oc_axis = self.fw_impl.attrs_oi_channels_info_for_pruning(node, self.fw_info)
284
+ attributes_and_oc_axis = self.fw_impl.attrs_oi_channels_info_for_pruning(node)
293
285
 
294
286
  # Iterate over the node's weights and apply pruning based on the masks.
295
287
  for w_attr, w in node.weights.items():
@@ -311,7 +303,7 @@ class MemoryCalculator:
311
303
  num_oc = np.sum(output_mask)
312
304
  else:
313
305
  # Get the node channel axis from framework info
314
- channel_axis = self.fw_info.out_channel_axis_mapping.get(node.type)
306
+ channel_axis = node.out_channel_axis
315
307
  if channel_axis is None:
316
308
  Logger.critical(f"The channel axis is undefined. Please ensure the channel axis is explicitly defined for node {node.type} in the framework info.")
317
309
 
@@ -27,7 +27,6 @@ from model_compression_toolkit.logger import Logger
27
27
 
28
28
  def build_pruned_graph(graph: Graph,
29
29
  masks: Dict[BaseNode, np.ndarray],
30
- fw_info: FrameworkInfo,
31
30
  fw_impl: FrameworkImplementation) -> Graph:
32
31
  """
33
32
  Prunes the provided graph according to the given pruning output-channels masks.
@@ -35,7 +34,6 @@ def build_pruned_graph(graph: Graph,
35
34
  Args:
36
35
  graph: The original computational graph to be pruned.
37
36
  masks: A dictionary mapping each prunable node to its pruning mask.
38
- fw_info: Framework-specific information object.
39
37
  fw_impl: Framework-specific implementation object.
40
38
 
41
39
  Returns:
@@ -66,8 +64,7 @@ def build_pruned_graph(graph: Graph,
66
64
  section_mask = PruningSectionMask(entry_node_oc_mask=mask,
67
65
  exit_node_ic_mask=mask)
68
66
  pruning_section.apply_inner_section_mask(section_mask,
69
- fw_impl,
70
- fw_info)
67
+ fw_impl)
71
68
 
72
69
  return graph_to_prune
73
70
 
@@ -40,7 +40,6 @@ class Pruner:
40
40
  """
41
41
  def __init__(self,
42
42
  float_graph: Graph,
43
- fw_info: FrameworkInfo,
44
43
  fw_impl: PruningFrameworkImplementation,
45
44
  target_resource_utilization: ResourceUtilization,
46
45
  representative_data_gen: Callable,
@@ -49,7 +48,6 @@ class Pruner:
49
48
  """
50
49
  Args:
51
50
  float_graph (Graph): The floating-point representation of the model's computation graph.
52
- fw_info (FrameworkInfo): Contains metadata and helper functions for the framework.
53
51
  fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.
54
52
  target_resource_utilization (ResourceUtilization): The target resource utilization to be achieved after pruning.
55
53
  representative_data_gen (Callable): Generator function for representative dataset used in pruning analysis.
@@ -57,7 +55,6 @@ class Pruner:
57
55
  target_platform_capabilities (FrameworkQuantizationCapabilities): Object encapsulating the capabilities of the target hardware platform.
58
56
  """
59
57
  self.float_graph = float_graph
60
- self.fw_info = fw_info
61
58
  self.fw_impl = fw_impl
62
59
  self.target_resource_utilization = target_resource_utilization
63
60
  self.representative_data_gen = representative_data_gen
@@ -84,7 +81,6 @@ class Pruner:
84
81
  # Apply Greedy strategy to compute masks based on importance scores.
85
82
  if self.pruning_config.channels_filtering_strategy == ChannelsFilteringStrategy.GREEDY:
86
83
  mask_calculator = GreedyMaskCalculator(entry_nodes,
87
- self.fw_info,
88
84
  self.simd_scores,
89
85
  self.target_resource_utilization,
90
86
  self.float_graph,
@@ -99,7 +95,6 @@ class Pruner:
99
95
  Logger.info("Start pruning graph...")
100
96
  _pruned_graph = build_pruned_graph(self.float_graph,
101
97
  self.per_oc_mask,
102
- self.fw_info,
103
98
  self.fw_impl)
104
99
  return _pruned_graph
105
100
 
@@ -116,7 +111,7 @@ class Pruner:
116
111
  # Retrieve and initialize the importance metric.
117
112
  im = get_importance_metric(self.pruning_config.importance_metric, graph=self.float_graph,
118
113
  representative_data_gen=self.representative_data_gen, fw_impl=self.fw_impl,
119
- pruning_config=self.pruning_config, fw_info=self.fw_info)
114
+ pruning_config=self.pruning_config)
120
115
  entry_node_to_simd_score, simd_groups_indices = im.get_entry_node_to_simd_score(entry_nodes)
121
116
  return entry_node_to_simd_score, simd_groups_indices
122
117
 
@@ -28,15 +28,13 @@ class PruningFrameworkImplementation(FrameworkImplementation):
28
28
  @abstractmethod
29
29
  def prune_entry_node(self,
30
30
  node: BaseNode,
31
- output_mask: np.ndarray,
32
- fw_info: FrameworkInfo):
31
+ output_mask: np.ndarray):
33
32
  """
34
33
  Abstract method to prune an entry node in the model.
35
34
 
36
35
  Args:
37
36
  node: The node to be pruned.
38
37
  output_mask: A numpy array representing the mask to be applied to the output channels.
39
- fw_info: Framework-specific information.
40
38
 
41
39
  Raises:
42
40
  NotImplemented: If the method is not implemented in the subclass.
@@ -48,8 +46,7 @@ class PruningFrameworkImplementation(FrameworkImplementation):
48
46
  def prune_intermediate_node(self,
49
47
  node: BaseNode,
50
48
  input_mask: np.ndarray,
51
- output_mask: np.ndarray,
52
- fw_info: FrameworkInfo):
49
+ output_mask: np.ndarray):
53
50
  """
54
51
  Abstract method to prune an intermediate node in the model.
55
52
 
@@ -57,7 +54,6 @@ class PruningFrameworkImplementation(FrameworkImplementation):
57
54
  node: The node to be pruned.
58
55
  input_mask: Mask to be applied to the input channels.
59
56
  output_mask: Mask to be applied to the output channels.
60
- fw_info: Framework-specific information.
61
57
 
62
58
  Raises:
63
59
  NotImplemented: If the method is not implemented in the subclass.
@@ -68,15 +64,13 @@ class PruningFrameworkImplementation(FrameworkImplementation):
68
64
  @abstractmethod
69
65
  def prune_exit_node(self,
70
66
  node: BaseNode,
71
- input_mask: np.ndarray,
72
- fw_info: FrameworkInfo):
67
+ input_mask: np.ndarray):
73
68
  """
74
69
  Abstract method to prune an exit node in the model.
75
70
 
76
71
  Args:
77
72
  node: The node to be pruned.
78
73
  input_mask: Mask to be applied to the input channels.
79
- fw_info: Framework-specific information.
80
74
 
81
75
  Raises:
82
76
  NotImplemented: If the method is not implemented in the subclass.
@@ -105,8 +99,7 @@ class PruningFrameworkImplementation(FrameworkImplementation):
105
99
  @abstractmethod
106
100
  def is_node_exit_node(self,
107
101
  node: BaseNode,
108
- corresponding_entry_node: BaseNode,
109
- fw_info: FrameworkInfo) -> bool:
102
+ corresponding_entry_node: BaseNode) -> bool:
110
103
 
111
104
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
112
105
  f'framework\'s is_node_exit_node method.') # pragma: no cover
@@ -129,7 +122,7 @@ class PruningFrameworkImplementation(FrameworkImplementation):
129
122
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
130
123
  f'framework\'s is_node_intermediate_pruning_section method.') # pragma: no cover
131
124
 
132
- def attrs_oi_channels_info_for_pruning(self, node: BaseNode, fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
125
+ def attrs_oi_channels_info_for_pruning(self, node: BaseNode) -> Dict[str, Tuple[int, int]]:
133
126
  """
134
127
  Retrieves the attributes of a given node along with the output/input (OI) channel axis
135
128
  for each attribute used to prune these attributes.
@@ -146,7 +139,6 @@ class PruningFrameworkImplementation(FrameworkImplementation):
146
139
 
147
140
  Args:
148
141
  node (BaseNode): The node from the computational graph.
149
- fw_info (FrameworkInfo): Contains framework-specific information and utilities.
150
142
 
151
143
  Returns:
152
144
  Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'kernel' or 'bias')
@@ -76,34 +76,28 @@ class PruningSection:
76
76
 
77
77
  def apply_inner_section_mask(self,
78
78
  pruning_section_mask: PruningSectionMask,
79
- fw_impl: Any,
80
- fw_info: FrameworkInfo):
79
+ fw_impl: Any):
81
80
  """
82
81
  Apply the provided pruning section mask to all nodes within the pruning section.
83
82
 
84
83
  Args:
85
84
  pruning_section_mask (PruningSectionMask): The mask to be applied to the pruning section.
86
85
  fw_impl (PruningFrameworkImplementation): Framework-specific implementation for applying the mask.
87
- fw_info (FrameworkInfo): Framework-specific information needed to apply the mask.
88
86
  """
89
87
  fw_impl.prune_entry_node(node=self.entry_node,
90
- output_mask=pruning_section_mask.entry_node_oc_mask,
91
- fw_info=fw_info)
88
+ output_mask=pruning_section_mask.entry_node_oc_mask)
92
89
 
93
90
  for inter_node in self.intermediate_nodes:
94
91
  fw_impl.prune_intermediate_node(node=inter_node,
95
92
  input_mask=pruning_section_mask.entry_node_oc_mask,
96
- output_mask=pruning_section_mask.entry_node_oc_mask,
97
- fw_info=fw_info)
93
+ output_mask=pruning_section_mask.entry_node_oc_mask)
98
94
 
99
95
  fw_impl.prune_exit_node(self.exit_node,
100
- input_mask=pruning_section_mask.exit_node_ic_mask,
101
- fw_info=fw_info)
96
+ input_mask=pruning_section_mask.exit_node_ic_mask)
102
97
 
103
98
  @staticmethod
104
99
  def has_matching_channel_count(exit_node: BaseNode,
105
- corresponding_entry_node: BaseNode,
106
- fw_info: FrameworkInfo) -> bool:
100
+ corresponding_entry_node: BaseNode) -> bool:
107
101
  """
108
102
  Checks if the number of input channels of the exit node matches the number of output channels
109
103
  of its corresponding entry node.
@@ -115,13 +109,10 @@ class PruningSection:
115
109
  Returns:
116
110
  bool: True if the channel counts match, False otherwise.
117
111
  """
118
- _, exit_input_channel_axis = fw_info.kernel_channels_mapping.get(exit_node.type)
119
- entry_output_channel_axis, _ = fw_info.kernel_channels_mapping.get(corresponding_entry_node.type)
112
+ exit_input_channel_axis = exit_node.channel_axis.input
113
+ entry_output_channel_axis = corresponding_entry_node.channel_axis.output
120
114
 
121
- exit_node_attr = fw_info.get_kernel_op_attributes(exit_node.type)[0]
122
- entry_node_attr = fw_info.get_kernel_op_attributes(corresponding_entry_node.type)[0]
123
-
124
- exit_input_channels = exit_node.get_weights_by_keys(exit_node_attr).shape[exit_input_channel_axis]
125
- entry_output_channels = corresponding_entry_node.get_weights_by_keys(entry_node_attr).shape[entry_output_channel_axis]
115
+ exit_input_channels = exit_node.get_weights_by_keys(exit_node.kernel_attr).shape[exit_input_channel_axis]
116
+ entry_output_channels = corresponding_entry_node.get_weights_by_keys(corresponding_entry_node.kernel_attr).shape[entry_output_channel_axis]
126
117
 
127
118
  return exit_input_channels == entry_output_channels
@@ -15,6 +15,7 @@
15
15
  from typing import Callable, List, Tuple
16
16
 
17
17
  from model_compression_toolkit.core import QuantizationConfig
18
+ from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
18
19
  from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
19
20
  NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
20
21
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig, \
@@ -40,7 +41,7 @@ class CandidateNodeQuantizationConfig(BaseNodeQuantizationConfig):
40
41
  activation_quantization_fn: Callable = None,
41
42
  activation_quantization_params_fn: Callable = None,
42
43
  weights_quantization_cfg: NodeWeightsQuantizationConfig = None,
43
- weights_channels_axis: Tuple[int, int] = None,
44
+ weights_channels_axis: ChannelAxisMapping = None,
44
45
  node_attrs_list: List[str] = None):
45
46
  """
46
47
 
@@ -34,7 +34,7 @@ def filter_nodes_candidates(graph: Graph):
34
34
  """
35
35
  nodes = list(graph.nodes)
36
36
  for n in nodes:
37
- n.candidates_quantization_cfg = filter_node_candidates(node=n, fw_info=graph.fw_info)
37
+ n.candidates_quantization_cfg = filter_node_candidates(node=n)
38
38
 
39
39
  return graph
40
40
 
@@ -71,7 +71,7 @@ def _filter_bit_method_dups(candidates: List[CandidateNodeQuantizationConfig],
71
71
  return final_candidates
72
72
 
73
73
 
74
- def filter_node_candidates(node: BaseNode, fw_info) -> List[CandidateNodeQuantizationConfig]:
74
+ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConfig]:
75
75
  """
76
76
  Updates a node's candidates configuration list.
77
77
  If the node's weights quantization is disabled (or it only has activations to quantize), then the updated list
@@ -81,15 +81,13 @@ def filter_node_candidates(node: BaseNode, fw_info) -> List[CandidateNodeQuantiz
81
81
 
82
82
  Args:
83
83
  node: Node to set its quantization configurations.
84
- fw_info: FrameworkInfo object with information about the specific framework's model.
85
84
 
86
85
  """
87
86
 
88
87
  filtered_candidates = copy.deepcopy(node.candidates_quantization_cfg)
89
88
  final_candidates = copy.deepcopy(node.candidates_quantization_cfg)
90
- kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
91
89
 
92
- if (kernel_attr is None or not node.is_weights_quantization_enabled(kernel_attr)) and not node.is_activation_quantization_enabled():
90
+ if (node.kernel_attr is None or not node.is_weights_quantization_enabled(node.kernel_attr)) and not node.is_activation_quantization_enabled():
93
91
  # If activation quantization is disabled and the node doesn't have a kernel or doesn't quantize the kernel,
94
92
  # but for some reason the node has multiple candidates then replace it with a single dummy candidate with
95
93
  # default bit-width values.
@@ -97,8 +95,8 @@ def filter_node_candidates(node: BaseNode, fw_info) -> List[CandidateNodeQuantiz
97
95
  single_dummy_candidate.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
98
96
  single_dummy_candidate.activation_quantization_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO
99
97
 
100
- if kernel_attr is not None:
101
- kernel_config = single_dummy_candidate.weights_quantization_cfg.get_attr_config(kernel_attr)
98
+ if node.kernel_attr is not None:
99
+ kernel_config = single_dummy_candidate.weights_quantization_cfg.get_attr_config(node.kernel_attr)
102
100
  kernel_config.weights_n_bits = FLOAT_BITWIDTH
103
101
  kernel_config.weights_quantization_method = QuantizationMethod.POWER_OF_TWO
104
102
 
@@ -116,9 +114,9 @@ def filter_node_candidates(node: BaseNode, fw_info) -> List[CandidateNodeQuantiz
116
114
  c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
117
115
  c.activation_quantization_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO
118
116
 
119
- final_candidates = _filter_bit_method_dups(filtered_candidates, kernel_attr)
117
+ final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr)
120
118
 
121
- elif kernel_attr is None or not node.is_weights_quantization_enabled(kernel_attr):
119
+ elif node.kernel_attr is None or not node.is_weights_quantization_enabled(node.kernel_attr):
122
120
  # TODO:
123
121
  # To allow MP on positional weights we need to modify this to consider all weights not only kernel.
124
122
  # Remove candidates that have duplicated activation candidates for node with disabled weights quantization.
@@ -129,11 +127,11 @@ def filter_node_candidates(node: BaseNode, fw_info) -> List[CandidateNodeQuantiz
129
127
  and not seen_candidates.add(candidate.activation_quantization_cfg)]
130
128
 
131
129
  for c in filtered_candidates:
132
- if kernel_attr is not None:
133
- kernel_config = c.weights_quantization_cfg.get_attr_config(kernel_attr)
130
+ if node.kernel_attr is not None:
131
+ kernel_config = c.weights_quantization_cfg.get_attr_config(node.kernel_attr)
134
132
  kernel_config.weights_n_bits = FLOAT_BITWIDTH
135
133
  kernel_config.weights_quantization_method = QuantizationMethod.POWER_OF_TWO
136
134
 
137
- final_candidates = _filter_bit_method_dups(filtered_candidates, kernel_attr)
135
+ final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr)
138
136
 
139
137
  return final_candidates
@@ -18,6 +18,7 @@ from typing import Callable, Any, List, Tuple, Union, Dict, TYPE_CHECKING
18
18
  from enum import Enum, auto
19
19
  import numpy as np
20
20
 
21
+ from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
21
22
  from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
22
23
  from model_compression_toolkit.logger import Logger
23
24
  from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
@@ -262,7 +263,7 @@ class WeightsAttrQuantizationConfig:
262
263
  def __init__(self,
263
264
  qc: QuantizationConfig,
264
265
  weights_attr_cfg: AttributeQuantizationConfig,
265
- weights_channels_axis: Tuple[int, int] = None):
266
+ weights_channels_axis: ChannelAxisMapping = None):
266
267
  """
267
268
 
268
269
  Args:
@@ -352,7 +353,7 @@ class WeightsAttrQuantizationConfig:
352
353
  p=self.l_p_value,
353
354
  n_bits=self.weights_n_bits,
354
355
  per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
355
- channel_axis=self.weights_channels_axis[0], # output channel axis
356
+ channel_axis=self.weights_channels_axis.output, # output channel axis
356
357
  min_threshold=min_threshold)[0] # Take only first output, the q-params, as axis is already chosen.
357
358
  )
358
359
  else:
@@ -400,7 +401,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
400
401
  """
401
402
  def __init__(self, qc: QuantizationConfig,
402
403
  op_cfg: OpQuantizationConfig,
403
- weights_channels_axis: Tuple[int, int],
404
+ weights_channels_axis: ChannelAxisMapping,
404
405
  node_attrs_list: List[str]):
405
406
  """
406
407
 
@@ -20,6 +20,7 @@ from typing import List, Callable, Generator
20
20
  from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
21
21
  from model_compression_toolkit.core import QuantizationErrorMethod
22
22
  from model_compression_toolkit.core.common import Graph, BaseNode
23
+ from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
23
24
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
24
25
  from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
25
26
  HessianScoresGranularity
@@ -44,11 +45,8 @@ def _collect_nodes_for_hmse(nodes_list: List[BaseNode], graph: Graph) -> List[Ba
44
45
  """
45
46
  hmse_nodes = []
46
47
  for n in nodes_list:
47
- kernel_attr_name = graph.fw_info.get_kernel_op_attributes(n.type)
48
- kernel_attr_name = None if kernel_attr_name is None or len(kernel_attr_name) == 0 else kernel_attr_name[0]
49
-
50
- if kernel_attr_name is not None and n.is_weights_quantization_enabled(kernel_attr_name) and \
51
- all([c.weights_quantization_cfg.get_attr_config(kernel_attr_name).weights_error_method ==
48
+ if n.kernel_attr is not None and n.is_weights_quantization_enabled(n.kernel_attr) and \
49
+ all([c.weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_error_method ==
52
50
  QuantizationErrorMethod.HMSE for c in n.candidates_quantization_cfg]):
53
51
  hmse_nodes.append(n)
54
52
 
@@ -114,11 +112,7 @@ def calculate_quantization_params(graph: Graph,
114
112
  if attr_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
115
113
  # Although we collected nodes for HMSE before running the loop, we keep this verification to
116
114
  # notify the user in case of HMSE configured for node that is not compatible for this method
117
- kernel_attr_name = graph.fw_info.get_kernel_op_attributes(n.type)
118
- if len(kernel_attr_name) > 0:
119
- kernel_attr_name = kernel_attr_name[0]
120
-
121
- if kernel_attr_name is None or kernel_attr_name not in attr:
115
+ if n.kernel_attr is None or n.kernel_attr not in attr:
122
116
  Logger.warning(f"The HMSE error method for parameters selection is only supported for "
123
117
  f"kernel weights attributes. Running parameters selection for attribute "
124
118
  f"'{attr}' in node '{n.name}' with the default MSE error method instead.")
@@ -132,7 +126,7 @@ def calculate_quantization_params(graph: Graph,
132
126
  node=n,
133
127
  hessian_info_service=hessian_info_service,
134
128
  num_hessian_samples=num_hessian_samples)
135
- attr_cfg.weights_channels_axis = (output_channels_axis, attr_cfg.weights_channels_axis[1])
129
+ attr_cfg.weights_channels_axis = ChannelAxisMapping(output_channels_axis, attr_cfg.weights_channels_axis.input)
136
130
  attr_cfg.set_weights_quantization_param(weights_params)
137
131
 
138
132
  if n.is_activation_quantization_enabled():
@@ -20,7 +20,7 @@ from model_compression_toolkit.constants import WEIGHTS, ACTIVATION
20
20
  from model_compression_toolkit.core.common import BaseNode
21
21
  from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
22
22
  from model_compression_toolkit.logger import Logger
23
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
+ from model_compression_toolkit.core.common.framework_info import get_fw_info, ChannelAxisMapping
24
24
  from model_compression_toolkit.core.common.graph.base_graph import Graph
25
25
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
26
26
  CandidateNodeQuantizationConfig
@@ -73,7 +73,6 @@ def set_quantization_configuration_to_graph(graph: Graph,
73
73
  set_quantization_configs_to_node(node=n,
74
74
  graph=graph,
75
75
  quant_config=quant_config,
76
- fw_info=graph.fw_info,
77
76
  fqc=graph.fqc,
78
77
  mixed_precision_enable=mixed_precision_enable,
79
78
  manual_bit_width_override=manual_bit_width_override)
@@ -154,7 +153,6 @@ def filter_node_qco_by_graph(node: BaseNode,
154
153
  def set_quantization_configs_to_node(node: BaseNode,
155
154
  graph: Graph,
156
155
  quant_config: QuantizationConfig,
157
- fw_info: FrameworkInfo,
158
156
  fqc: FrameworkQuantizationCapabilities,
159
157
  mixed_precision_enable: bool = False,
160
158
  manual_bit_width_override: Optional[Dict] = None):
@@ -165,7 +163,6 @@ def set_quantization_configs_to_node(node: BaseNode,
165
163
  node (BaseNode): Node to set its quantization configurations.
166
164
  graph (Graph): Model's internal representation graph.
167
165
  quant_config (QuantizationConfig): Quantization configuration to generate the node's configurations from.
168
- fw_info (FrameworkInfo): Information needed for quantization about the specific framework.
169
166
  fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to get default OpQuantizationConfig.
170
167
  mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
171
168
  manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
@@ -186,10 +183,8 @@ def set_quantization_configs_to_node(node: BaseNode,
186
183
  mixed_precision_enable=mixed_precision_enable)
187
184
 
188
185
  # Create QC candidates for weights and activation combined
189
- weight_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
190
186
  node.candidates_quantization_cfg = _create_node_candidates_qc(quant_config,
191
- fw_info,
192
- weight_channel_axis,
187
+ node.channel_axis,
193
188
  node_qc_options_list,
194
189
  base_config,
195
190
  node,
@@ -198,7 +193,7 @@ def set_quantization_configs_to_node(node: BaseNode,
198
193
  # sorting the candidates by kernel attribute weights number of bits first and then by activation number of bits
199
194
  # (in reversed order). since only kernel attribute is quantized in weights mixed precision,
200
195
  # if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
201
- node.sort_node_candidates(fw_info)
196
+ node.sort_node_candidates()
202
197
 
203
198
  for candidate_qc in node.candidates_quantization_cfg:
204
199
  if candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.QUANT and \
@@ -217,14 +212,12 @@ def set_quantization_configs_to_node(node: BaseNode,
217
212
 
218
213
 
219
214
  def create_node_activation_qc(qc: QuantizationConfig,
220
- fw_info: FrameworkInfo,
221
215
  op_cfg: OpQuantizationConfig) -> NodeActivationQuantizationConfig:
222
216
  """
223
217
  Create an activation quantization configuration from a QuantizationConfig object.
224
218
 
225
219
  Args:
226
220
  qc: QuantizationConfig to create the node's config from.
227
- fw_info: Information about the specific framework the node was created from (e.g., whether or not its
228
221
  weights/activations should be quantized)
229
222
  op_cfg: OpQuantizationConfig with quantizers types to set in node quantization configuration.
230
223
 
@@ -232,7 +225,7 @@ def create_node_activation_qc(qc: QuantizationConfig,
232
225
  Activation quantization configuration of a node.
233
226
  """
234
227
 
235
- activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
228
+ activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
236
229
  if activation_quantization_fn is None:
237
230
  Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
238
231
 
@@ -245,8 +238,7 @@ def create_node_activation_qc(qc: QuantizationConfig,
245
238
 
246
239
 
247
240
  def _create_node_single_candidate_qc(qc: QuantizationConfig,
248
- fw_info: FrameworkInfo,
249
- weight_channel_axis: Tuple[int, int],
241
+ weight_channel_axis: ChannelAxisMapping,
250
242
  op_cfg: OpQuantizationConfig,
251
243
  node_attrs_list: List[str]) -> CandidateNodeQuantizationConfig:
252
244
  """
@@ -256,8 +248,6 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig,
256
248
 
257
249
  Args:
258
250
  qc: QuantizationConfig to create the node's config from.
259
- fw_info: Information about the specific framework the node was created from (e.g., whether its
260
- weights/activations should be quantized)
261
251
  weight_channel_axis: (Output, Input) channel index of the node's kernel.
262
252
  op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
263
253
  node_attrs_list: A list of the node's weights attributes names.
@@ -269,7 +259,7 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig,
269
259
  # parameters for weights attributes quantization are set within CandidateNodeQuantizationConfig initialization
270
260
 
271
261
  # get parameters for activation quantization
272
- activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
262
+ activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
273
263
  if activation_quantization_fn is None:
274
264
  Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
275
265
 
@@ -293,8 +283,7 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig,
293
283
 
294
284
 
295
285
  def _create_node_candidates_qc(qc: QuantizationConfig,
296
- fw_info: FrameworkInfo,
297
- weight_channel_axis: Tuple[int, int],
286
+ weight_channel_axis: ChannelAxisMapping,
298
287
  node_qc_options_list: List[OpQuantizationConfig],
299
288
  base_config: OpQuantizationConfig,
300
289
  node: BaseNode,
@@ -304,8 +293,7 @@ def _create_node_candidates_qc(qc: QuantizationConfig,
304
293
 
305
294
  Args:
306
295
  qc (QuantizationConfig): Quantization configuration the quantization process should follow.
307
- fw_info (FrameworkInfo): Framework information (e.g., which layers should have their kernels quantized).
308
- weight_channel_axis (Tuple[int, int]): (Output, Input) channel index of the node's kernel.
296
+ weight_channel_axis (ChannelAxisMapping): (Output, Input) channel index of the node's kernel.
309
297
  node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs of node.
310
298
  base_config (OpQuantizationConfig): Base quantization config for node.
311
299
  node (BaseNode): A node to set quantization configuration candidates to.
@@ -322,14 +310,12 @@ def _create_node_candidates_qc(qc: QuantizationConfig,
322
310
  for op_cfg in node_qc_options_list:
323
311
  candidate_qc = copy.deepcopy(qc)
324
312
  candidates.append(_create_node_single_candidate_qc(candidate_qc,
325
- fw_info,
326
313
  weight_channel_axis,
327
314
  op_cfg,
328
315
  node_attrs_list))
329
316
 
330
317
  else:
331
318
  candidates.append(_create_node_single_candidate_qc(qc,
332
- fw_info,
333
319
  weight_channel_axis,
334
320
  base_config,
335
321
  node_attrs_list))
@@ -38,8 +38,7 @@ def apply_activation_bias_correction_to_graph(graph: Graph,
38
38
 
39
39
  for n in graph.nodes:
40
40
  # Activation bias correction is only relevant for nodes with kernel op
41
- kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
42
- if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \
41
+ if core_config.quantization_config.activation_bias_correction and n.kernel_attr is not None and \
43
42
  n.final_activation_quantization_cfg.activation_bias_correction_term is not None:
44
43
  # If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was
45
44
  # calculated during model preparation, and is used now in the node's bias term.
@@ -41,9 +41,8 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
41
41
  graph = copy.deepcopy(graph_to_apply_bias_correction)
42
42
  for n in graph.nodes:
43
43
  # bias correction is only relevant for nodes with kernel op
44
- kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
45
- if core_config.quantization_config.weights_bias_correction and kernel_attr is not None and \
46
- n.is_weights_quantization_enabled(kernel_attr) and \
44
+ if core_config.quantization_config.weights_bias_correction and n.kernel_attr is not None and \
45
+ n.is_weights_quantization_enabled(n.kernel_attr) and \
47
46
  not n.final_weights_quantization_cfg.weights_second_moment_correction:
48
47
  # If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg,
49
48
  # a bias correction term was calculated during model preparation, and is used now in the node's bias term.