mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250619.621__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/RECORD +123 -123
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +2 -5
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
  6. model_compression_toolkit/core/common/framework_implementation.py +10 -22
  7. model_compression_toolkit/core/common/framework_info.py +105 -68
  8. model_compression_toolkit/core/common/graph/base_graph.py +15 -42
  9. model_compression_toolkit/core/common/graph/base_node.py +103 -42
  10. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  11. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
  12. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
  16. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
  17. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
  18. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
  19. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
  20. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
  21. model_compression_toolkit/core/common/model_collector.py +10 -20
  22. model_compression_toolkit/core/common/model_validation.py +1 -4
  23. model_compression_toolkit/core/common/network_editors/actions.py +14 -38
  24. model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
  25. model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
  26. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
  27. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
  28. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
  29. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
  30. model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
  31. model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
  32. model_compression_toolkit/core/common/pruning/pruner.py +1 -6
  33. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
  34. model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
  35. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
  36. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
  37. model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
  39. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
  40. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
  41. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
  42. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
  43. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
  44. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
  45. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
  46. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
  47. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
  48. model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
  49. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
  50. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
  51. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  52. model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
  53. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
  54. model_compression_toolkit/core/graph_prep_runner.py +2 -16
  55. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
  56. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
  57. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
  58. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
  59. model_compression_toolkit/core/keras/default_framework_info.py +138 -87
  60. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
  61. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
  62. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
  63. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
  64. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  65. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
  66. model_compression_toolkit/core/keras/keras_implementation.py +15 -35
  67. model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
  68. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
  69. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
  70. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  71. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
  72. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
  73. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
  74. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
  75. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
  76. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
  77. model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
  78. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
  79. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
  80. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  81. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
  82. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
  83. model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
  84. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
  85. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  86. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
  87. model_compression_toolkit/core/quantization_prep_runner.py +4 -9
  88. model_compression_toolkit/core/runner.py +5 -15
  89. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  90. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  91. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -0
  92. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +19 -17
  93. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -0
  94. model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
  95. model_compression_toolkit/gptq/common/gptq_training.py +1 -8
  96. model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
  97. model_compression_toolkit/gptq/keras/graph_info.py +4 -6
  98. model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  100. model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
  101. model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
  102. model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
  103. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  104. model_compression_toolkit/gptq/runner.py +1 -7
  105. model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
  106. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
  107. model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
  108. model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
  109. model_compression_toolkit/ptq/runner.py +1 -4
  110. model_compression_toolkit/qat/common/qat_config.py +2 -6
  111. model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
  112. model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
  113. model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
  114. model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
  115. model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
  116. model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
  117. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
  118. model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
  119. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
  120. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
  121. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/WHEEL +0 -0
  122. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/licenses/LICENSE.md +0 -0
  123. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/top_level.txt +0 -0
@@ -32,8 +32,7 @@ def analyzer_model_quantization(representative_data_gen: Callable,
32
32
  tb_w: TensorboardWriter,
33
33
  float_graph: Graph,
34
34
  quantized_graph: Graph,
35
- fw_impl: FrameworkImplementation,
36
- fw_info: FrameworkInfo):
35
+ fw_impl: FrameworkImplementation):
37
36
  """
38
37
  Plot the cosine similarity of different points on the graph between the float and quantized
39
38
  graphs. Add them to the passed TensorboardWriter object and close all tensorboard writer open
@@ -45,14 +44,12 @@ def analyzer_model_quantization(representative_data_gen: Callable,
45
44
  float_graph: Graph of float model.
46
45
  quantized_graph: Graph of quantized model.
47
46
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
48
- fw_info: Information needed for quantization about the specific framework.
49
47
 
50
48
  """
51
49
  if tb_w is not None:
52
50
  visual = NNVisualizer(float_graph,
53
51
  quantized_graph,
54
- fw_impl=fw_impl,
55
- fw_info=fw_info)
52
+ fw_impl=fw_impl)
56
53
  if not visual.has_compare_points():
57
54
  Logger.error(f'No comparing points were found to plot analyze similarity.')
58
55
  else:
@@ -28,20 +28,17 @@ class BaseModelBuilder(ABC):
28
28
  def __init__(self,
29
29
  graph: common.Graph,
30
30
  append2output=None,
31
- fw_info: FrameworkInfo = None,
32
31
  return_float_outputs: bool = False):
33
32
  """
34
33
 
35
34
  Args:
36
35
  graph: Graph to build the model from.
37
36
  append2output: Nodes of graph to append to model's output.
38
- fw_info: Information about the specific framework of the model that is built.
39
37
  return_float_outputs: Whether the model returns float tensors or not.
40
38
  """
41
39
 
42
40
  self.graph = graph
43
41
  self.append2output = append2output
44
- self.fw_info = fw_info
45
42
  self.return_float_outputs = return_float_outputs
46
43
 
47
44
  @abstractmethod
@@ -125,18 +125,16 @@ class FrameworkImplementation(ABC):
125
125
  graph: Graph,
126
126
  mode: ModelBuilderMode,
127
127
  append2output: List[Any],
128
- fw_info: FrameworkInfo,
129
128
  return_float_outputs: bool = False) -> Tuple:
130
129
  """
131
130
  Build a framework model from a graph.
132
- The mode determines how the model should be build. append2output is a list of Nodes
131
+ The mode determines how the model should be built. append2output is a list of Nodes
133
132
  to set as the model outputs.
134
133
 
135
134
  Args:
136
135
  graph: Graph to build the model from it.
137
136
  mode: Mode for how to build the model.
138
137
  append2output: List of Nodes to set as the model's outputs.
139
- fw_info: FrameworkInfo object with information about the specific framework's model
140
138
  return_float_outputs (bool): whether to return outputs before or after quantization nodes (default)
141
139
 
142
140
  Returns:
@@ -170,15 +168,13 @@ class FrameworkImplementation(ABC):
170
168
  @abstractmethod
171
169
  def shift_negative_correction(self,
172
170
  graph: Graph,
173
- core_config: CoreConfig,
174
- fw_info: FrameworkInfo) -> Graph:
171
+ core_config: CoreConfig) -> Graph:
175
172
  """
176
173
  Apply shift negative correction (SNC) on a graph.
177
174
 
178
175
  Args:
179
176
  graph: Graph to apply SNC on.
180
177
  core_config: Quantization configuration.
181
- fw_info: FrameworkInfo object with information about the specific framework's model.
182
178
 
183
179
  Returns:
184
180
  Graph after SNC.
@@ -189,15 +185,13 @@ class FrameworkImplementation(ABC):
189
185
  @abstractmethod
190
186
  def compute_activation_bias_correction(self,
191
187
  graph: Graph,
192
- quant_config: QuantizationConfig,
193
- fw_info: FrameworkInfo) -> Graph:
188
+ quant_config: QuantizationConfig) -> Graph:
194
189
  """
195
190
  Compute activation bias correction on a graph.
196
191
 
197
192
  Args:
198
193
  graph: Graph to apply activation bias correction on.
199
194
  quant_config: QuantizationConfig of how the model should be quantized.
200
- fw_info: FrameworkInfo object with information about the specific framework's model.
201
195
 
202
196
  Returns:
203
197
  Graph after activation bias correction computing.
@@ -207,30 +201,28 @@ class FrameworkImplementation(ABC):
207
201
 
208
202
  @abstractmethod
209
203
  def get_substitutions_channel_equalization(self,
210
- quant_config: QuantizationConfig,
211
- fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
204
+ quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
212
205
  """
213
206
  Return a list of the framework substitutions used for channel equalization.
214
207
 
215
208
  Args:
216
209
  quant_config: QuantizationConfig to determine which substitutions to return.
217
- fw_info: FrameworkInfo object with information about the specific framework's model.
218
210
 
219
211
  Returns:
220
212
  A list of the framework substitutions used after we collect statistics.
221
213
  """
222
214
  raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
223
- f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover
215
+ f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover
224
216
 
225
217
  @abstractmethod
226
- def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List[common.BaseSubstitution]:
218
+ def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
227
219
  """
228
220
 
229
221
  Returns: A list of the framework substitutions used to prepare the graph.
230
222
 
231
223
  """
232
224
  raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
233
- f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover
225
+ f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover
234
226
 
235
227
  @abstractmethod
236
228
  def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig) -> \
@@ -328,14 +320,12 @@ class FrameworkImplementation(ABC):
328
320
  f'method.') # pragma: no cover
329
321
 
330
322
  def get_node_prior_info(self, node: BaseNode,
331
- fw_info: FrameworkInfo,
332
323
  graph: Graph) -> NodePriorInfo:
333
324
  """
334
325
  Get a NodePriorInfo object for a node.
335
326
 
336
327
  Args:
337
328
  node: Node to get its prior info.
338
- fw_info: Framework specific information needed to create the prior info of the node.
339
329
  graph: Graph to check the next node type.
340
330
 
341
331
  Returns:
@@ -343,7 +333,7 @@ class FrameworkImplementation(ABC):
343
333
  """
344
334
 
345
335
  raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
346
- f'framework\'s get_node_prior_info method.') # pragma: no cover
336
+ f'framework\'s get_node_prior_info method.') # pragma: no cover
347
337
 
348
338
  def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
349
339
  """
@@ -394,20 +384,18 @@ class FrameworkImplementation(ABC):
394
384
 
395
385
  @abstractmethod
396
386
  def get_node_mac_operations(self,
397
- node: BaseNode,
398
- fw_info: FrameworkInfo) -> float:
387
+ node: BaseNode) -> float:
399
388
  """
400
389
  Gets the MAC operation count for a given operation.
401
390
 
402
391
  Args:
403
392
  node: A graph node that wraps the operation for which the MAC count is computed.
404
- fw_info: FrameworkInfo object with information about the specific framework's model.
405
393
 
406
394
  Returns: The MAC count of the operation
407
395
  """
408
396
 
409
397
  raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
410
- f'framework\'s get_node_mac_operations method.') # pragma: no cover
398
+ f'framework\'s get_node_mac_operations method.') # pragma: no cover
411
399
 
412
400
  @abstractmethod
413
401
  def apply_second_moment_correction(self,
@@ -16,16 +16,16 @@
16
16
 
17
17
  from collections.abc import Callable
18
18
  from enum import Enum
19
- from typing import Dict, Any, List
19
+ from typing import Dict, Any, Tuple, NamedTuple
20
+ from abc import ABC, abstractmethod
20
21
 
21
22
  from mct_quantizers import QuantizationMethod
22
- from model_compression_toolkit.defaultdict import DefaultDict
23
23
 
24
24
 
25
25
  # Default value to use for ops without kernel.
26
26
  # This is a weird default, but it's used all over the place, so for now only extract it to const so that it can be
27
27
  # referenced by variable instead of hard-coded.
28
- DEFAULT_KERNEL_ATTRIBUTES = [None]
28
+ DEFAULT_KERNEL_ATTRIBUTE = None
29
29
 
30
30
 
31
31
  class ChannelAxis(Enum):
@@ -42,89 +42,83 @@ class ChannelAxis(Enum):
42
42
  NCHW = 1
43
43
 
44
44
 
45
- class FrameworkInfo:
45
+ class ChannelAxisMapping(NamedTuple):
46
+ output: int
47
+ input: int
46
48
 
47
- def __init__(self,
48
- activation_quantizer_mapping: Dict[QuantizationMethod, Callable],
49
- kernel_channels_mapping: DefaultDict,
50
- activation_min_max_mapping: Dict[str, tuple],
51
- layer_min_max_mapping: Dict[Any, tuple],
52
- kernel_ops_attributes_mapping: DefaultDict,
53
- out_channel_axis_mapping: DefaultDict):
54
- """
55
- A class to wrap all information about a specific framework the library needs to quantize a model.
56
- Specifically, FrameworkInfo holds lists of layers by how they should be quantized, and multiple mappings such as
57
- layer to it kernel channels indices, and a layer to its min/max values, etc.
58
- The layers lists are divided into three groups:
59
- kernel_ops: Layers that have coefficients and need to get quantized (e.g., Conv2D, Dense, etc.)
60
- activation_ops: Layers that their outputs should get quantized (e.g., Add, ReLU, etc.)
61
- no_quantization_ops:Layers that should not get quantized (e.g., Reshape, Transpose, etc.)
62
-
63
- Args:
64
- activation_quantizer_mapping (Dict[QuantizationMethod, Callable]): A dictionary mapping from QuantizationMethod to a quantization function.
65
- kernel_channels_mapping (DefaultDict): Dictionary from a layer to a tuple of its kernel in/out channels indices.
66
- activation_min_max_mapping (Dict[str, tuple]): Dictionary from an activation function to its min/max output values.
67
- layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
68
- kernel_ops_attributes_mapping (DefaultDict): Dictionary from a framework operator to a list of its weights attirbutes to quantize.
69
- out_channel_axis_mapping (DefaultDict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
70
-
71
- Examples:
72
- When quantizing a Keras model, if we want to quantize the kernels of Conv2D layers only, we can
73
- set, and we know it's kernel out/in channel indices are (3, 2) respectivly:
74
-
75
- >>> import tensorflow as tf
76
- >>> kernel_ops = [tf.keras.layers.Conv2D]
77
- >>> kernel_channels_mapping = DefaultDict({tf.keras.layers.Conv2D: (3,2)})
78
49
 
79
- Then, we can create a FrameworkInfo object:
50
+ class FrameworkInfo(ABC):
51
+ """
52
+ A class to wrap all information about a specific framework the library needs to quantize a model.
53
+ Specifically, FrameworkInfo holds lists of layers by how they should be quantized, and multiple mappings such as
54
+ layer to it kernel channels indices, and a layer to its min/max values, etc.
55
+ The layers lists are divided into three groups:
56
+ kernel_ops: Layers that have coefficients and need to get quantized (e.g., Conv2D, Dense, etc.)
57
+ activation_ops: Layers that their outputs should get quantized (e.g., Add, ReLU, etc.)
58
+ no_quantization_ops:Layers that should not get quantized (e.g., Reshape, Transpose, etc.)
59
+
60
+ Fields:
61
+ activation_quantizer_mapping (Dict[QuantizationMethod, Callable]): A dictionary mapping from QuantizationMethod to a quantization function.
62
+ kernel_channels_mapping (Dict): Dictionary from a layer to a tuple of its kernel in/out channels indices.
63
+ kernel_ops_attribute_mapping (Dict): Dictionary from a framework operator to its weight attribute to quantize.
64
+ out_channel_axis_mapping (Dict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
65
+ _layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
80
66
 
81
- >>> FrameworkInfo(kernel_channels_mapping, {}, {})
67
+ """
82
68
 
83
- If an activation layer (tf.keras.layers.Activation) should be quantized and we know it's min/max outputs range in advanced, we can add it to activation_min_max_mapping for saving the statistics collection time. For example:
69
+ activation_quantizer_mapping: Dict[QuantizationMethod, Callable]
70
+ kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
71
+ kernel_ops_attribute_mapping: Dict[Any, str]
72
+ out_channel_axis_mapping: Dict[Any, int]
73
+ _layer_min_max_mapping: Dict[Any, tuple]
84
74
 
85
- >>> activation_min_max_mapping = {'softmax': (0, 1)}
86
- >>> FrameworkInfo(kernel_channels_mapping, activation_min_max_mapping, {})
75
+ _default_channel_mapping = ChannelAxisMapping(None, None)
87
76
 
88
- If a layer's activations should be quantized and we know it's min/max outputs range in advanced, we can add it to layer_min_max_mapping for saving the statistics collection time. For example:
77
+ @classmethod
78
+ def get_kernel_op_attribute(cls, node_type: Any) -> str:
79
+ """
80
+ Get attribute of a layer's weight to quantize.
89
81
 
90
- >>> layer_min_max_mapping = {tf.keras.layers.Softmax: (0, 1)}
91
- >>> FrameworkInfo(kernel_channels_mapping, activation_min_max_mapping, layer_min_max_mapping)
82
+ Args:
83
+ node_type: Layer to get its attribute.
92
84
 
85
+ Returns:
86
+ Attribute the layer has and should be quantized.
93
87
  """
88
+ return cls.kernel_ops_attribute_mapping.get(node_type, DEFAULT_KERNEL_ATTRIBUTE)
94
89
 
95
- self.activation_quantizer_mapping = activation_quantizer_mapping
96
- self.kernel_channels_mapping = kernel_channels_mapping
97
- self.activation_min_max_mapping = activation_min_max_mapping
98
- self.layer_min_max_mapping = layer_min_max_mapping
99
- self.kernel_ops_attributes_mapping = kernel_ops_attributes_mapping
100
- self.out_channel_axis_mapping = out_channel_axis_mapping
101
-
102
- def get_kernel_op_attributes(self, node_type: Any) -> List[str]:
90
+ @classmethod
91
+ def is_kernel_op(cls, node_type: Any) -> bool:
103
92
  """
104
- Get a list of attributes of a layer's weights to quantize.
93
+ Check is the node is a kernel operation.
105
94
 
106
95
  Args:
107
96
  node_type: Layer to get its attributes.
108
97
 
109
98
  Returns:
110
- A list of attributes the layer has and should be quantized.
99
+ True if node type is a kernel operation, else False.
111
100
  """
112
- attr_list = self.kernel_ops_attributes_mapping.get(node_type)
113
- return attr_list
101
+ return node_type in cls.kernel_ops_attribute_mapping
114
102
 
115
- def is_kernel_op(self, node_type: Any) -> bool:
103
+ @classmethod
104
+ def get_layer_min_max(cls, layer: Any, fw_attrs: Dict) -> Tuple[float, float]:
116
105
  """
117
- Check is the node is a kernel operation.
118
-
106
+ Return layer min/max mapping the FrameworkInfo holds.
119
107
  Args:
120
- node_type: Layer to get its attributes.
108
+ layer: A layer to check if has a min/max known values.
109
+ fw_attrs: framework attributes from framework layer.
121
110
 
122
111
  Returns:
123
- True if node type is a kernel operation, else False.
112
+ Layer's min/max known values.
124
113
  """
125
- return node_type in self.kernel_ops_attributes_mapping.keys()
126
114
 
127
- def layers_has_min_max(self, layer: Any) -> bool:
115
+ if cls.layers_has_min_max(layer):
116
+ return cls._layer_min_max_mapping[layer]
117
+ else:
118
+ return None, None
119
+
120
+ @classmethod
121
+ def layers_has_min_max(cls, layer: Any) -> bool:
128
122
  """
129
123
  Check if a layer is in a layer to min/max mapping the FrameworkInfo holds.
130
124
  Args:
@@ -134,17 +128,60 @@ class FrameworkInfo:
134
128
  Whether a layer has a min/max known values or not.
135
129
  """
136
130
 
137
- return layer in self.layer_min_max_mapping
131
+ return layer in cls._layer_min_max_mapping
138
132
 
139
- def activation_has_min_max(self, activation_name: str) -> bool:
133
+ @classmethod
134
+ @abstractmethod
135
+ def get_kernel_channels(cls, node_type: Any) -> ChannelAxisMapping:
140
136
  """
141
- Check if an activation layer has a min/max mapping.
137
+ Returns node's channels mapping from kernel_channels_mapping or framework specific default value.
138
+ Args:
139
+ node_type: A node type
142
140
 
141
+ Returns:
142
+ Node's channels mapping.
143
+ """
144
+ pass
145
+
146
+ @classmethod
147
+ @abstractmethod
148
+ def get_out_channel_axis(cls, node_type: Any):
149
+ """
150
+ Returns node's output channel mapping from out_channel_axis_mapping or framework specific default value.
143
151
  Args:
144
- activation_name: String of the activation function to check for its min/max values.
152
+ node_type: A node type.
145
153
 
146
154
  Returns:
147
- Whether an activation layer has a min/max known values or not.
155
+ Node's output channel axis.
156
+
148
157
  """
158
+ pass
159
+
160
+
161
+ # Pointer to current FrameworkInfo class.
162
+ _current_framework_info: type[FrameworkInfo] = None
163
+
164
+
165
+ def get_fw_info():
166
+ """
167
+ A common function to get the current FrameworkInfo class. Raises an error if the pointer wasn't initialized.
168
+
169
+ Returns: FrameworkInfo class.
170
+ """
171
+ assert _current_framework_info is not None, "fw_info isn't initialized."
172
+ assert issubclass(_current_framework_info, FrameworkInfo), "fw_info isn't initialized to a FrameworkInfo class."
173
+ return _current_framework_info
174
+
175
+
176
+ def set_fw_info(fw_info: type[FrameworkInfo]):
177
+ """
178
+ A common function to set the current FrameworkInfo class. Raises an error if fw_info doesn't inherit from FrameworkInfo.
179
+
180
+ Args:
181
+ fw_info: Framework specific object implementing the FrameworkInfo.
182
+ """
183
+ global _current_framework_info
184
+ assert _current_framework_info in [None, _current_framework_info], "FrameworkInfo already initialized."
185
+ assert issubclass(fw_info, FrameworkInfo), "fw_info must inherit from FrameworkInfo."
149
186
 
150
- return activation_name in self.activation_min_max_mapping
187
+ _current_framework_info = fw_info
@@ -23,7 +23,6 @@ import numpy as np
23
23
 
24
24
  from networkx.algorithms.dag import topological_sort
25
25
 
26
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
27
26
  from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo
28
27
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX, EDGE_SOURCE_INDEX
29
28
  from model_compression_toolkit.core.common.graph.edge import Edge, convert_to_edge
@@ -74,7 +73,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
74
73
  input_nodes: List[BaseNode],
75
74
  output_nodes: List[OutTensor],
76
75
  edge_list: List[Edge],
77
- fw_info: FrameworkInfo = None,
78
76
  **attr):
79
77
  """
80
78
  Args:
@@ -82,7 +80,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
82
80
  input_nodes: List of input nodes the model
83
81
  output_nodes: List of output nodes of the model to a list of their output indices.
84
82
  edge_list: List of edges the graph has between nodes.
85
- fw_info: FrameworkInfo object (needed for computing the graph's weights memory).
86
83
  **attr: Attributes to add to graph as key=value pairs.
87
84
  """
88
85
 
@@ -103,7 +100,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
103
100
  e.sink_node,
104
101
  **e.get_attributes())
105
102
  self.user_info = UserInformation()
106
- self.fw_info = fw_info
107
103
 
108
104
  @property
109
105
  def skip_validation_check(self) -> bool:
@@ -124,16 +120,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
124
120
  def fusing_info(self, fusing_info: FusingInfo):
125
121
  self._fusing_info = fusing_info
126
122
 
127
- def set_fw_info(self,
128
- fw_info: FrameworkInfo):
129
- """
130
- Set the graph's framework info.
131
- Args:
132
- fw_info: FrameworkInfo object.
133
- """
134
-
135
- self.fw_info = fw_info
136
-
137
123
  def set_fqc(self,
138
124
  fqc: FrameworkQuantizationCapabilities):
139
125
  """
@@ -563,7 +549,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
563
549
  return output_edges
564
550
 
565
551
  def get_configurable_sorted_nodes_names(self,
566
- fw_info: FrameworkInfo,
567
552
  include_reused_nodes: bool = False) -> List[str]:
568
553
  """
569
554
  Get a list of nodes' names that can be configured (namely, has one or
@@ -571,56 +556,49 @@ class Graph(nx.MultiDiGraph, GraphSearches):
571
556
  order of the graph.
572
557
 
573
558
  Args:
574
- fw_info: FrameworkInfo object with information about the specific framework's model.
575
559
  include_reused_nodes: Whether or not to include reused nodes (False by default).
576
560
 
577
561
  Returns: List of nodes' names that can be configured (namely, has one or
578
562
  more weight qc candidate) sorted topology.
579
563
 
580
564
  """
581
- sorted_names = [n.name for n in self.get_configurable_sorted_nodes(fw_info=fw_info,
582
- include_reused_nodes=include_reused_nodes)]
565
+ sorted_names = [n.name for n in self.get_configurable_sorted_nodes(include_reused_nodes=include_reused_nodes)]
583
566
  return sorted_names
584
567
 
585
568
  def get_weights_configurable_nodes(self,
586
- fw_info: FrameworkInfo,
587
569
  include_reused_nodes: bool = False) -> List[BaseNode]:
588
570
  """
589
571
  Get a list of nodes that their weights can be configured (namely, has one or
590
572
  more weight qc candidate and their weights should be quantized).
591
573
 
592
574
  Args:
593
- fw_info: FrameworkInfo object with information about the specific framework's model.
594
575
  include_reused_nodes: Whether to include reused nodes (False by default).
595
576
 
596
577
  Returns:
597
578
  A list of nodes that their weights can be configured (namely, has one or more weight qc candidate).
598
579
  """
599
580
  # configurability is only relevant for kernel attribute quantization
600
- potential_conf_nodes = [n for n in list(self) if fw_info.is_kernel_op(n.type)]
581
+ potential_conf_nodes = [n for n in list(self) if n.is_kernel_op]
601
582
 
602
583
  def is_configurable(n):
603
- kernel_attrs = fw_info.get_kernel_op_attributes(n.type)
604
- return any(n.is_configurable_weight(attr) for attr in kernel_attrs) and (not n.reuse or include_reused_nodes)
584
+ return n.is_configurable_weight(n.kernel_attr) and (not n.reuse or include_reused_nodes)
605
585
 
606
586
  return [n for n in potential_conf_nodes if is_configurable(n)]
607
587
 
608
588
  def get_sorted_weights_configurable_nodes(self,
609
- fw_info: FrameworkInfo,
610
589
  include_reused_nodes: bool = False) -> List[BaseNode]:
611
590
  """
612
591
  Get a list of sorted nodes that their weights can be configured (namely, has one or
613
592
  more weight qc candidate and their weights should be quantized).
614
593
 
615
594
  Args:
616
- fw_info: FrameworkInfo object with information about the specific framework's model.
617
595
  include_reused_nodes: Whether to include reused nodes (False by default).
618
596
 
619
597
  Returns:
620
598
  A list of nodes that their weights can be configured (namely, has one or more weight qc candidate)
621
599
  sorted topologically.
622
600
  """
623
- return self._sort_nodes_in_list(self.get_weights_configurable_nodes(fw_info, include_reused_nodes))
601
+ return self._sort_nodes_in_list(self.get_weights_configurable_nodes(include_reused_nodes))
624
602
 
625
603
  def get_activation_configurable_nodes(self) -> List[BaseNode]:
626
604
  """
@@ -644,7 +622,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
644
622
  return self._sort_nodes_in_list(self.get_activation_configurable_nodes())
645
623
 
646
624
  def get_configurable_sorted_nodes(self,
647
- fw_info: FrameworkInfo,
648
625
  include_reused_nodes: bool = False) -> List[BaseNode]:
649
626
  """
650
627
  Get a list of nodes that can be configured (namely, has one or
@@ -652,14 +629,13 @@ class Graph(nx.MultiDiGraph, GraphSearches):
652
629
  The nodes are sorted according to the topological order of the graph.
653
630
 
654
631
  Args:
655
- fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
656
632
  include_reused_nodes: Whether or not to include reused nodes (False by default).
657
633
 
658
634
  Returns:
659
635
  A list of nodes that can be configured (namely, has one or more qc candidate) sorted topology.
660
636
 
661
637
  """
662
- weights_configurable_nodes = self.get_weights_configurable_nodes(fw_info, include_reused_nodes)
638
+ weights_configurable_nodes = self.get_weights_configurable_nodes(include_reused_nodes)
663
639
  activation_configurable_nodes = self.get_activation_configurable_nodes()
664
640
 
665
641
  # combine and remove duplications
@@ -684,7 +660,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
684
660
  sorted_configurable_nodes.append(n)
685
661
  return sorted_configurable_nodes
686
662
 
687
- def get_min_candidates_config(self, fw_info: FrameworkInfo) -> Dict[BaseNode, int]:
663
+ def get_min_candidates_config(self) -> Dict[BaseNode, int]:
688
664
  """
689
665
  Builds a minimal configuration.
690
666
  Note: we assume that a minimal configuration exists, i.e., each configurable node has exactly one candidate
@@ -697,26 +673,23 @@ class Graph(nx.MultiDiGraph, GraphSearches):
697
673
  Returns:
698
674
  A dict from layer to an index of its minimal candidate.
699
675
  """
700
- conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
676
+ conf_sorted_nodes = self.get_configurable_sorted_nodes()
701
677
  return {n: n.find_min_candidate_index() for n in conf_sorted_nodes}
702
678
 
703
- def get_max_candidates_config(self, fw_info: FrameworkInfo) -> Dict[BaseNode, int]:
679
+ def get_max_candidates_config(self) -> Dict[BaseNode, int]:
704
680
  """
705
681
  Builds a maximal configuration.
706
682
  Note: we assume that a maximal configuration exists, i.e., each configurable node has exactly one candidate
707
683
  with maximal n_bits (in both weight and activation if both are quantized, or in the relevant one if only
708
684
  one of them is quantized)
709
685
 
710
- Args:
711
- fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
712
-
713
686
  Returns:
714
687
  A dict from layer to an index of its maximal candidate.
715
688
  """
716
- conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
689
+ conf_sorted_nodes = self.get_configurable_sorted_nodes()
717
690
  return {n: n.find_max_candidate_index() for n in conf_sorted_nodes}
718
691
 
719
- def get_final_weights_config(self, fw_info: FrameworkInfo) -> List[Tuple[BaseNode, int]]:
692
+ def get_final_weights_config(self) -> List[Tuple[BaseNode, int]]:
720
693
  """
721
694
  Gets the final number of bits for quantization of each weights' configurable layer.
722
695
 
@@ -726,9 +699,9 @@ class Graph(nx.MultiDiGraph, GraphSearches):
726
699
  Returns: A list of pairs of (node type, node's weights quantization bitwidth).
727
700
 
728
701
  """
729
- sorted_conf_weights = self.get_sorted_weights_configurable_nodes(fw_info)
702
+ sorted_conf_weights = self.get_sorted_weights_configurable_nodes()
730
703
  # a configurable node by definition has a kernel op
731
- return [(n, n.final_weights_quantization_cfg.get_attr_config(self.fw_info.get_kernel_op_attributes(n.type)[0]).weights_n_bits)
704
+ return [(n, n.final_weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_n_bits)
732
705
  for n in sorted_conf_weights]
733
706
 
734
707
  def get_final_activation_config(self) -> List[Tuple[BaseNode, int]]:
@@ -846,7 +819,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
846
819
  next_node = self.out_edges(next_node)[0].sink_node
847
820
 
848
821
  # If next_node is an exit node and has only one incoming edge, the topology is prunable.
849
- if fw_impl.is_node_exit_node(next_node, entry_node, self.fw_info) and len(self.in_edges(next_node)) == 1:
822
+ if fw_impl.is_node_exit_node(next_node, entry_node) and len(self.in_edges(next_node)) == 1:
850
823
  return True
851
824
 
852
825
  # If the next node is not an intermediate node or has more than one incoming/outgoing edge,
@@ -876,7 +849,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
876
849
 
877
850
  intermediate_nodes, exit_node = self._find_intermediate_and_exit_nodes(entry_node, fw_impl)
878
851
 
879
- if not fw_impl.is_node_exit_node(exit_node, entry_node, self.fw_info):
852
+ if not fw_impl.is_node_exit_node(exit_node, entry_node):
880
853
  Logger.critical(f"Node {exit_node} is not a valid exit node for the pruning section starting with {entry_node}.") # pragma: no cover
881
854
 
882
855
  return PruningSection(entry_node=entry_node,
@@ -897,7 +870,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
897
870
  """
898
871
  intermediate_nodes = []
899
872
  next_node = self.out_edges(entry_node)[0].sink_node
900
- while not fw_impl.is_node_exit_node(next_node, entry_node, self.fw_info):
873
+ while not fw_impl.is_node_exit_node(next_node, entry_node):
901
874
  intermediate_nodes.append(next_node)
902
875
  next_node = self.out_edges(next_node)[0].sink_node
903
876