mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250619.621__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/RECORD +123 -123
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +2 -5
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
  6. model_compression_toolkit/core/common/framework_implementation.py +10 -22
  7. model_compression_toolkit/core/common/framework_info.py +105 -68
  8. model_compression_toolkit/core/common/graph/base_graph.py +15 -42
  9. model_compression_toolkit/core/common/graph/base_node.py +103 -42
  10. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  11. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
  12. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
  16. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
  17. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
  18. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
  19. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
  20. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
  21. model_compression_toolkit/core/common/model_collector.py +10 -20
  22. model_compression_toolkit/core/common/model_validation.py +1 -4
  23. model_compression_toolkit/core/common/network_editors/actions.py +14 -38
  24. model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
  25. model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
  26. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
  27. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
  28. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
  29. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
  30. model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
  31. model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
  32. model_compression_toolkit/core/common/pruning/pruner.py +1 -6
  33. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
  34. model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
  35. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
  36. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
  37. model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
  39. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
  40. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
  41. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
  42. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
  43. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
  44. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
  45. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
  46. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
  47. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
  48. model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
  49. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
  50. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
  51. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  52. model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
  53. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
  54. model_compression_toolkit/core/graph_prep_runner.py +2 -16
  55. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
  56. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
  57. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
  58. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
  59. model_compression_toolkit/core/keras/default_framework_info.py +138 -87
  60. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
  61. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
  62. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
  63. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
  64. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  65. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
  66. model_compression_toolkit/core/keras/keras_implementation.py +15 -35
  67. model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
  68. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
  69. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
  70. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  71. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
  72. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
  73. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
  74. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
  75. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
  76. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
  77. model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
  78. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
  79. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
  80. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  81. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
  82. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
  83. model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
  84. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
  85. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  86. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
  87. model_compression_toolkit/core/quantization_prep_runner.py +4 -9
  88. model_compression_toolkit/core/runner.py +5 -15
  89. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  90. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  91. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -0
  92. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +19 -17
  93. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -0
  94. model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
  95. model_compression_toolkit/gptq/common/gptq_training.py +1 -8
  96. model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
  97. model_compression_toolkit/gptq/keras/graph_info.py +4 -6
  98. model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  100. model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
  101. model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
  102. model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
  103. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  104. model_compression_toolkit/gptq/runner.py +1 -7
  105. model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
  106. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
  107. model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
  108. model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
  109. model_compression_toolkit/ptq/runner.py +1 -4
  110. model_compression_toolkit/qat/common/qat_config.py +2 -6
  111. model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
  112. model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
  113. model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
  114. model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
  115. model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
  116. model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
  117. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
  118. model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
  119. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
  120. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
  121. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/WHEEL +0 -0
  122. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/licenses/LICENSE.md +0 -0
  123. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,6 @@ from model_compression_toolkit.core.common.pruning.pruning_framework_implementat
19
19
  PruningFrameworkImplementation
20
20
  from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
21
21
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
22
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
22
  from model_compression_toolkit.core.common import BaseNode
24
23
  from model_compression_toolkit.core.pytorch.constants import BIAS, GROUPS, OUT_CHANNELS, OUT_FEATURES, NUM_FEATURES, \
25
24
  IN_CHANNELS, IN_FEATURES, NUM_PARAMETERS
@@ -39,27 +38,23 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
39
38
 
40
39
  def prune_entry_node(self,
41
40
  node: BaseNode,
42
- output_mask: np.ndarray,
43
- fw_info: FrameworkInfo):
41
+ output_mask: np.ndarray):
44
42
  """
45
43
  Prunes the entry node of a model in Pytorch.
46
44
 
47
45
  Args:
48
46
  node (BaseNode): The entry node to be pruned.
49
47
  output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
50
- fw_info (FrameworkInfo): Framework-specific information object.
51
48
 
52
49
  """
53
50
  return _prune_pytorch_edge_node(node=node,
54
51
  mask=output_mask,
55
- fw_info=fw_info,
56
52
  is_exit_node=False)
57
53
 
58
54
  def prune_intermediate_node(self,
59
55
  node: BaseNode,
60
56
  input_mask: np.ndarray,
61
- output_mask: np.ndarray,
62
- fw_info: FrameworkInfo):
57
+ output_mask: np.ndarray):
63
58
  """
64
59
  Prunes an intermediate node in a Pytorch model.
65
60
 
@@ -67,12 +62,11 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
67
62
  node (BaseNode): The intermediate node to be pruned.
68
63
  input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
69
64
  output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
70
- fw_info (FrameworkInfo): Framework-specific information object.
71
65
 
72
66
  """
73
67
  # TODO (reuvenp/liord): Address handling of node parameters that can be either a single value across all channels or distinct per channel, e.g., PReLU. Consider developing a structured approach.
74
68
  pruning_en = True
75
- _edit_node_input_shape(node, input_mask, fw_info)
69
+ _edit_node_input_shape(node, input_mask)
76
70
  pruned_parameters = {}
77
71
  mask_bool = output_mask.astype(bool)
78
72
  node.weights = pruned_parameters
@@ -91,20 +85,17 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
91
85
 
92
86
  def prune_exit_node(self,
93
87
  node: BaseNode,
94
- input_mask: np.ndarray,
95
- fw_info: FrameworkInfo):
88
+ input_mask: np.ndarray):
96
89
  """
97
90
  Prunes the exit node of a model in Pytorch.
98
91
 
99
92
  Args:
100
93
  node (BaseNode): The exit node to be pruned.
101
94
  input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
102
- fw_info (FrameworkInfo): Framework-specific information object.
103
95
 
104
96
  """
105
97
  return _prune_pytorch_edge_node(node=node,
106
98
  mask=input_mask,
107
- fw_info=fw_info,
108
99
  is_exit_node=True)
109
100
 
110
101
  def is_node_entry_node(self, node: BaseNode) -> bool:
@@ -121,22 +112,19 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
121
112
 
122
113
  def is_node_exit_node(self,
123
114
  node: BaseNode,
124
- corresponding_entry_node: BaseNode,
125
- fw_info: FrameworkInfo) -> bool:
115
+ corresponding_entry_node: BaseNode) -> bool:
126
116
  """
127
117
  Determines whether a node is an exit node in a Pytorch model.
128
118
 
129
119
  Args:
130
120
  node (BaseNode): The node to be checked.
131
121
  corresponding_entry_node (BaseNode): The entry node of the pruning section that is checked.
132
- fw_info (FrameworkInfo) Framework-specific information object.
133
122
 
134
123
  Returns:
135
124
  bool: Boolean indicating if the node is an exit node.
136
125
  """
137
126
  return _is_pytorch_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node,
138
- corresponding_entry_node,
139
- fw_info)
127
+ corresponding_entry_node)
140
128
 
141
129
  def is_node_intermediate_pruning_section(self, node: BaseNode) -> bool:
142
130
  """
@@ -155,8 +143,7 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
155
143
  torch.nn.Linear]
156
144
 
157
145
  def attrs_oi_channels_info_for_pruning(self,
158
- node: BaseNode,
159
- fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
146
+ node: BaseNode) -> Dict[str, Tuple[int, int]]:
160
147
  """
161
148
  Retrieves the attributes of a given node along with the output/input (OI) channel axis
162
149
  for each attribute used to prune these attributes.
@@ -173,7 +160,6 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
173
160
 
174
161
  Args:
175
162
  node (BaseNode): The node from the computational graph.
176
- fw_info (FrameworkInfo): Contains framework-specific information and utilities.
177
163
 
178
164
  Returns:
179
165
  Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'weight' or 'bias')
@@ -181,13 +167,8 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
181
167
  """
182
168
 
183
169
  attributes_with_axis = {}
184
- if fw_info.is_kernel_op(node.type):
185
- kernel_attributes = fw_info.get_kernel_op_attributes(node.type)
186
- if kernel_attributes is None or len(kernel_attributes) == 0:
187
- Logger.critical(f"Expected to find kernel attributes but none were identified for node '{node.name}' of type {node.type}.")
188
-
189
- for attr in kernel_attributes:
190
- attributes_with_axis[attr] = fw_info.kernel_channels_mapping.get(node.type)
170
+ if node.is_kernel_op:
171
+ attributes_with_axis[node.kernel_attr] = (node.channel_axis.output, node.channel_axis.input)
191
172
 
192
173
  # Bias is a vector at the length of the number of output channels.
193
174
  # For this reason, input channel axis is irrelevant to the bias attribute.
@@ -202,7 +183,7 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
202
183
  # If the number of float parameters is 1 or less - is the case where
203
184
  # we have one parameter for all channels. For this case, we don't
204
185
  # want to prune the parameter.
205
- if node.get_num_parameters(fw_info)[1] <= 1:
186
+ if node.get_num_parameters()[1] <= 1:
206
187
  attributes_with_axis[attr] = (None, None)
207
188
  else:
208
189
  attributes_with_axis[attr] = (-1, None)
@@ -234,7 +215,6 @@ def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
234
215
 
235
216
  def _prune_pytorch_edge_node(node: BaseNode,
236
217
  mask: np.ndarray,
237
- fw_info: FrameworkInfo,
238
218
  is_exit_node: bool):
239
219
  """
240
220
  Prunes the given Pytorch node by applying the mask to the node's weights (weights and biases).
@@ -243,21 +223,18 @@ def _prune_pytorch_edge_node(node: BaseNode,
243
223
  Args:
244
224
  node (BaseNode): The node to be pruned.
245
225
  mask (np.ndarray): The pruning mask to be applied.
246
- fw_info (FrameworkInfo): Framework-specific information object.
247
226
  is_exit_node (bool): A boolean indicating whether the node is an exit node.
248
227
 
249
228
  """
250
229
 
251
230
  # Retrieve the kernel attribute and the axes to prune.
252
- kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
253
- io_axis = fw_info.kernel_channels_mapping.get(node.type)
254
- axis_to_prune = io_axis[int(is_exit_node)]
255
- kernel = node.get_weights_by_keys(kernel_attr)
231
+ axis_to_prune = node.channel_axis.input if is_exit_node else node.channel_axis.output
232
+ kernel = node.get_weights_by_keys(node.kernel_attr)
256
233
  # Convert mask to boolean.
257
234
  mask_bool = mask.astype(bool)
258
235
 
259
236
  pruned_kernel = kernel.compress(mask_bool, axis=axis_to_prune)
260
- node.set_weights_by_keys(name=kernel_attr, tensor=pruned_kernel)
237
+ node.set_weights_by_keys(name=node.kernel_attr, tensor=pruned_kernel)
261
238
 
262
239
  if not is_exit_node and node.framework_attr[BIAS]:
263
240
  # Prune the bias if applicable and it's an entry node.
@@ -285,12 +262,11 @@ def _prune_pytorch_edge_node(node: BaseNode,
285
262
  Logger.critical(f"{node.type} is currently not supported"
286
263
  f"as an edge node in a pruning section")
287
264
  # Adjust the input shape for the last node in the section.
288
- _edit_node_input_shape(node, mask_bool, fw_info)
265
+ _edit_node_input_shape(node, mask_bool)
289
266
 
290
267
 
291
268
  def _edit_node_input_shape(node: BaseNode,
292
- input_mask: np.ndarray,
293
- fw_info: FrameworkInfo):
269
+ input_mask: np.ndarray):
294
270
  """
295
271
  Adjusts the input shape of a node based on the given input mask.
296
272
 
@@ -301,14 +277,13 @@ def _edit_node_input_shape(node: BaseNode,
301
277
  Args:
302
278
  node (BaseNode): The node whose input shape needs to be adjusted.
303
279
  input_mask (np.ndarray): A binary array where 1 indicates the channel is kept and 0 means pruned.
304
- fw_info (FrameworkInfo): Framework-specific information object.
305
280
  """
306
281
  # Start with the current input shape of the node.
307
282
  new_input_shape = list(node.input_shape)
308
283
 
309
284
  # Adjust the last dimension of the shape to match the number of unpruned (retained) channels.
310
285
  # This is done by summing the mask, as each '1' in the mask represents a retained channel.
311
- channel_axis = fw_info.out_channel_axis_mapping.get(node.type)
286
+ channel_axis = node.out_channel_axis
312
287
  new_input_shape[0][channel_axis] = int(np.sum(input_mask))
313
288
 
314
289
  # Update the node's input shape with the new dimensions.
@@ -37,7 +37,6 @@ from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
37
37
  from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_kl_divergence, compute_cs
38
38
  from model_compression_toolkit.core.pytorch.back2framework import get_pytorch_model_builder
39
39
  from model_compression_toolkit.core.pytorch.data_util import data_gen_to_dataloader
40
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
41
40
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_folding import \
42
41
  pytorch_batchnorm_folding, pytorch_batchnorm_forward_folding
43
42
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_reconstruction import \
@@ -178,7 +177,6 @@ class PytorchImplementation(FrameworkImplementation):
178
177
  graph: Graph,
179
178
  mode: ModelBuilderMode,
180
179
  append2output: List[Any] = None,
181
- fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
182
180
  return_float_outputs: bool = False) -> Tuple:
183
181
  """
184
182
  Build a Pytorch module from a graph.
@@ -189,7 +187,6 @@ class PytorchImplementation(FrameworkImplementation):
189
187
  graph: Graph to build the module from it.
190
188
  mode: Mode for how to build the module.
191
189
  append2output: List of Nodes to set as the module's outputs.
192
- fw_info: FrameworkInfo object with information about the specific framework's module
193
190
  return_float_outputs (bool): whether to return outputs before or after quantization nodes (default)
194
191
 
195
192
  Returns:
@@ -198,7 +195,6 @@ class PytorchImplementation(FrameworkImplementation):
198
195
  pytorch_model_builder = get_pytorch_model_builder(mode)
199
196
  return pytorch_model_builder(graph=graph,
200
197
  append2output=append2output,
201
- fw_info=fw_info,
202
198
  return_float_outputs=return_float_outputs).build_model()
203
199
 
204
200
  def run_model_inference(self,
@@ -232,63 +228,55 @@ class PytorchImplementation(FrameworkImplementation):
232
228
 
233
229
  def shift_negative_correction(self,
234
230
  graph: Graph,
235
- core_config: CoreConfig,
236
- fw_info: FrameworkInfo) -> Graph:
231
+ core_config: CoreConfig) -> Graph:
237
232
  """
238
233
  Apply shift negative correction (SNC) on a graph.
239
234
 
240
235
  Args:
241
236
  graph: Graph to apply SNC on.
242
237
  core_config: Quantization configuration.
243
- fw_info: FrameworkInfo object with information about the specific framework's module.
244
238
 
245
239
  Returns:
246
240
  Graph after SNC.
247
241
  """
248
242
  return pytorch_apply_shift_negative_correction(graph,
249
- core_config,
250
- fw_info)
243
+ core_config)
251
244
 
252
245
  def compute_activation_bias_correction(self,
253
246
  graph: Graph,
254
- quant_config: QuantizationConfig,
255
- fw_info: FrameworkInfo):
247
+ quant_config: QuantizationConfig):
256
248
  """
257
249
  Compute activation bias correction on a graph.
258
250
 
259
251
  Args:
260
252
  graph: Graph to apply activation bias correction on.
261
253
  quant_config: QuantizationConfig of how the model should be quantized.
262
- fw_info: FrameworkInfo object with information about the specific framework's model.
263
254
 
264
255
  Returns:
265
256
  Graph after activation bias correction computing.
266
257
  """
267
258
  return pytorch_compute_activation_bias_correction_of_graph(graph=graph,
268
259
  quant_config=quant_config,
269
- fw_info=fw_info,
270
260
  fw_impl=self)
271
261
 
272
262
  def get_substitutions_channel_equalization(self,
273
- quant_config: QuantizationConfig,
274
- fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
263
+ quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
275
264
  """
276
265
  Return a list of the framework substitutions used for channel equalization.
277
266
 
278
267
  Args:
279
268
  quant_config: QuantizationConfig to determine which substitutions to return.
280
- fw_info: FrameworkInfo object with information about the specific framework's model.
281
269
 
282
270
  Returns:
283
271
  A list of the framework substitutions used after we collect statistics.
284
272
  """
285
273
  substitutions_list = []
286
274
  if quant_config.activation_channel_equalization:
287
- substitutions_list.extend([ScaleEqualization(quant_config, fw_info),
288
- ScaleEqualizationWithPad(quant_config, fw_info)])
275
+ substitutions_list.extend([ScaleEqualization(quant_config),
276
+ ScaleEqualizationWithPad(quant_config)])
289
277
  return substitutions_list
290
278
 
291
- def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List[common.BaseSubstitution]:
279
+ def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
292
280
  """
293
281
 
294
282
  Returns: A list of the framework substitutions used before we collect the prior information.
@@ -299,7 +287,7 @@ class PytorchImplementation(FrameworkImplementation):
299
287
  ScaledDotProductDecomposition(),
300
288
  MatMulDecomposition(),
301
289
  TransformFunctionCallMethod(),
302
- FunctionalConvSubstitution(fw_info),
290
+ FunctionalConvSubstitution(),
303
291
  FunctionalBatchNorm(),
304
292
  FunctionalLayerNorm(),
305
293
  FunctionalLinear(),
@@ -401,20 +389,17 @@ class PytorchImplementation(FrameworkImplementation):
401
389
 
402
390
  def get_node_prior_info(self,
403
391
  node: BaseNode,
404
- fw_info: FrameworkInfo,
405
392
  graph: Graph) -> NodePriorInfo:
406
393
  """
407
394
  Get a NodePriorInfo object for a node that represents a Pytorch layer.
408
395
  Args:
409
396
  node: Node to get its prior info.
410
- fw_info: Framework specific information needed to create the prior info of the node.
411
397
  graph: Graph to check the next node type.
412
398
  Returns:
413
399
  NodePriorInfo with information about the node.
414
400
  """
415
401
 
416
402
  return create_node_prior_info(node=node,
417
- fw_info=fw_info,
418
403
  graph=graph)
419
404
 
420
405
  def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
@@ -476,23 +461,19 @@ class PytorchImplementation(FrameworkImplementation):
476
461
  return node.layer_class not in [argmax, softmax, Softmax]
477
462
 
478
463
  def get_node_mac_operations(self,
479
- node: BaseNode,
480
- fw_info: FrameworkInfo) -> float:
464
+ node: BaseNode) -> float:
481
465
  """
482
466
  Gets the MAC operation count for a given operation.
483
467
 
484
468
  Args:
485
469
  node: A graph node that wraps the operation for which the MAC count is computed.
486
- fw_info: FrameworkInfo object with information about the Pytorch model.
487
470
 
488
471
  Returns: The MAC count of the operation
489
472
  """
490
- kernels = fw_info.get_kernel_op_attributes(node.type)
491
- if not kernels or kernels[0] is None:
473
+ if node.kernel_attr is None:
492
474
  return 0
493
475
 
494
- assert len(kernels) == 1
495
- kernel_shape = node.get_weights_by_keys(kernels[0]).shape
476
+ kernel_shape = node.get_weights_by_keys(node.kernel_attr).shape
496
477
 
497
478
  if node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d):
498
479
  h, w = node.get_output_shapes_list()[0][-2:]
@@ -500,8 +481,7 @@ class PytorchImplementation(FrameworkImplementation):
500
481
 
501
482
  if node.is_match_type(Linear):
502
483
  # IN * OUT * (all previous dims[:-1])
503
- _, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
504
- return node.get_total_output_params() * kernel_shape[input_channel_axis]
484
+ return node.get_total_output_params() * kernel_shape[node.channel_axis.input]
505
485
 
506
486
  return 0
507
487
 
@@ -23,23 +23,19 @@ from model_compression_toolkit.core.pytorch.constants import MOVING_MEAN, MOVING
23
23
 
24
24
 
25
25
  def create_node_prior_info(node: BaseNode,
26
- fw_info: FrameworkInfo,
27
26
  graph: Graph):
28
27
  """
29
28
  Create a NodePriorInfo object for a given node.
30
29
 
31
30
  Args:
32
31
  node: Node to create its prior info.
33
- fw_info: Information about a specific framework the node was generated from.
34
32
  graph: Graph to check the next node type.
35
33
 
36
34
  Returns:
37
35
  NodePriorInfo object with info about the node.
38
36
  """
39
37
 
40
- min_output, max_output = None, None
41
- if fw_info.layers_has_min_max(node.type):
42
- min_output, max_output = fw_info.layer_min_max_mapping[node.type]
38
+ min_output, max_output = node.minmax
43
39
  mean_output, std_output = _get_mean_std_outputs(node=node,
44
40
  graph=graph)
45
41
  return NodePriorInfo(min_output=min_output,
@@ -27,7 +27,7 @@ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler impor
27
27
  from model_compression_toolkit.verify_packages import FOUND_TORCH
28
28
 
29
29
  if FOUND_TORCH:
30
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
30
+ from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
31
31
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
32
32
  from torch.nn import Module
33
33
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
@@ -38,6 +38,7 @@ if FOUND_TORCH:
38
38
  PYTORCH_DEFAULT_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
39
39
 
40
40
 
41
+ @set_pytorch_info
41
42
  def pytorch_resource_utilization_data(in_model: Module,
42
43
  representative_data_gen: Callable,
43
44
  core_config: CoreConfig = CoreConfig(),
@@ -93,7 +94,6 @@ if FOUND_TORCH:
93
94
  representative_data_gen,
94
95
  core_config,
95
96
  target_platform_capabilities,
96
- DEFAULT_PYTORCH_INFO,
97
97
  fw_impl)
98
98
 
99
99
  else:
@@ -33,7 +33,6 @@ def activation_bias_correction_node_matchers():
33
33
 
34
34
  def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
35
35
  quant_config: QuantizationConfig,
36
- fw_info: FrameworkInfo,
37
36
  fw_impl: FrameworkImplementation) -> Graph:
38
37
  """
39
38
  Compute the activation bias correction term for graph based on a PyTorch model.
@@ -41,7 +40,6 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
41
40
  Args:
42
41
  graph: Graph with nodes to compute the activation bias correction.
43
42
  quant_config: QuantizationConfig of how the model should be quantized.
44
- fw_info: Framework info like lists of nodes their kernel should quantized.
45
43
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
46
44
 
47
45
  Returns:
@@ -49,7 +47,6 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
49
47
  """
50
48
  graph = compute_activation_bias_correction_of_graph(graph=graph,
51
49
  quant_config=quant_config,
52
- fw_info=fw_info,
53
50
  fw_impl=fw_impl,
54
51
  activation_bias_correction_node_matchers=
55
52
  activation_bias_correction_node_matchers,
@@ -37,7 +37,6 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
37
37
  def quantization_preparation_runner(graph: Graph,
38
38
  representative_data_gen: Callable,
39
39
  core_config: CoreConfig,
40
- fw_info: FrameworkInfo,
41
40
  fw_impl: FrameworkImplementation,
42
41
  tb_w: TensorboardWriter = None,
43
42
  hessian_info_service: HessianInfoService = None, ) -> Graph:
@@ -53,8 +52,6 @@ def quantization_preparation_runner(graph: Graph,
53
52
  graph: A graph representation of the model to be quantized.
54
53
  representative_data_gen: Dataset used for calibration.
55
54
  core_config: CoreConfig containing parameters of how the model should be quantized
56
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
57
- groups of layers by how they should be quantized, etc.).
58
55
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
59
56
  tb_w: TensorboardWriter object for logging
60
57
  hessian_info_service: HessianInfoService object for retrieving Hessian-based scores.
@@ -68,7 +65,6 @@ def quantization_preparation_runner(graph: Graph,
68
65
  ######################################
69
66
  mi = ModelCollector(graph,
70
67
  fw_impl,
71
- fw_info,
72
68
  hessian_info_service,
73
69
  core_config.quantization_config) # Mark points for statistics collection
74
70
 
@@ -85,7 +81,7 @@ def quantization_preparation_runner(graph: Graph,
85
81
  # Notice that not all actions affect at this stage (for example, actions that edit the final configuration as
86
82
  # there are no final configurations at this stage of the optimization). For this reason we edit the graph
87
83
  # again at the end of the optimization process.
88
- edit_network_graph(graph, fw_info, core_config.debug_config.network_editor)
84
+ edit_network_graph(graph, core_config.debug_config.network_editor)
89
85
 
90
86
  ######################################
91
87
  # Calculate quantization params
@@ -109,8 +105,7 @@ def quantization_preparation_runner(graph: Graph,
109
105
  ######################################
110
106
  if core_config.quantization_config.shift_negative_activation_correction:
111
107
  transformed_graph = fw_impl.shift_negative_correction(transformed_graph,
112
- core_config,
113
- fw_info)
108
+ core_config)
114
109
  if tb_w is not None:
115
110
  tb_w.add_graph(transformed_graph, 'after_shift_negative_correction')
116
111
  tb_w.add_all_statistics(transformed_graph, 'after_shift_negative_correction')
@@ -122,9 +117,9 @@ def quantization_preparation_runner(graph: Graph,
122
117
  ######################################
123
118
  # Statistics Correction
124
119
  ######################################
125
- tg_with_bias = statistics_correction_runner(transformed_graph, core_config, fw_info, fw_impl, tb_w)
120
+ tg_with_bias = statistics_correction_runner(transformed_graph, core_config, fw_impl, tb_w)
126
121
 
127
122
  for n in tg_with_bias.nodes:
128
123
  assert n.final_weights_quantization_cfg is None
129
124
 
130
- return tg_with_bias
125
+ return tg_with_bias
@@ -16,7 +16,6 @@
16
16
  import copy
17
17
  from typing import Callable, Any, List, Optional
18
18
 
19
- from model_compression_toolkit.core.common import FrameworkInfo
20
19
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
20
  from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser
22
21
  from model_compression_toolkit.core.common.graph.base_graph import Graph
@@ -46,7 +45,6 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
46
45
  def core_runner(in_model: Any,
47
46
  representative_data_gen: Callable,
48
47
  core_config: CoreConfig,
49
- fw_info: FrameworkInfo,
50
48
  fw_impl: FrameworkImplementation,
51
49
  fqc: FrameworkQuantizationCapabilities,
52
50
  target_resource_utilization: ResourceUtilization = None,
@@ -65,7 +63,6 @@ def core_runner(in_model: Any,
65
63
  in_model: Model to quantize.
66
64
  representative_data_gen: Dataset used for calibration.
67
65
  core_config: CoreConfig containing parameters of how the model should be quantized
68
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
69
66
  groups of layers by how they should be quantized, etc.).
70
67
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
71
68
  fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
@@ -99,7 +96,6 @@ def core_runner(in_model: Any,
99
96
  graph = graph_preparation_runner(in_model,
100
97
  representative_data_gen,
101
98
  core_config.quantization_config,
102
- fw_info,
103
99
  fw_impl,
104
100
  fqc,
105
101
  core_config.bit_width_config,
@@ -112,7 +108,6 @@ def core_runner(in_model: Any,
112
108
  tg = quantization_preparation_runner(graph=graph,
113
109
  representative_data_gen=representative_data_gen,
114
110
  core_config=core_config,
115
- fw_info=fw_info,
116
111
  fw_impl=fw_impl,
117
112
  tb_w=tb_w,
118
113
  hessian_info_service=hessian_info_service)
@@ -123,9 +118,8 @@ def core_runner(in_model: Any,
123
118
  if core_config.is_mixed_precision_enabled:
124
119
  if core_config.mixed_precision_config.configuration_overwrite is None:
125
120
 
126
- filter_candidates_for_mixed_precision(graph, target_resource_utilization, fw_info, fqc)
121
+ filter_candidates_for_mixed_precision(graph, target_resource_utilization, fqc)
127
122
  bit_widths_config = search_bit_width(tg,
128
- fw_info,
129
123
  fw_impl,
130
124
  target_resource_utilization,
131
125
  core_config.mixed_precision_config,
@@ -153,22 +147,20 @@ def core_runner(in_model: Any,
153
147
  ######################################
154
148
  if core_config.quantization_config.activation_bias_correction:
155
149
  tg = fw_impl.compute_activation_bias_correction(graph=tg,
156
- quant_config=core_config.quantization_config,
157
- fw_info=fw_info)
150
+ quant_config=core_config.quantization_config)
158
151
 
159
152
  # Edit the graph again after finalizing the configurations.
160
153
  # This is since some actions regard the final configuration and should be edited.
161
- edit_network_graph(tg, fw_info, core_config.debug_config.network_editor)
154
+ edit_network_graph(tg, core_config.debug_config.network_editor)
162
155
 
163
156
  _set_final_resource_utilization(graph=tg,
164
157
  final_bit_widths_config=bit_widths_config,
165
158
  target_resource_utilization=target_resource_utilization,
166
- fw_info=fw_info,
167
159
  fw_impl=fw_impl)
168
160
 
169
161
  if core_config.is_mixed_precision_enabled:
170
162
  # Retrieve lists of tuples (node, node's final weights/activation bitwidth)
171
- weights_conf_nodes_bitwidth = tg.get_final_weights_config(fw_info)
163
+ weights_conf_nodes_bitwidth = tg.get_final_weights_config()
172
164
  activation_conf_nodes_bitwidth = tg.get_final_activation_config()
173
165
 
174
166
  if len(weights_conf_nodes_bitwidth) > 0:
@@ -200,7 +192,6 @@ def core_runner(in_model: Any,
200
192
  def _set_final_resource_utilization(graph: Graph,
201
193
  final_bit_widths_config: List[int],
202
194
  target_resource_utilization: Optional[ResourceUtilization],
203
- fw_info: FrameworkInfo,
204
195
  fw_impl: FrameworkImplementation):
205
196
  """
206
197
  Computing the resource utilization of the model according to the final bit-width configuration,
@@ -210,14 +201,13 @@ def _set_final_resource_utilization(graph: Graph,
210
201
  graph: Graph to compute the resource utilization for.
211
202
  final_bit_widths_config: The final bit-width configuration to quantize the model accordingly.
212
203
  target_resource_utilization: Requested target resource utilization if relevant.
213
- fw_info: A FrameworkInfo object.
214
204
  fw_impl: FrameworkImplementation object with specific framework methods implementation.
215
205
 
216
206
  """
217
207
  ru_targets = target_resource_utilization.get_restricted_targets() if target_resource_utilization else None
218
208
  final_ru = None
219
209
  if ru_targets:
220
- ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
210
+ ru_calculator = ResourceUtilizationCalculator(graph, fw_impl)
221
211
  w_qcs = {n.name: n.final_weights_quantization_cfg for n in graph.nodes}
222
212
  a_qcs = {n.name: n.final_activation_quantization_cfg for n in graph.nodes}
223
213
  final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused,
@@ -56,7 +56,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
56
56
  super(ReduceLROnPlateau, self).__init__()
57
57
 
58
58
  if factor >= 1.0:
59
- Logger.critical('Factor should be < 1.0.') # pragma: no cover
59
+ Logger.critical('Factor should be < 1.0.') # pragma: no cover
60
60
  self.factor = factor
61
61
 
62
62
  self.optimizer = optimizer
@@ -101,7 +101,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
101
101
  else:
102
102
  self.num_bad_epochs += 1
103
103
 
104
- if self.in_cooldown:
104
+ if self.in_cooldown: # pragma: no cover
105
105
  self.cooldown_counter -= 1
106
106
  self.num_bad_epochs = 0 # Ignore any bad epochs in cooldown
107
107
 
@@ -122,7 +122,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
122
122
  new_lr = max(old_lr * self.factor, self.min_lr)
123
123
  if old_lr - new_lr > self.eps:
124
124
  tf.keras.backend.set_value(self.optimizer.learning_rate, new_lr)
125
- if self.verbose:
125
+ if self.verbose: # pragma: no cover
126
126
  print(f'Epoch {epoch:05d}: reducing learning rate to {new_lr:.4e}.')
127
127
 
128
128
  @property
@@ -152,13 +152,13 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
152
152
  if self.mode == 'min' and self.threshold_mode == 'rel':
153
153
  rel_epsilon = 1. - self.threshold
154
154
  return a < best * rel_epsilon
155
- elif self.mode == 'min' and self.threshold_mode == 'abs':
155
+ elif self.mode == 'min' and self.threshold_mode == 'abs': # pragma: no cover
156
156
  return a < best - self.threshold
157
- elif self.mode == 'max' and self.threshold_mode == 'rel':
157
+ elif self.mode == 'max' and self.threshold_mode == 'rel': # pragma: no cover
158
158
  rel_epsilon = self.threshold + 1.
159
159
  return a > best * rel_epsilon
160
160
  else: # mode == 'max' and threshold_mode == 'abs':
161
- return a > best + self.threshold
161
+ return a > best + self.threshold # pragma: no cover
162
162
 
163
163
  def _init_is_better(self, mode: str, threshold: float, threshold_mode: str) -> None:
164
164
  """
@@ -186,7 +186,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
186
186
  self.threshold = threshold
187
187
  self.threshold_mode = threshold_mode
188
188
 
189
- def get_config(self) -> Dict:
189
+ def get_config(self) -> Dict: # pragma: no cover
190
190
  """
191
191
  Return the configuration of the scheduler as a dictionary.
192
192
 
@@ -207,7 +207,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
207
207
  base_config = super(ReduceLROnPlateau, self).get_config()
208
208
  return {**base_config, **config}
209
209
 
210
- def set_config(self, config: Dict) -> None:
210
+ def set_config(self, config: Dict) -> None: # pragma: no cover
211
211
  """
212
212
  Set the configuration of the scheduler from a dictionary.
213
213