mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250927.534__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.4.0.20250925.000543"
30
+ __version__ = "2.4.2.20250927.000534"
@@ -32,7 +32,8 @@ def analyzer_model_quantization(representative_data_gen: Callable,
32
32
  tb_w: TensorboardWriter,
33
33
  float_graph: Graph,
34
34
  quantized_graph: Graph,
35
- fw_impl: FrameworkImplementation):
35
+ fw_impl: FrameworkImplementation,
36
+ fw_info: FrameworkInfo):
36
37
  """
37
38
  Plot the cosine similarity of different points on the graph between the float and quantized
38
39
  graphs. Add them to the passed TensorboardWriter object and close all tensorboard writer open
@@ -44,12 +45,14 @@ def analyzer_model_quantization(representative_data_gen: Callable,
44
45
  float_graph: Graph of float model.
45
46
  quantized_graph: Graph of quantized model.
46
47
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
48
+ fw_info: Information needed for quantization about the specific framework.
47
49
 
48
50
  """
49
51
  if tb_w is not None:
50
52
  visual = NNVisualizer(float_graph,
51
53
  quantized_graph,
52
- fw_impl=fw_impl)
54
+ fw_impl=fw_impl,
55
+ fw_info=fw_info)
53
56
  if not visual.has_compare_points():
54
57
  Logger.error(f'No comparing points were found to plot analyze similarity.')
55
58
  else:
@@ -15,6 +15,7 @@
15
15
  from abc import ABC, abstractmethod
16
16
  from typing import Any, Tuple
17
17
 
18
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
18
19
  from model_compression_toolkit.core import common
19
20
  from model_compression_toolkit.core.common.user_info import UserInformation
20
21
 
@@ -27,17 +28,20 @@ class BaseModelBuilder(ABC):
27
28
  def __init__(self,
28
29
  graph: common.Graph,
29
30
  append2output=None,
31
+ fw_info: FrameworkInfo = None,
30
32
  return_float_outputs: bool = False):
31
33
  """
32
34
 
33
35
  Args:
34
36
  graph: Graph to build the model from.
35
37
  append2output: Nodes of graph to append to model's output.
38
+ fw_info: Information about the specific framework of the model that is built.
36
39
  return_float_outputs: Whether the model returns float tensors or not.
37
40
  """
38
41
 
39
42
  self.graph = graph
40
43
  self.append2output = append2output
44
+ self.fw_info = fw_info
41
45
  self.return_float_outputs = return_float_outputs
42
46
 
43
47
  @abstractmethod
@@ -13,12 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from abc import ABC, abstractmethod
17
16
  import numpy as np
18
17
  from model_compression_toolkit.logger import Logger
19
18
 
20
19
 
21
- class BaseCollector(ABC):
20
+ class BaseCollector(object):
22
21
  """
23
22
  Base class for statistics collection object.
24
23
  """
@@ -27,7 +26,6 @@ class BaseCollector(ABC):
27
26
  # When manipulation statistics in a granularity they were not collected by, the data is invalid.
28
27
  self.is_legal = True
29
28
 
30
- @abstractmethod
31
29
  def scale(self, scale_factor: np.ndarray):
32
30
  """
33
31
  Scale all statistics in collector by some factor.
@@ -39,7 +37,6 @@ class BaseCollector(ABC):
39
37
  raise NotImplemented(
40
38
  f'{self.__class__.__name__} needs to implement scale operation for its state.') # pragma: no cover
41
39
 
42
- @abstractmethod
43
40
  def shift(self, shift_value: np.ndarray):
44
41
  """
45
42
  Shift all statistics in collector by some value.
@@ -87,13 +87,10 @@ class MeanCollector(BaseCollector):
87
87
  x: Tensor that goes through the mean collector and needs to be considered in the mean computation.
88
88
  """
89
89
  self.i += 1 # Update the iteration index
90
- if self.axis is None:
91
- mu = np.mean(np.reshape(x, [1, -1]), axis=-1) # mean per channel for a batch
92
- else:
93
- axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
94
- n = x.shape[axis]
95
- transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
96
- mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
90
+ axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
91
+ n = x.shape[axis]
92
+ transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
93
+ mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
97
94
  self.current_sum += mu # sum of all batches
98
95
  self.current_mean = self.current_sum / self.i # mean of all batches
99
96
 
@@ -130,13 +130,10 @@ class MinMaxPerChannelCollector(BaseCollector):
130
130
  x: Tensor that goes through the collector and needs to be considered in the min/max computation.
131
131
  """
132
132
 
133
- if self.axis is None:
134
- x_reshape = np.reshape(x, [1, -1])
135
- else:
136
- axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
137
- n = x.shape[axis]
138
- transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
139
- x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
133
+ axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
134
+ n = x.shape[axis]
135
+ transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
136
+ x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
140
137
  if self.state is None:
141
138
  x_max = np.max(x_reshape, axis=-1)
142
139
  x_min = np.min(x_reshape, axis=-1)
@@ -125,16 +125,18 @@ class FrameworkImplementation(ABC):
125
125
  graph: Graph,
126
126
  mode: ModelBuilderMode,
127
127
  append2output: List[Any],
128
+ fw_info: FrameworkInfo,
128
129
  return_float_outputs: bool = False) -> Tuple:
129
130
  """
130
131
  Build a framework model from a graph.
131
- The mode determines how the model should be built. append2output is a list of Nodes
132
+ The mode determines how the model should be build. append2output is a list of Nodes
132
133
  to set as the model outputs.
133
134
 
134
135
  Args:
135
136
  graph: Graph to build the model from it.
136
137
  mode: Mode for how to build the model.
137
138
  append2output: List of Nodes to set as the model's outputs.
139
+ fw_info: FrameworkInfo object with information about the specific framework's model
138
140
  return_float_outputs (bool): whether to return outputs before or after quantization nodes (default)
139
141
 
140
142
  Returns:
@@ -168,13 +170,15 @@ class FrameworkImplementation(ABC):
168
170
  @abstractmethod
169
171
  def shift_negative_correction(self,
170
172
  graph: Graph,
171
- core_config: CoreConfig) -> Graph:
173
+ core_config: CoreConfig,
174
+ fw_info: FrameworkInfo) -> Graph:
172
175
  """
173
176
  Apply shift negative correction (SNC) on a graph.
174
177
 
175
178
  Args:
176
179
  graph: Graph to apply SNC on.
177
180
  core_config: Quantization configuration.
181
+ fw_info: FrameworkInfo object with information about the specific framework's model.
178
182
 
179
183
  Returns:
180
184
  Graph after SNC.
@@ -185,13 +189,15 @@ class FrameworkImplementation(ABC):
185
189
  @abstractmethod
186
190
  def compute_activation_bias_correction(self,
187
191
  graph: Graph,
188
- quant_config: QuantizationConfig) -> Graph:
192
+ quant_config: QuantizationConfig,
193
+ fw_info: FrameworkInfo) -> Graph:
189
194
  """
190
195
  Compute activation bias correction on a graph.
191
196
 
192
197
  Args:
193
198
  graph: Graph to apply activation bias correction on.
194
199
  quant_config: QuantizationConfig of how the model should be quantized.
200
+ fw_info: FrameworkInfo object with information about the specific framework's model.
195
201
 
196
202
  Returns:
197
203
  Graph after activation bias correction computing.
@@ -201,28 +207,30 @@ class FrameworkImplementation(ABC):
201
207
 
202
208
  @abstractmethod
203
209
  def get_substitutions_channel_equalization(self,
204
- quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
210
+ quant_config: QuantizationConfig,
211
+ fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
205
212
  """
206
213
  Return a list of the framework substitutions used for channel equalization.
207
214
 
208
215
  Args:
209
216
  quant_config: QuantizationConfig to determine which substitutions to return.
217
+ fw_info: FrameworkInfo object with information about the specific framework's model.
210
218
 
211
219
  Returns:
212
220
  A list of the framework substitutions used after we collect statistics.
213
221
  """
214
222
  raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
215
- f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover
223
+ f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover
216
224
 
217
225
  @abstractmethod
218
- def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
226
+ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List[common.BaseSubstitution]:
219
227
  """
220
228
 
221
229
  Returns: A list of the framework substitutions used to prepare the graph.
222
230
 
223
231
  """
224
232
  raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
225
- f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover
233
+ f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover
226
234
 
227
235
  @abstractmethod
228
236
  def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig) -> \
@@ -320,12 +328,14 @@ class FrameworkImplementation(ABC):
320
328
  f'method.') # pragma: no cover
321
329
 
322
330
  def get_node_prior_info(self, node: BaseNode,
331
+ fw_info: FrameworkInfo,
323
332
  graph: Graph) -> NodePriorInfo:
324
333
  """
325
334
  Get a NodePriorInfo object for a node.
326
335
 
327
336
  Args:
328
337
  node: Node to get its prior info.
338
+ fw_info: Framework specific information needed to create the prior info of the node.
329
339
  graph: Graph to check the next node type.
330
340
 
331
341
  Returns:
@@ -333,7 +343,7 @@ class FrameworkImplementation(ABC):
333
343
  """
334
344
 
335
345
  raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
336
- f'framework\'s get_node_prior_info method.') # pragma: no cover
346
+ f'framework\'s get_node_prior_info method.') # pragma: no cover
337
347
 
338
348
  def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
339
349
  """
@@ -384,18 +394,20 @@ class FrameworkImplementation(ABC):
384
394
 
385
395
  @abstractmethod
386
396
  def get_node_mac_operations(self,
387
- node: BaseNode) -> float:
397
+ node: BaseNode,
398
+ fw_info: FrameworkInfo) -> float:
388
399
  """
389
400
  Gets the MAC operation count for a given operation.
390
401
 
391
402
  Args:
392
403
  node: A graph node that wraps the operation for which the MAC count is computed.
404
+ fw_info: FrameworkInfo object with information about the specific framework's model.
393
405
 
394
406
  Returns: The MAC count of the operation
395
407
  """
396
408
 
397
409
  raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
398
- f'framework\'s get_node_mac_operations method.') # pragma: no cover
410
+ f'framework\'s get_node_mac_operations method.') # pragma: no cover
399
411
 
400
412
  @abstractmethod
401
413
  def apply_second_moment_correction(self,
@@ -13,9 +13,19 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+
17
+ from collections.abc import Callable
16
18
  from enum import Enum
17
- from typing import Dict, Any, Tuple, NamedTuple, Optional
18
- from abc import ABC, abstractmethod
19
+ from typing import Dict, Any, List
20
+
21
+ from mct_quantizers import QuantizationMethod
22
+ from model_compression_toolkit.defaultdict import DefaultDict
23
+
24
+
25
+ # Default value to use for ops without kernel.
26
+ # This is a weird default, but it's used all over the place, so for now only extract it to const so that it can be
27
+ # referenced by variable instead of hard-coded.
28
+ DEFAULT_KERNEL_ATTRIBUTES = [None]
19
29
 
20
30
 
21
31
  class ChannelAxis(Enum):
@@ -32,67 +42,89 @@ class ChannelAxis(Enum):
32
42
  NCHW = 1
33
43
 
34
44
 
35
- class ChannelAxisMapping(NamedTuple):
36
- output: int
37
- input: int
45
+ class FrameworkInfo:
46
+
47
+ def __init__(self,
48
+ activation_quantizer_mapping: Dict[QuantizationMethod, Callable],
49
+ kernel_channels_mapping: DefaultDict,
50
+ activation_min_max_mapping: Dict[str, tuple],
51
+ layer_min_max_mapping: Dict[Any, tuple],
52
+ kernel_ops_attributes_mapping: DefaultDict,
53
+ out_channel_axis_mapping: DefaultDict):
54
+ """
55
+ A class to wrap all information about a specific framework the library needs to quantize a model.
56
+ Specifically, FrameworkInfo holds lists of layers by how they should be quantized, and multiple mappings such as
57
+ layer to it kernel channels indices, and a layer to its min/max values, etc.
58
+ The layers lists are divided into three groups:
59
+ kernel_ops: Layers that have coefficients and need to get quantized (e.g., Conv2D, Dense, etc.)
60
+ activation_ops: Layers that their outputs should get quantized (e.g., Add, ReLU, etc.)
61
+ no_quantization_ops:Layers that should not get quantized (e.g., Reshape, Transpose, etc.)
38
62
 
63
+ Args:
64
+ activation_quantizer_mapping (Dict[QuantizationMethod, Callable]): A dictionary mapping from QuantizationMethod to a quantization function.
65
+ kernel_channels_mapping (DefaultDict): Dictionary from a layer to a tuple of its kernel in/out channels indices.
66
+ activation_min_max_mapping (Dict[str, tuple]): Dictionary from an activation function to its min/max output values.
67
+ layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
68
+ kernel_ops_attributes_mapping (DefaultDict): Dictionary from a framework operator to a list of its weights attirbutes to quantize.
69
+ out_channel_axis_mapping (DefaultDict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
39
70
 
40
- class FrameworkInfo(ABC):
41
- """
42
- A class to wrap all information about a specific framework the library needs to quantize a model.
43
- Specifically, FrameworkInfo holds lists of layers by how they should be quantized, and multiple mappings such as
44
- layer to it kernel channels indices, and a layer to its min/max values, etc.
45
- The layers lists are divided into three groups:
46
- kernel_ops: Layers that have coefficients and need to get quantized (e.g., Conv2D, Dense, etc.)
47
- activation_ops: Layers that their outputs should get quantized (e.g., Add, ReLU, etc.)
48
- no_quantization_ops:Layers that should not get quantized (e.g., Reshape, Transpose, etc.)
49
-
50
- Fields:
51
- kernel_channels_mapping (Dict): Dictionary from a layer to a tuple of its kernel in/out channels indices.
52
- kernel_ops_attribute_mapping (Dict): Dictionary from a framework operator to its weight attribute to quantize.
53
- out_channel_axis_mapping (Dict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
54
- _layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
55
- """
71
+ Examples:
72
+ When quantizing a Keras model, if we want to quantize the kernels of Conv2D layers only, we can
73
+ set, and we know it's kernel out/in channel indices are (3, 2) respectivly:
74
+
75
+ >>> import tensorflow as tf
76
+ >>> kernel_ops = [tf.keras.layers.Conv2D]
77
+ >>> kernel_channels_mapping = DefaultDict({tf.keras.layers.Conv2D: (3,2)})
56
78
 
57
- kernel_ops_attribute_mapping: Dict[Any, str]
58
- kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
59
- out_channel_axis_mapping: Dict[Any, int]
79
+ Then, we can create a FrameworkInfo object:
60
80
 
61
- _layer_min_max_mapping: Dict[Any, tuple]
62
- _default_channel_mapping = ChannelAxisMapping(None, None)
81
+ >>> FrameworkInfo(kernel_channels_mapping, {}, {})
82
+
83
+ If an activation layer (tf.keras.layers.Activation) should be quantized and we know it's min/max outputs range in advanced, we can add it to activation_min_max_mapping for saving the statistics collection time. For example:
84
+
85
+ >>> activation_min_max_mapping = {'softmax': (0, 1)}
86
+ >>> FrameworkInfo(kernel_channels_mapping, activation_min_max_mapping, {})
87
+
88
+ If a layer's activations should be quantized and we know it's min/max outputs range in advanced, we can add it to layer_min_max_mapping for saving the statistics collection time. For example:
89
+
90
+ >>> layer_min_max_mapping = {tf.keras.layers.Softmax: (0, 1)}
91
+ >>> FrameworkInfo(kernel_channels_mapping, activation_min_max_mapping, layer_min_max_mapping)
63
92
 
64
- @classmethod
65
- def get_kernel_op_attribute(cls, node_type: Any) -> Optional[str]:
66
93
  """
67
- Get attribute of a layer's weight to quantize.
94
+
95
+ self.activation_quantizer_mapping = activation_quantizer_mapping
96
+ self.kernel_channels_mapping = kernel_channels_mapping
97
+ self.activation_min_max_mapping = activation_min_max_mapping
98
+ self.layer_min_max_mapping = layer_min_max_mapping
99
+ self.kernel_ops_attributes_mapping = kernel_ops_attributes_mapping
100
+ self.out_channel_axis_mapping = out_channel_axis_mapping
101
+
102
+ def get_kernel_op_attributes(self, node_type: Any) -> List[str]:
103
+ """
104
+ Get a list of attributes of a layer's weights to quantize.
68
105
 
69
106
  Args:
70
- node_type: Layer to get its attribute.
107
+ node_type: Layer to get its attributes.
71
108
 
72
109
  Returns:
73
- Attribute the layer has and should be quantized.
110
+ A list of attributes the layer has and should be quantized.
74
111
  """
75
- return cls.kernel_ops_attribute_mapping.get(node_type)
112
+ attr_list = self.kernel_ops_attributes_mapping.get(node_type)
113
+ return attr_list
76
114
 
77
- @classmethod
78
- def get_layer_min_max(cls, layer: Any, fw_attrs: Dict) -> Tuple[float, float]:
115
+ def is_kernel_op(self, node_type: Any) -> bool:
79
116
  """
80
- Return layer min/max mapping the FrameworkInfo holds.
117
+ Check is the node is a kernel operation.
118
+
81
119
  Args:
82
- layer: A layer to check if has a min/max known values.
83
- fw_attrs: framework attributes from framework layer.
120
+ node_type: Layer to get its attributes.
84
121
 
85
122
  Returns:
86
- Layer's min/max known values.
123
+ True if node type is a kernel operation, else False.
87
124
  """
125
+ return node_type in self.kernel_ops_attributes_mapping.keys()
88
126
 
89
- if cls.layers_has_min_max(layer):
90
- return cls._layer_min_max_mapping[layer]
91
- else:
92
- return None, None
93
-
94
- @classmethod
95
- def layers_has_min_max(cls, layer: Any) -> bool:
127
+ def layers_has_min_max(self, layer: Any) -> bool:
96
128
  """
97
129
  Check if a layer is in a layer to min/max mapping the FrameworkInfo holds.
98
130
  Args:
@@ -102,59 +134,17 @@ class FrameworkInfo(ABC):
102
134
  Whether a layer has a min/max known values or not.
103
135
  """
104
136
 
105
- return layer in cls._layer_min_max_mapping
137
+ return layer in self.layer_min_max_mapping
106
138
 
107
- @classmethod
108
- @abstractmethod
109
- def get_kernel_channels(cls, node_type: Any) -> ChannelAxisMapping:
110
- """
111
- Returns node's channels mapping from kernel_channels_mapping or framework specific default value.
112
- Args:
113
- node_type: A node type
114
-
115
- Returns:
116
- Node's channels mapping.
139
+ def activation_has_min_max(self, activation_name: str) -> bool:
117
140
  """
118
- pass
141
+ Check if an activation layer has a min/max mapping.
119
142
 
120
- @classmethod
121
- @abstractmethod
122
- def get_out_channel_axis(cls, node_type: Any):
123
- """
124
- Returns node's output channel mapping from out_channel_axis_mapping or framework specific default value.
125
143
  Args:
126
- node_type: A node type.
144
+ activation_name: String of the activation function to check for its min/max values.
127
145
 
128
146
  Returns:
129
- Node's output channel axis.
130
-
147
+ Whether an activation layer has a min/max known values or not.
131
148
  """
132
- pass
133
-
134
-
135
- # Pointer to current FrameworkInfo class.
136
- _current_framework_info: type[FrameworkInfo] = None
137
-
138
-
139
- def get_fw_info():
140
- """
141
- A common function to get the current FrameworkInfo class. Raises an error if the pointer wasn't initialized.
142
-
143
- Returns: FrameworkInfo class.
144
- """
145
- assert _current_framework_info is not None, "fw_info isn't initialized."
146
- return _current_framework_info
147
-
148
-
149
- def set_fw_info(fw_info: type[FrameworkInfo]):
150
- """
151
- A common function to set the current FrameworkInfo class. Raises an error if fw_info doesn't inherit from FrameworkInfo.
152
-
153
- Args:
154
- fw_info: Framework specific object implementing the FrameworkInfo.
155
- """
156
- global _current_framework_info
157
- assert _current_framework_info in [None, _current_framework_info], "FrameworkInfo already initialized."
158
- assert issubclass(fw_info, FrameworkInfo), "fw_info must inherit from FrameworkInfo."
159
149
 
160
- _current_framework_info = fw_info
150
+ return activation_name in self.activation_min_max_mapping
@@ -14,12 +14,12 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import copy
17
- from typing import Tuple
17
+ from typing import List, Tuple
18
18
 
19
19
  from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
20
20
  from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
21
- from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
22
- CandidateNodeQuantizationConfig, NodeQuantizationConfig
21
+ from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
22
+ from itertools import product
23
23
 
24
24
 
25
25
  class FusedLayerType:
@@ -30,7 +30,6 @@ class FusedLayerType:
30
30
  def __init__(self):
31
31
  self.__name__ = 'FusedLayer'
32
32
 
33
-
34
33
  class GraphFuser:
35
34
  def apply_node_fusion(self, graph: Graph) -> Graph:
36
35
  """
@@ -65,6 +64,7 @@ class GraphFuser:
65
64
 
66
65
  return graph_copy
67
66
 
67
+
68
68
  @staticmethod
69
69
  def _create_fused_node(fused_node_id: str, nodes: Tuple[BaseNode]) -> BaseNode:
70
70
  """
@@ -86,15 +86,10 @@ class GraphFuser:
86
86
  weights={},
87
87
  layer_class=FusedLayerType)
88
88
 
89
- base_cfg = CandidateNodeQuantizationConfig(
90
- activation_quantization_cfg=nodes[-1].quantization_cfg.base_quantization_cfg.activation_quantization_cfg,
91
- weights_quantization_cfg=None
92
- )
93
89
  activation_cfgs = [c.activation_quantization_cfg for c in nodes[-1].candidates_quantization_cfg]
94
- candidates = [CandidateNodeQuantizationConfig(weights_quantization_cfg=None, activation_quantization_cfg=a)
95
- for a in activation_cfgs]
96
- fused_node.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=base_cfg,
97
- candidates_quantization_cfg=candidates)
90
+ fused_node.candidates_quantization_cfg = [
91
+ CandidateNodeQuantizationConfig(weights_quantization_cfg=None, activation_quantization_cfg=a) for a in
92
+ activation_cfgs]
98
93
 
99
94
  # Keep the final configurations if they were set already.
100
95
  fused_node.final_weights_quantization_cfg = nodes[0].final_weights_quantization_cfg
@@ -163,3 +158,5 @@ class GraphFuser:
163
158
 
164
159
  # Finally, add the new fused node to the graph
165
160
  graph.add_node(fused_node)
161
+
162
+