mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250618.606__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/RECORD +120 -120
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +2 -5
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
  6. model_compression_toolkit/core/common/framework_implementation.py +10 -22
  7. model_compression_toolkit/core/common/framework_info.py +105 -68
  8. model_compression_toolkit/core/common/graph/base_graph.py +15 -42
  9. model_compression_toolkit/core/common/graph/base_node.py +103 -42
  10. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  11. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
  12. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
  16. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
  17. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
  18. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
  19. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
  20. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
  21. model_compression_toolkit/core/common/model_collector.py +10 -20
  22. model_compression_toolkit/core/common/model_validation.py +1 -4
  23. model_compression_toolkit/core/common/network_editors/actions.py +14 -38
  24. model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
  25. model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
  26. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
  27. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
  28. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
  29. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
  30. model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
  31. model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
  32. model_compression_toolkit/core/common/pruning/pruner.py +1 -6
  33. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
  34. model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
  35. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
  36. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
  37. model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
  39. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
  40. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
  41. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
  42. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
  43. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
  44. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
  45. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
  46. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
  47. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
  48. model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
  49. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
  50. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
  51. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  52. model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
  53. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
  54. model_compression_toolkit/core/graph_prep_runner.py +2 -16
  55. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
  56. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
  57. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
  58. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
  59. model_compression_toolkit/core/keras/default_framework_info.py +138 -87
  60. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
  61. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
  62. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
  63. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
  64. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  65. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
  66. model_compression_toolkit/core/keras/keras_implementation.py +15 -35
  67. model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
  68. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
  69. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
  70. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +0 -2
  71. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
  72. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
  73. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
  74. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
  75. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
  76. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
  77. model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
  78. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
  79. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
  80. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  81. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
  82. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
  83. model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
  84. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
  85. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  86. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
  87. model_compression_toolkit/core/quantization_prep_runner.py +4 -9
  88. model_compression_toolkit/core/runner.py +5 -15
  89. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  90. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  91. model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
  92. model_compression_toolkit/gptq/common/gptq_training.py +1 -8
  93. model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
  94. model_compression_toolkit/gptq/keras/graph_info.py +4 -6
  95. model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
  96. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  97. model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
  98. model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
  99. model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
  100. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  101. model_compression_toolkit/gptq/runner.py +1 -7
  102. model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
  103. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
  104. model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
  105. model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
  106. model_compression_toolkit/ptq/runner.py +1 -4
  107. model_compression_toolkit/qat/common/qat_config.py +2 -6
  108. model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
  109. model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
  110. model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
  111. model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
  112. model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
  113. model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
  114. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
  115. model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
  116. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
  117. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
  118. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/WHEEL +0 -0
  119. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/licenses/LICENSE.md +0 -0
  120. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/top_level.txt +0 -0
@@ -37,7 +37,6 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
37
37
  def graph_preparation_runner(in_model: Any,
38
38
  representative_data_gen: Callable,
39
39
  quantization_config: QuantizationConfig,
40
- fw_info: FrameworkInfo,
41
40
  fw_impl: FrameworkImplementation,
42
41
  fqc: FrameworkQuantizationCapabilities,
43
42
  bit_width_config: BitWidthConfig = None,
@@ -56,8 +55,6 @@ def graph_preparation_runner(in_model: Any,
56
55
  in_model (Any): Model to quantize.
57
56
  representative_data_gen (Callable): Dataset used for calibration.
58
57
  quantization_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized.
59
- fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices,
60
- groups of layers by how they should be quantized, etc.).
61
58
  fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
62
59
  fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities object that models the inference target platform and
63
60
  the attached framework operator's information.
@@ -73,7 +70,6 @@ def graph_preparation_runner(in_model: Any,
73
70
  graph = read_model_to_graph(in_model,
74
71
  representative_data_gen,
75
72
  fqc,
76
- fw_info,
77
73
  fw_impl)
78
74
 
79
75
  if tb_w is not None:
@@ -83,7 +79,6 @@ def graph_preparation_runner(in_model: Any,
83
79
  fqc,
84
80
  quantization_config,
85
81
  bit_width_config,
86
- fw_info,
87
82
  tb_w,
88
83
  fw_impl,
89
84
  mixed_precision_enable=mixed_precision_enable,
@@ -96,7 +91,6 @@ def get_finalized_graph(initial_graph: Graph,
96
91
  fqc: FrameworkQuantizationCapabilities,
97
92
  quant_config: QuantizationConfig = DEFAULTCONFIG,
98
93
  bit_width_config: BitWidthConfig = None,
99
- fw_info: FrameworkInfo = None,
100
94
  tb_w: TensorboardWriter = None,
101
95
  fw_impl: FrameworkImplementation = None,
102
96
  mixed_precision_enable: bool = False,
@@ -111,8 +105,6 @@ def get_finalized_graph(initial_graph: Graph,
111
105
  quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
112
106
  quantized.
113
107
  bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
114
- fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g.,
115
- kernel channels indices, groups of layers by how they should be quantized, etc.)
116
108
  tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
117
109
  fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
118
110
  mixed_precision_enable: is mixed precision enabled.
@@ -124,7 +116,7 @@ def get_finalized_graph(initial_graph: Graph,
124
116
  ######################################
125
117
  # Graph substitution (prepare graph)
126
118
  ######################################
127
- graph = substitute(initial_graph, fw_impl.get_substitutions_prepare_graph(fw_info))
119
+ graph = substitute(initial_graph, fw_impl.get_substitutions_prepare_graph())
128
120
 
129
121
  if tb_w is not None:
130
122
  tb_w.add_graph(graph, 'after_graph_preparation')
@@ -134,7 +126,6 @@ def get_finalized_graph(initial_graph: Graph,
134
126
  ##########################################
135
127
  for node in graph.nodes:
136
128
  node.prior_info = fw_impl.get_node_prior_info(node=node,
137
- fw_info=fw_info,
138
129
  graph=graph)
139
130
 
140
131
  ##################################################
@@ -170,8 +161,7 @@ def get_finalized_graph(initial_graph: Graph,
170
161
  # Channel equalization
171
162
  ######################################
172
163
  transformed_graph = substitute(transformed_graph,
173
- fw_impl.get_substitutions_channel_equalization(quant_config,
174
- fw_info))
164
+ fw_impl.get_substitutions_channel_equalization(quant_config))
175
165
 
176
166
  if tb_w is not None:
177
167
  tb_w.add_graph(transformed_graph, 'after_graph_marking')
@@ -190,7 +180,6 @@ def get_finalized_graph(initial_graph: Graph,
190
180
  def read_model_to_graph(in_model: Any,
191
181
  representative_data_gen: Callable,
192
182
  fqc: FrameworkQuantizationCapabilities,
193
- fw_info: FrameworkInfo = None,
194
183
  fw_impl: FrameworkImplementation = None) -> Graph:
195
184
 
196
185
  """
@@ -201,8 +190,6 @@ def read_model_to_graph(in_model: Any,
201
190
  representative_data_gen: Dataset used for calibration.
202
191
  fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
203
192
  the attached framework operator's information.
204
- fw_info: Information needed for quantization about the specific framework (e.g.,
205
- kernel channels indices, groups of layers by how they should be quantized, etc.)
206
193
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
207
194
 
208
195
  Returns:
@@ -210,6 +197,5 @@ def read_model_to_graph(in_model: Any,
210
197
  """
211
198
  graph = fw_impl.model_reader(in_model,
212
199
  representative_data_gen)
213
- graph.set_fw_info(fw_info)
214
200
  graph.set_fqc(fqc)
215
201
  return graph
@@ -17,7 +17,6 @@ from typing import List
17
17
  from model_compression_toolkit.core import FrameworkInfo
18
18
  from model_compression_toolkit.core.common import BaseNode
19
19
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
20
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
21
20
  from model_compression_toolkit.core import common
22
21
  from tensorflow.python.util.object_identity import Reference as TFReference
23
22
 
@@ -29,20 +28,17 @@ class FloatKerasModelBuilder(KerasModelBuilder):
29
28
  def __init__(self,
30
29
  graph: common.Graph,
31
30
  append2output=None,
32
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
33
31
  return_float_outputs: bool = False):
34
32
  """
35
33
 
36
34
  Args:
37
35
  graph: Graph to build the model from.
38
36
  append2output: Nodes to append to model's output.
39
- fw_info: Information about the specific framework of the model that is built.
40
37
  return_float_outputs: Whether the model returns float tensors or not.
41
38
  """
42
39
 
43
40
  super().__init__(graph,
44
41
  append2output,
45
- fw_info,
46
42
  return_float_outputs)
47
43
 
48
44
  def _quantize_node_activations(self,
@@ -35,8 +35,6 @@ from typing import Any, Dict, List, Tuple, Callable
35
35
  from tensorflow.python.util.object_identity import Reference as TFReference
36
36
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
37
37
  from model_compression_toolkit.core import common
38
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
39
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
40
38
  from model_compression_toolkit.core.common import BaseNode
41
39
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
42
40
  from model_compression_toolkit.core.keras.back2framework.instance_builder import OperationHandler
@@ -57,7 +55,6 @@ class KerasModelBuilder(BaseModelBuilder):
57
55
  def __init__(self,
58
56
  graph: common.Graph,
59
57
  append2output=None,
60
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
61
58
  return_float_outputs: bool = False,
62
59
  wrapper: Callable = None,
63
60
  get_activation_quantizer_holder_fn: Callable=None):
@@ -66,7 +63,6 @@ class KerasModelBuilder(BaseModelBuilder):
66
63
  Args:
67
64
  graph: Graph to build the model from.
68
65
  append2output: Nodes to append to model's output.
69
- fw_info: Information about the specific framework of the model that is built.
70
66
  return_float_outputs: Whether the model returns float tensors or not.
71
67
  wrapper: A function wrapper keras Layers.
72
68
  get_activation_quantizer_holder_fn: Function to retrieve a quantization holder for a node.
@@ -75,7 +71,6 @@ class KerasModelBuilder(BaseModelBuilder):
75
71
 
76
72
  super().__init__(graph,
77
73
  append2output,
78
- fw_info,
79
74
  return_float_outputs)
80
75
 
81
76
  # Build an OperationHandler to handle conversions from graph nodes to Keras operators.
@@ -36,7 +36,6 @@ from model_compression_toolkit.core.keras.mixed_precision.configurable_weights_q
36
36
  from model_compression_toolkit.logger import Logger
37
37
  from model_compression_toolkit.core import common
38
38
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
39
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
40
39
 
41
40
 
42
41
  class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
@@ -47,14 +46,12 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
47
46
  def __init__(self,
48
47
  graph: common.Graph,
49
48
  append2output=None,
50
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
51
49
  return_float_outputs: bool = False):
52
50
  """
53
51
 
54
52
  Args:
55
53
  graph: Graph to build the model from.
56
54
  append2output: Nodes to append to model's output.
57
- fw_info: Information about the specific framework of the model that is built.
58
55
  return_float_outputs: Whether the model returns float tensors or not.
59
56
  """
60
57
 
@@ -62,7 +59,6 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
62
59
 
63
60
  super().__init__(graph,
64
61
  append2output,
65
- fw_info,
66
62
  return_float_outputs,
67
63
  wrapper=self.mixed_precision_wrapper,
68
64
  get_activation_quantizer_holder_fn=self.mixed_precision_activation_holder)
@@ -87,13 +83,12 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
87
83
  ValueError: if kernel attribute is quantized but not configurable.
88
84
  """
89
85
 
90
- kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
91
- if kernel_attr is None or not n.is_weights_quantization_enabled(kernel_attr):
86
+ if n.kernel_attr is None or not n.is_weights_quantization_enabled(n.kernel_attr):
92
87
  return layer
93
- if not n.is_configurable_weight(kernel_attr): # pragma: no cover
88
+ if not n.is_configurable_weight(n.kernel_attr): # pragma: no cover
94
89
  raise ValueError(f'Weight wrapper is not expected to be created for non-configurable weight of node {n}.')
95
- wq = ConfigurableWeightsQuantizer(**self._get_weights_configurable_quantizer_kwargs(n, kernel_attr))
96
- return KerasQuantizationWrapper(layer, weights_quantizers={kernel_attr: wq})
90
+ wq = ConfigurableWeightsQuantizer(**self._get_weights_configurable_quantizer_kwargs(n, n.kernel_attr))
91
+ return KerasQuantizationWrapper(layer, weights_quantizers={n.kernel_attr: wq})
97
92
 
98
93
  def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) -> Dict[str, Any]:
99
94
  """
@@ -147,13 +142,12 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
147
142
  # activation number of bits (in reversed order).
148
143
  # since only kernel attribute is quantized in weights mixed precision,
149
144
  # if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
150
- n.sort_node_candidates(self.fw_info)
145
+ n.sort_node_candidates()
151
146
 
152
147
  max_candidate_idx = n.find_max_candidate_index()
153
- kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
154
148
  activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
155
149
  'max_candidate_idx': max_candidate_idx,
156
- 'kernel_attr': kernel_attr})] \
150
+ 'kernel_attr': n.kernel_attr})] \
157
151
  * num_of_outputs
158
152
 
159
153
  # Holder by definition uses a single quantizer for the activation quantization
@@ -181,7 +175,7 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
181
175
 
182
176
  # creating a mapping between graph nodes and model's layers for mixed precision configurability
183
177
  conf_node2layers = {n.name: self._find_layers_in_model_by_node(n, model.layers)
184
- for n in self.graph.get_configurable_sorted_nodes(self.fw_info)}
178
+ for n in self.graph.get_configurable_sorted_nodes()}
185
179
 
186
180
  return model, user_info, conf_node2layers
187
181
 
@@ -231,8 +225,7 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
231
225
 
232
226
  """
233
227
  # Only layers with kernel op are considered weights configurable
234
- kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
235
- weights_quant = False if kernel_attr is None else n.is_weights_quantization_enabled(kernel_attr)
228
+ weights_quant = False if n.kernel_attr is None else n.is_weights_quantization_enabled(n.kernel_attr)
236
229
  act_quant = n.is_activation_quantization_enabled()
237
230
 
238
231
  if weights_quant and not act_quant:
@@ -18,7 +18,6 @@ from model_compression_toolkit.core import FrameworkInfo
18
18
  from model_compression_toolkit.core import common
19
19
  from model_compression_toolkit.core.common import BaseNode
20
20
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
21
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
22
21
  from tensorflow.python.util.object_identity import Reference as TFReference
23
22
 
24
23
 
@@ -30,20 +29,17 @@ class QuantizedKerasModelBuilder(KerasModelBuilder):
30
29
  def __init__(self,
31
30
  graph: common.Graph,
32
31
  append2output=None,
33
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
34
32
  return_float_outputs: bool = False):
35
33
  """
36
34
 
37
35
  Args:
38
36
  graph: Graph to build the model from.
39
37
  append2output: Nodes to append to model's output.
40
- fw_info: Information about the specific framework of the model that is built.
41
38
  return_float_outputs: Whether the model returns float tensors or not.
42
39
  """
43
40
 
44
41
  super().__init__(graph,
45
42
  append2output,
46
- fw_info,
47
43
  return_float_outputs)
48
44
 
49
45
  def _quantize_node_activations(self,
@@ -13,102 +13,153 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
-
17
16
  import tensorflow as tf
18
17
 
18
+ from typing import Tuple, Any, Dict
19
+ from functools import wraps
20
+
19
21
  from model_compression_toolkit.core.keras.quantizer.lut_fake_quant import activation_lut_kmean_quantizer
20
22
  from packaging import version
21
23
 
22
24
  if version.parse(tf.__version__) >= version.parse("2.13"):
23
- from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU
25
+ from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU, Activation
24
26
  else:
25
- from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU # pragma: no cover
26
-
27
- from model_compression_toolkit.defaultdict import DefaultDict
28
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo, DEFAULT_KERNEL_ATTRIBUTES
27
+ from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU, Activation # pragma: no cover
28
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo, set_fw_info, ChannelAxisMapping
29
29
  from mct_quantizers import QuantizationMethod
30
- from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
30
+ from model_compression_toolkit.constants import SOFTMAX_THRESHOLD, ACTIVATION
31
31
  from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
32
32
  KERNEL, DEPTHWISE_KERNEL, GELU
33
33
  from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization
34
34
 
35
- """
36
- Map each layer to a list of its' weights attributes that should get quantized.
37
- If a layer that is not listed here is queried, [None] is returned.
38
- """
39
- KERNEL_ATTRIBUTES = DefaultDict({Conv2D: [KERNEL],
40
- DepthwiseConv2D: [DEPTHWISE_KERNEL],
41
- Dense: [KERNEL],
42
- Conv2DTranspose: [KERNEL]}, DEFAULT_KERNEL_ATTRIBUTES)
43
-
44
-
45
- """
46
- Map a layer to its kernel's output and input channels indices.
47
- Map's values are tuples of (output_channel_index, input_channel_index).
48
- Default value is returned for layers that are not included.
49
- """
50
- DEFAULT_CHANNEL_AXIS_DICT = DefaultDict({Conv2D: (3, 2),
51
- DepthwiseConv2D: (2, 2),
52
- Dense: (1, 0),
53
- Conv2DTranspose: (2, 3)}, (None, None))
54
-
55
-
56
- """
57
- Map a layer to its output channel axis.
58
- Where axis=-1 is the last axis
59
- """
60
- DEFAULT_OUT_CHANNEL_AXIS_DICT = DefaultDict({Conv2D: -1,
61
- DepthwiseConv2D: -1,
62
- Dense: -1,
63
- Conv2DTranspose: -1},
64
- -1)
65
-
66
-
67
- """
68
- Map from an activation function to its min/max output values (if known).
69
- The values are used for tensor min/max values initialization.
70
- """
71
- ACTIVATION2MINMAX = {SOFTMAX: (0, SOFTMAX_THRESHOLD),
72
- SIGMOID: (0, 1),
73
- LINEAR: (None, None),
74
- IDENTITY: (None, None),
75
- TANH: (-1, 1),
76
- SWISH: (-0.279, None),
77
- RELU: (0, None),
78
- SELU: (-1.76, None),
79
- GELU: (-0.17, None),
80
- }
81
-
82
- """
83
- Map from an Keras layer to its min/max output values (if known).
84
- The values are used for tensor min/max values initialization.
85
- """
86
- LAYER2MINMAX = {Softmax: (0, SOFTMAX_THRESHOLD),
87
- ELU: (-1, None),
88
- tf.nn.silu: (-0.279, None),
89
- tf.nn.swish: (-0.279, None),
90
- tf.nn.sigmoid: (0, 1),
91
- tf.nn.tanh: (-1, 1),
92
- tf.nn.relu: (0, None),
93
- tf.nn.relu6: (0, None),
94
- tf.nn.gelu: (-0.17, None),
95
- tf.nn.elu: (-1, None),
96
- tf.nn.selu: (-1.76, None),
97
- tf.nn.softplus: (0, None),
98
- tf.nn.softmax: (0, SOFTMAX_THRESHOLD),
99
- }
100
- """
101
- Mapping from a QuantizationMethod to an activation quantizer function.
102
- """
103
- ACTIVATION_QUANTIZER_MAPPING = {QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
104
- QuantizationMethod.SYMMETRIC: symmetric_quantization,
105
- QuantizationMethod.UNIFORM: uniform_quantization,
106
- QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer}
107
-
108
-
109
- DEFAULT_KERAS_INFO = FrameworkInfo(ACTIVATION_QUANTIZER_MAPPING,
110
- DEFAULT_CHANNEL_AXIS_DICT,
111
- ACTIVATION2MINMAX,
112
- LAYER2MINMAX,
113
- KERNEL_ATTRIBUTES,
114
- DEFAULT_OUT_CHANNEL_AXIS_DICT)
35
+
36
+ class KerasInfo(FrameworkInfo):
37
+ """
38
+ Extra field defined to handle Activation layer functions:
39
+
40
+ _activation_min_max_mapping (Dict[str, tuple]): Dictionary from an activation function to its min/max output values.
41
+
42
+ """
43
+
44
+ """
45
+ Map each layer to it's weight attribute that should get quantized.
46
+ If a layer that is not listed here is queried, None is returned.
47
+ """
48
+ kernel_ops_attribute_mapping = {Conv2D: KERNEL,
49
+ DepthwiseConv2D: DEPTHWISE_KERNEL,
50
+ Dense: KERNEL,
51
+ Conv2DTranspose: KERNEL}
52
+
53
+ """
54
+ Map a layer to its kernel's output and input channels indices.
55
+ Map's values are tuples of (output_channel_index, input_channel_index).
56
+ Default value is returned for layers that are not included.
57
+ """
58
+ kernel_channels_mapping = {Conv2D: ChannelAxisMapping(3, 2),
59
+ DepthwiseConv2D: ChannelAxisMapping(2, 2),
60
+ Dense: ChannelAxisMapping(1, 0),
61
+ Conv2DTranspose: ChannelAxisMapping(2, 3)}
62
+
63
+ """
64
+ Map a layer to its output channel axis.
65
+ Where axis=-1 is the last axis
66
+ """
67
+ out_channel_axis_mapping = {Conv2D: -1,
68
+ DepthwiseConv2D: -1,
69
+ Dense: -1,
70
+ Conv2DTranspose: -1}
71
+
72
+ """
73
+ Map from an activation function name to its min/max output values (if known).
74
+ The values are used for tensor min/max values initialization.
75
+ """
76
+ _activation_min_max_mapping = {SOFTMAX: (0, SOFTMAX_THRESHOLD),
77
+ SIGMOID: (0, 1),
78
+ LINEAR: (None, None),
79
+ IDENTITY: (None, None),
80
+ TANH: (-1, 1),
81
+ SWISH: (-0.279, None),
82
+ RELU: (0, None),
83
+ SELU: (-1.76, None),
84
+ GELU: (-0.17, None),
85
+ }
86
+
87
+ """
88
+ Map from an Keras module to its min/max output values (if known).
89
+ The values are used for tensor min/max values initialization.
90
+ """
91
+ _layer_min_max_mapping = {Softmax: (0, SOFTMAX_THRESHOLD),
92
+ ELU: (-1, None),
93
+ tf.nn.silu: (-0.279, None),
94
+ tf.nn.swish: (-0.279, None),
95
+ tf.nn.sigmoid: (0, 1),
96
+ tf.nn.tanh: (-1, 1),
97
+ tf.nn.relu: (0, None),
98
+ tf.nn.relu6: (0, None),
99
+ tf.nn.gelu: (-0.17, None),
100
+ tf.nn.elu: (-1, None),
101
+ tf.nn.selu: (-1.76, None),
102
+ tf.nn.softplus: (0, None),
103
+ tf.nn.softmax: (0, SOFTMAX_THRESHOLD),
104
+ }
105
+
106
+ """
107
+ Mapping from a QuantizationMethod to an activation quantizer function.
108
+ """
109
+ activation_quantizer_mapping = {QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
110
+ QuantizationMethod.SYMMETRIC: symmetric_quantization,
111
+ QuantizationMethod.UNIFORM: uniform_quantization,
112
+ QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer}
113
+
114
+ @classmethod
115
+ def get_layer_min_max(cls, layer: Any, fw_attrs: Dict) -> Tuple[float, float]:
116
+ """
117
+ Return layer min/max mapping the FrameworkInfo holds.
118
+ Args:
119
+ layer: A layer to check if has a min/max known values.
120
+ fw_attrs: framework attributes from framework layer.
121
+
122
+ Returns:
123
+ Layer's min/max known values.
124
+ """
125
+
126
+ if cls.layers_has_min_max(layer):
127
+ return cls._layer_min_max_mapping[layer]
128
+ elif isinstance(layer, Activation) and fw_attrs[ACTIVATION] in cls._activation_min_max_mapping:
129
+ return cls._activation_min_max_mapping[fw_attrs[ACTIVATION]]
130
+ else:
131
+ return None, None
132
+
133
+ @classmethod
134
+ def get_kernel_channels(cls, node_type: Any) -> ChannelAxisMapping:
135
+ """
136
+ Returns node's channels mapping from kernel_channels_mapping or framework specific default value.
137
+ Args:
138
+ node_type: A node type
139
+
140
+ Returns:
141
+ Node's channels mapping.
142
+
143
+ """
144
+ return cls.kernel_channels_mapping.get(node_type, cls._default_channel_mapping)
145
+
146
+ @classmethod
147
+ def get_out_channel_axis(cls, node_type: Any):
148
+ """
149
+ Returns node's output channel mapping from out_channel_axis_mapping or framework specific default value.
150
+ Args:
151
+ node_type: A node type.
152
+
153
+ Returns:
154
+ Node's output channel axis.
155
+
156
+ """
157
+ return cls.out_channel_axis_mapping.get(node_type, -1)
158
+
159
+
160
+ def set_keras_info(func):
161
+ @wraps(func)
162
+ def wrapper(*args, **kwargs):
163
+ set_fw_info(KerasInfo)
164
+ return func(*args, **kwargs)
165
+ return wrapper
@@ -21,7 +21,6 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
21
21
  from model_compression_toolkit.core.common.substitutions.batchnorm_folding import BatchNormalizationFolding, BatchNormalizationForwardFolding
22
22
  from model_compression_toolkit.core.keras.constants import KERNEL, LINEAR, ACTIVATION, DEPTHWISE_KERNEL, BIAS, GAMMA, BETA, \
23
23
  MOVING_MEAN, MOVING_VARIANCE, EPSILON, USE_BIAS, LAYER_NAME, GROUPS
24
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
25
24
 
26
25
 
27
26
  def batchnorm_folding_node_matchers() -> [BaseNode, BaseNode]:
@@ -77,9 +76,7 @@ def update_kernel_for_bn_folding_fn(conv_node: BaseNode,
77
76
  else:
78
77
  kernel = kernel * weights_scale.reshape((1, 1, 1, -1))
79
78
 
80
- kernel_name = DEFAULT_KERAS_INFO.get_kernel_op_attributes(conv_node.type)[0]
81
-
82
- return kernel, kernel_name
79
+ return kernel, conv_node.kernel_attr
83
80
 
84
81
 
85
82
  def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode,
@@ -108,9 +105,7 @@ def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode,
108
105
  bias_update = (kernel * bias_factor.reshape((1, 1, -1, 1))).sum(2)
109
106
  kernel = kernel * weights_scale.reshape((1, 1, -1, 1))
110
107
 
111
- kernel_name = DEFAULT_KERAS_INFO.get_kernel_op_attributes(conv_node.type)[0]
112
-
113
- return kernel, bias + bias_update.flatten(), kernel_name
108
+ return kernel, bias + bias_update.flatten(), conv_node.kernel_attr
114
109
 
115
110
 
116
111
  def get_kernel_hw_fn(kernel: np.ndarray) -> [int, int]:
@@ -27,7 +27,6 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNo
27
27
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
28
28
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
29
29
  from model_compression_toolkit.constants import REUSE, REUSE_GROUP
30
- from model_compression_toolkit.core.keras.reader.node_builder import REUSED_IDENTIFIER
31
30
  from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, NUM_HEADS, KEY_DIM, VALUE_DIM, \
32
31
  QUERY_SHAPE, KEY_SHAPE, VALUE_SHAPE, OUTPUT_SHAPE, ATTENTION_AXES, ACTIVATION, GROUPS, LINEAR, FILTERS, PADDING, \
33
32
  FUNCTION, DIMS, TARGET_SHAPE, F_STRIDED_SLICE, F_STACK, Q_KERNEL, Q_BIAS, K_KERNEL, K_BIAS, V_KERNEL, V_BIAS, \
@@ -97,16 +97,14 @@ class BaseInputScaling(common.BaseSubstitution):
97
97
  scale_factor = threshold_float / threshold
98
98
  graph.user_info.set_input_scale(1 / scale_factor)
99
99
 
100
- kernel_attr = graph.fw_info.get_kernel_op_attributes(linear_layer.type)[0]
101
-
102
- w1_fixed = linear_layer.get_weights_by_keys(kernel_attr) * scale_factor
103
- linear_layer.set_weights_by_keys(kernel_attr, w1_fixed)
100
+ w1_fixed = linear_layer.get_weights_by_keys(linear_layer.kernel_attr) * scale_factor
101
+ linear_layer.set_weights_by_keys(linear_layer.kernel_attr, w1_fixed)
104
102
 
105
103
  graph.scale_stats_collector(input_layer, 1 / scale_factor)
106
104
 
107
105
  # After scaling weights may have different thresholds so it needs to be recalculated
108
106
  for nqc in linear_layer.candidates_quantization_cfg:
109
- nqc.weights_quantization_cfg.get_attr_config(kernel_attr).calculate_and_set_weights_params(w1_fixed,
107
+ nqc.weights_quantization_cfg.get_attr_config(linear_layer.kernel_attr).calculate_and_set_weights_params(w1_fixed,
110
108
  nqc.weights_quantization_cfg.min_threshold)
111
109
 
112
110
  return graph
@@ -63,17 +63,15 @@ class ScaleEqualization(BaseScaleEqualization):
63
63
  """
64
64
 
65
65
  def __init__(self,
66
- quant_config: QuantizationConfig,
67
- fw_info: FrameworkInfo):
66
+ quant_config: QuantizationConfig):
68
67
  """
69
68
  Initialize a ScaleEqualization object.
70
69
  Args:
71
70
  quant_config: Quantization configuration.
72
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
73
71
  groups of layers by how they should be quantized, etc.)
74
72
  """
75
73
 
76
- super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER,
74
+ super().__init__(quant_config=quant_config, matcher_instance=MATCHER,
77
75
  kernel_str=KERNEL, bias_str=BIAS)
78
76
 
79
77
 
@@ -83,17 +81,15 @@ class ScaleEqualizationWithPad(BaseScaleEqualization):
83
81
  """
84
82
 
85
83
  def __init__(self,
86
- quant_config: QuantizationConfig,
87
- fw_info: FrameworkInfo):
84
+ quant_config: QuantizationConfig):
88
85
  """
89
86
  Initialize a ScaleEqualizationWithPad object.
90
87
  Args:
91
88
  quant_config: Quantization configuration.
92
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
93
89
  groups of layers by how they should be quantized, etc.)
94
90
  """
95
91
 
96
- super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER_WITH_PAD,
92
+ super().__init__(quant_config=quant_config, matcher_instance=MATCHER_WITH_PAD,
97
93
  kernel_str=KERNEL, bias_str=BIAS)
98
94
 
99
95
 
@@ -104,17 +100,15 @@ class ScaleEqualizationMidActivation(BaseScaleEqualization):
104
100
  """
105
101
 
106
102
  def __init__(self,
107
- quant_config: QuantizationConfig,
108
- fw_info: FrameworkInfo):
103
+ quant_config: QuantizationConfig):
109
104
  """
110
105
  Initialize a ScaleEqualizationMidActivation object.
111
106
  Args:
112
107
  quant_config: Quantization configuration.
113
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
114
108
  groups of layers by how they should be quantized, etc.)
115
109
  """
116
110
 
117
- super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER_MID,
111
+ super().__init__(quant_config=quant_config, matcher_instance=MATCHER_MID,
118
112
  kernel_str=KERNEL, bias_str=BIAS)
119
113
 
120
114
 
@@ -124,15 +118,13 @@ class ScaleEqualizationMidActivationWithPad(BaseScaleEqualization):
124
118
  """
125
119
 
126
120
  def __init__(self,
127
- quant_config: QuantizationConfig,
128
- fw_info: FrameworkInfo):
121
+ quant_config: QuantizationConfig):
129
122
  """
130
123
  Initialize a ScaleEqualizationMidActivationWithPad object.
131
124
  Args:
132
125
  quant_config: Quantization configuration.
133
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
134
126
  groups of layers by how they should be quantized, etc.)
135
127
  """
136
128
 
137
- super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER_MID_WITH_PAD,
129
+ super().__init__(quant_config=quant_config, matcher_instance=MATCHER_MID_WITH_PAD,
138
130
  kernel_str=KERNEL, bias_str=BIAS)