mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250927.534__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -35,6 +35,8 @@ from typing import Any, Dict, List, Tuple, Callable
35
35
  from tensorflow.python.util.object_identity import Reference as TFReference
36
36
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
37
37
  from model_compression_toolkit.core import common
38
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
39
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
38
40
  from model_compression_toolkit.core.common import BaseNode
39
41
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
40
42
  from model_compression_toolkit.core.keras.back2framework.instance_builder import OperationHandler
@@ -55,6 +57,7 @@ class KerasModelBuilder(BaseModelBuilder):
55
57
  def __init__(self,
56
58
  graph: common.Graph,
57
59
  append2output=None,
60
+ fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
58
61
  return_float_outputs: bool = False,
59
62
  wrapper: Callable = None,
60
63
  get_activation_quantizer_holder_fn: Callable=None):
@@ -63,6 +66,7 @@ class KerasModelBuilder(BaseModelBuilder):
63
66
  Args:
64
67
  graph: Graph to build the model from.
65
68
  append2output: Nodes to append to model's output.
69
+ fw_info: Information about the specific framework of the model that is built.
66
70
  return_float_outputs: Whether the model returns float tensors or not.
67
71
  wrapper: A function wrapper keras Layers.
68
72
  get_activation_quantizer_holder_fn: Function to retrieve a quantization holder for a node.
@@ -71,6 +75,7 @@ class KerasModelBuilder(BaseModelBuilder):
71
75
 
72
76
  super().__init__(graph,
73
77
  append2output,
78
+ fw_info,
74
79
  return_float_outputs)
75
80
 
76
81
  # Build an OperationHandler to handle conversions from graph nodes to Keras operators.
@@ -36,6 +36,7 @@ from model_compression_toolkit.core.keras.mixed_precision.configurable_weights_q
36
36
  from model_compression_toolkit.logger import Logger
37
37
  from model_compression_toolkit.core import common
38
38
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
39
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
39
40
 
40
41
 
41
42
  class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
@@ -46,12 +47,14 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
46
47
  def __init__(self,
47
48
  graph: common.Graph,
48
49
  append2output=None,
50
+ fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
49
51
  return_float_outputs: bool = False):
50
52
  """
51
53
 
52
54
  Args:
53
55
  graph: Graph to build the model from.
54
56
  append2output: Nodes to append to model's output.
57
+ fw_info: Information about the specific framework of the model that is built.
55
58
  return_float_outputs: Whether the model returns float tensors or not.
56
59
  """
57
60
 
@@ -59,6 +62,7 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
59
62
 
60
63
  super().__init__(graph,
61
64
  append2output,
65
+ fw_info,
62
66
  return_float_outputs,
63
67
  wrapper=self.mixed_precision_wrapper,
64
68
  get_activation_quantizer_holder_fn=self.mixed_precision_activation_holder)
@@ -83,12 +87,13 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
83
87
  ValueError: if kernel attribute is quantized but not configurable.
84
88
  """
85
89
 
86
- if n.kernel_attr is None or not n.is_weights_quantization_enabled(n.kernel_attr):
90
+ kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
91
+ if kernel_attr is None or not n.is_weights_quantization_enabled(kernel_attr):
87
92
  return layer
88
- if not n.is_configurable_weight(n.kernel_attr): # pragma: no cover
93
+ if not n.is_configurable_weight(kernel_attr): # pragma: no cover
89
94
  raise ValueError(f'Weight wrapper is not expected to be created for non-configurable weight of node {n}.')
90
- wq = ConfigurableWeightsQuantizer(**self._get_weights_configurable_quantizer_kwargs(n, n.kernel_attr))
91
- return KerasQuantizationWrapper(layer, weights_quantizers={n.kernel_attr: wq})
95
+ wq = ConfigurableWeightsQuantizer(**self._get_weights_configurable_quantizer_kwargs(n, kernel_attr))
96
+ return KerasQuantizationWrapper(layer, weights_quantizers={kernel_attr: wq})
92
97
 
93
98
  def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) -> Dict[str, Any]:
94
99
  """
@@ -142,12 +147,13 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
142
147
  # activation number of bits (in reversed order).
143
148
  # since only kernel attribute is quantized in weights mixed precision,
144
149
  # if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
145
- n.sort_node_candidates()
150
+ n.sort_node_candidates(self.fw_info)
146
151
 
147
152
  max_candidate_idx = n.find_max_candidate_index()
153
+ kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
148
154
  activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
149
155
  'max_candidate_idx': max_candidate_idx,
150
- 'kernel_attr': n.kernel_attr})] \
156
+ 'kernel_attr': kernel_attr})] \
151
157
  * num_of_outputs
152
158
 
153
159
  # Holder by definition uses a single quantizer for the activation quantization
@@ -175,7 +181,7 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
175
181
 
176
182
  # creating a mapping between graph nodes and model's layers for mixed precision configurability
177
183
  conf_node2layers = {n.name: self._find_layers_in_model_by_node(n, model.layers)
178
- for n in self.graph.get_configurable_sorted_nodes()}
184
+ for n in self.graph.get_configurable_sorted_nodes(self.fw_info)}
179
185
 
180
186
  return model, user_info, conf_node2layers
181
187
 
@@ -225,7 +231,8 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
225
231
 
226
232
  """
227
233
  # Only layers with kernel op are considered weights configurable
228
- weights_quant = False if n.kernel_attr is None else n.is_weights_quantization_enabled(n.kernel_attr)
234
+ kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
235
+ weights_quant = False if kernel_attr is None else n.is_weights_quantization_enabled(kernel_attr)
229
236
  act_quant = n.is_activation_quantization_enabled()
230
237
 
231
238
  if weights_quant and not act_quant:
@@ -14,11 +14,11 @@
14
14
  # ==============================================================================
15
15
  from typing import List
16
16
 
17
+ from model_compression_toolkit.core import FrameworkInfo
17
18
  from model_compression_toolkit.core import common
18
19
  from model_compression_toolkit.core.common import BaseNode
19
- from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
20
- from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
21
20
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
21
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
22
22
  from tensorflow.python.util.object_identity import Reference as TFReference
23
23
 
24
24
 
@@ -30,17 +30,20 @@ class QuantizedKerasModelBuilder(KerasModelBuilder):
30
30
  def __init__(self,
31
31
  graph: common.Graph,
32
32
  append2output=None,
33
+ fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
33
34
  return_float_outputs: bool = False):
34
35
  """
35
36
 
36
37
  Args:
37
38
  graph: Graph to build the model from.
38
39
  append2output: Nodes to append to model's output.
40
+ fw_info: Information about the specific framework of the model that is built.
39
41
  return_float_outputs: Whether the model returns float tensors or not.
40
42
  """
41
43
 
42
44
  super().__init__(graph,
43
45
  append2output,
46
+ fw_info,
44
47
  return_float_outputs)
45
48
 
46
49
  def _quantize_node_activations(self,
@@ -57,6 +60,4 @@ class QuantizedKerasModelBuilder(KerasModelBuilder):
57
60
  Output of the node.
58
61
 
59
62
  """
60
- activation_quantizer = get_activation_quantization_fn(node.final_activation_quantization_cfg,
61
- get_activation_quantization_fn_factory)
62
- return activation_quantizer(input_tensors)
63
+ return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
@@ -13,142 +13,102 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import tensorflow as tf
17
16
 
18
- from typing import Tuple, Any, Dict
19
- from functools import wraps
17
+ import tensorflow as tf
20
18
 
19
+ from model_compression_toolkit.core.keras.quantizer.lut_fake_quant import activation_lut_kmean_quantizer
21
20
  from packaging import version
22
21
 
23
22
  if version.parse(tf.__version__) >= version.parse("2.13"):
24
- from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU, Activation
23
+ from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU
25
24
  else:
26
- from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU, Activation # pragma: no cover
27
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo, set_fw_info, ChannelAxisMapping
28
- from model_compression_toolkit.constants import SOFTMAX_THRESHOLD, ACTIVATION
25
+ from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU # pragma: no cover
26
+
27
+ from model_compression_toolkit.defaultdict import DefaultDict
28
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo, DEFAULT_KERNEL_ATTRIBUTES
29
+ from mct_quantizers import QuantizationMethod
30
+ from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
29
31
  from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
30
32
  KERNEL, DEPTHWISE_KERNEL, GELU
31
-
32
-
33
- class KerasInfo(FrameworkInfo):
34
- """
35
- Extra field defined to handle Activation layer functions:
36
-
37
- _activation_min_max_mapping (Dict[str, tuple]): Dictionary from an activation function to its min/max output values.
38
-
39
- """
40
-
41
- """
42
- Map each layer to it's weight attribute that should get quantized.
43
- If a layer that is not listed here is queried, None is returned.
44
- """
45
- kernel_ops_attribute_mapping = {Conv2D: KERNEL,
46
- DepthwiseConv2D: DEPTHWISE_KERNEL,
47
- Dense: KERNEL,
48
- Conv2DTranspose: KERNEL}
49
-
50
- """
51
- Map a layer to its kernel's output and input channels indices.
52
- Map's values are tuples of (output_channel_index, input_channel_index).
53
- Default value is returned for layers that are not included.
54
- """
55
- kernel_channels_mapping = {Conv2D: ChannelAxisMapping(3, 2),
56
- DepthwiseConv2D: ChannelAxisMapping(2, 2),
57
- Dense: ChannelAxisMapping(1, 0),
58
- Conv2DTranspose: ChannelAxisMapping(2, 3)}
59
-
60
- """
61
- Map a layer to its output channel axis.
62
- Where axis=-1 is the last axis
63
- """
64
- out_channel_axis_mapping = {Conv2D: -1,
65
- DepthwiseConv2D: -1,
66
- Dense: -1,
67
- Conv2DTranspose: -1}
68
-
69
- """
70
- Map from an activation function name to its min/max output values (if known).
71
- The values are used for tensor min/max values initialization.
72
- """
73
- _activation_min_max_mapping = {SOFTMAX: (0, SOFTMAX_THRESHOLD),
74
- SIGMOID: (0, 1),
75
- LINEAR: (None, None),
76
- IDENTITY: (None, None),
77
- TANH: (-1, 1),
78
- SWISH: (-0.279, None),
79
- RELU: (0, None),
80
- SELU: (-1.76, None),
81
- GELU: (-0.17, None),
82
- }
83
-
84
- """
85
- Map from an Keras module to its min/max output values (if known).
86
- The values are used for tensor min/max values initialization.
87
- """
88
- _layer_min_max_mapping = {Softmax: (0, SOFTMAX_THRESHOLD),
89
- ELU: (-1, None),
90
- tf.nn.silu: (-0.279, None),
91
- tf.nn.swish: (-0.279, None),
92
- tf.nn.sigmoid: (0, 1),
93
- tf.nn.tanh: (-1, 1),
94
- tf.nn.relu: (0, None),
95
- tf.nn.relu6: (0, None),
96
- tf.nn.gelu: (-0.17, None),
97
- tf.nn.elu: (-1, None),
98
- tf.nn.selu: (-1.76, None),
99
- tf.nn.softplus: (0, None),
100
- tf.nn.softmax: (0, SOFTMAX_THRESHOLD),
101
- }
102
-
103
- @classmethod
104
- def get_layer_min_max(cls, layer: Any, fw_attrs: Dict) -> Tuple[float, float]:
105
- """
106
- Return layer min/max mapping the FrameworkInfo holds.
107
- Args:
108
- layer: A layer to check if has a min/max known values.
109
- fw_attrs: framework attributes from framework layer.
110
-
111
- Returns:
112
- Layer's min/max known values.
113
- """
114
-
115
- if cls.layers_has_min_max(layer):
116
- return cls._layer_min_max_mapping[layer]
117
- elif isinstance(layer, Activation) and fw_attrs[ACTIVATION] in cls._activation_min_max_mapping:
118
- return cls._activation_min_max_mapping[fw_attrs[ACTIVATION]]
119
- else:
120
- return None, None
121
-
122
- @classmethod
123
- def get_kernel_channels(cls, node_type: Any) -> ChannelAxisMapping:
124
- """
125
- Returns node's channels mapping from kernel_channels_mapping or framework specific default value.
126
- Args:
127
- node_type: A node type
128
-
129
- Returns:
130
- Node's channels mapping.
131
-
132
- """
133
- return cls.kernel_channels_mapping.get(node_type, cls._default_channel_mapping)
134
-
135
- @classmethod
136
- def get_out_channel_axis(cls, node_type: Any):
137
- """
138
- Returns node's output channel mapping from out_channel_axis_mapping or framework specific default value.
139
- Args:
140
- node_type: A node type.
141
-
142
- Returns:
143
- Node's output channel axis.
144
-
145
- """
146
- return cls.out_channel_axis_mapping.get(node_type)
147
-
148
-
149
- def set_keras_info(func):
150
- @wraps(func)
151
- def wrapper(*args, **kwargs):
152
- set_fw_info(KerasInfo)
153
- return func(*args, **kwargs)
154
- return wrapper
33
+ from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization
34
+
35
+ """
36
+ Map each layer to a list of its' weights attributes that should get quantized.
37
+ If a layer that is not listed here is queried, [None] is returned.
38
+ """
39
+ KERNEL_ATTRIBUTES = DefaultDict({Conv2D: [KERNEL],
40
+ DepthwiseConv2D: [DEPTHWISE_KERNEL],
41
+ Dense: [KERNEL],
42
+ Conv2DTranspose: [KERNEL]}, DEFAULT_KERNEL_ATTRIBUTES)
43
+
44
+
45
+ """
46
+ Map a layer to its kernel's output and input channels indices.
47
+ Map's values are tuples of (output_channel_index, input_channel_index).
48
+ Default value is returned for layers that are not included.
49
+ """
50
+ DEFAULT_CHANNEL_AXIS_DICT = DefaultDict({Conv2D: (3, 2),
51
+ DepthwiseConv2D: (2, 2),
52
+ Dense: (1, 0),
53
+ Conv2DTranspose: (2, 3)}, (None, None))
54
+
55
+
56
+ """
57
+ Map a layer to its output channel axis.
58
+ Where axis=-1 is the last axis
59
+ """
60
+ DEFAULT_OUT_CHANNEL_AXIS_DICT = DefaultDict({Conv2D: -1,
61
+ DepthwiseConv2D: -1,
62
+ Dense: -1,
63
+ Conv2DTranspose: -1},
64
+ -1)
65
+
66
+
67
+ """
68
+ Map from an activation function to its min/max output values (if known).
69
+ The values are used for tensor min/max values initialization.
70
+ """
71
+ ACTIVATION2MINMAX = {SOFTMAX: (0, SOFTMAX_THRESHOLD),
72
+ SIGMOID: (0, 1),
73
+ LINEAR: (None, None),
74
+ IDENTITY: (None, None),
75
+ TANH: (-1, 1),
76
+ SWISH: (-0.279, None),
77
+ RELU: (0, None),
78
+ SELU: (-1.76, None),
79
+ GELU: (-0.17, None),
80
+ }
81
+
82
+ """
83
+ Map from an Keras layer to its min/max output values (if known).
84
+ The values are used for tensor min/max values initialization.
85
+ """
86
+ LAYER2MINMAX = {Softmax: (0, SOFTMAX_THRESHOLD),
87
+ ELU: (-1, None),
88
+ tf.nn.silu: (-0.279, None),
89
+ tf.nn.swish: (-0.279, None),
90
+ tf.nn.sigmoid: (0, 1),
91
+ tf.nn.tanh: (-1, 1),
92
+ tf.nn.relu: (0, None),
93
+ tf.nn.relu6: (0, None),
94
+ tf.nn.gelu: (-0.17, None),
95
+ tf.nn.elu: (-1, None),
96
+ tf.nn.selu: (-1.76, None),
97
+ tf.nn.softplus: (0, None),
98
+ tf.nn.softmax: (0, SOFTMAX_THRESHOLD),
99
+ }
100
+ """
101
+ Mapping from a QuantizationMethod to an activation quantizer function.
102
+ """
103
+ ACTIVATION_QUANTIZER_MAPPING = {QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
104
+ QuantizationMethod.SYMMETRIC: symmetric_quantization,
105
+ QuantizationMethod.UNIFORM: uniform_quantization,
106
+ QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer}
107
+
108
+
109
+ DEFAULT_KERAS_INFO = FrameworkInfo(ACTIVATION_QUANTIZER_MAPPING,
110
+ DEFAULT_CHANNEL_AXIS_DICT,
111
+ ACTIVATION2MINMAX,
112
+ LAYER2MINMAX,
113
+ KERNEL_ATTRIBUTES,
114
+ DEFAULT_OUT_CHANNEL_AXIS_DICT)
@@ -21,6 +21,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
21
21
  from model_compression_toolkit.core.common.substitutions.batchnorm_folding import BatchNormalizationFolding, BatchNormalizationForwardFolding
22
22
  from model_compression_toolkit.core.keras.constants import KERNEL, LINEAR, ACTIVATION, DEPTHWISE_KERNEL, BIAS, GAMMA, BETA, \
23
23
  MOVING_MEAN, MOVING_VARIANCE, EPSILON, USE_BIAS, LAYER_NAME, GROUPS
24
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
24
25
 
25
26
 
26
27
  def batchnorm_folding_node_matchers() -> [BaseNode, BaseNode]:
@@ -76,7 +77,9 @@ def update_kernel_for_bn_folding_fn(conv_node: BaseNode,
76
77
  else:
77
78
  kernel = kernel * weights_scale.reshape((1, 1, 1, -1))
78
79
 
79
- return kernel, conv_node.kernel_attr
80
+ kernel_name = DEFAULT_KERAS_INFO.get_kernel_op_attributes(conv_node.type)[0]
81
+
82
+ return kernel, kernel_name
80
83
 
81
84
 
82
85
  def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode,
@@ -105,7 +108,9 @@ def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode,
105
108
  bias_update = (kernel * bias_factor.reshape((1, 1, -1, 1))).sum(2)
106
109
  kernel = kernel * weights_scale.reshape((1, 1, -1, 1))
107
110
 
108
- return kernel, bias + bias_update.flatten(), conv_node.kernel_attr
111
+ kernel_name = DEFAULT_KERAS_INFO.get_kernel_op_attributes(conv_node.type)[0]
112
+
113
+ return kernel, bias + bias_update.flatten(), kernel_name
109
114
 
110
115
 
111
116
  def get_kernel_hw_fn(kernel: np.ndarray) -> [int, int]:
@@ -27,6 +27,7 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNo
27
27
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
28
28
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
29
29
  from model_compression_toolkit.constants import REUSE, REUSE_GROUP
30
+ from model_compression_toolkit.core.keras.reader.node_builder import REUSED_IDENTIFIER
30
31
  from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, NUM_HEADS, KEY_DIM, VALUE_DIM, \
31
32
  QUERY_SHAPE, KEY_SHAPE, VALUE_SHAPE, OUTPUT_SHAPE, ATTENTION_AXES, ACTIVATION, GROUPS, LINEAR, FILTERS, PADDING, \
32
33
  FUNCTION, DIMS, TARGET_SHAPE, F_STRIDED_SLICE, F_STACK, Q_KERNEL, Q_BIAS, K_KERNEL, K_BIAS, V_KERNEL, V_BIAS, \
@@ -17,13 +17,14 @@
17
17
  from tensorflow.keras.layers import InputLayer, Dense, DepthwiseConv2D, Conv2D, Conv2DTranspose, ZeroPadding2D
18
18
  from typing import List
19
19
 
20
- from model_compression_toolkit.core import common, QuantizationConfig
20
+ from model_compression_toolkit.core import common
21
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
21
22
  from model_compression_toolkit.core.common.graph.base_graph import Graph
22
- from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, WalkMatcher
23
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, EdgeMatcher, WalkMatcher
23
24
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
25
+ from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
24
26
  from model_compression_toolkit.constants import THRESHOLD
25
- from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
26
- compute_weights_qparams
27
+ from model_compression_toolkit.core.keras.constants import KERNEL
27
28
  from model_compression_toolkit.logger import Logger
28
29
 
29
30
  input_node = NodeOperationMatcher(InputLayer)
@@ -47,8 +48,7 @@ class BaseInputScaling(common.BaseSubstitution):
47
48
  """
48
49
 
49
50
  def __init__(self,
50
- matcher_instance,
51
- quant_cfg: QuantizationConfig):
51
+ matcher_instance):
52
52
  """
53
53
  Matches: InputLayer -> (optional nodes) -> (Dense,Conv2D,DepthwiseConv2D,Conv2DTranspose)
54
54
  note: the optional nodes are nodes that don't affect the scaling (such as ZeroPadding)
@@ -56,11 +56,10 @@ class BaseInputScaling(common.BaseSubstitution):
56
56
  Create a substitution using different params which may affect the way this substitution is made.
57
57
  The substitution is looking for edges in the graph which are input layers connected to linear layers.
58
58
  Args:
59
- matcher_instance: matcher instance of type WalkMatcher.
60
- quant_cfg: quantization config.
59
+ matcher_instance: matcher instance of type WalkMatcher
60
+
61
61
  """
62
62
  super().__init__(matcher_instance=matcher_instance)
63
- self.quant_cfg = quant_cfg
64
63
 
65
64
  def substitute(self,
66
65
  graph: Graph,
@@ -98,21 +97,17 @@ class BaseInputScaling(common.BaseSubstitution):
98
97
  scale_factor = threshold_float / threshold
99
98
  graph.user_info.set_input_scale(1 / scale_factor)
100
99
 
101
- w1_fixed = linear_layer.get_weights_by_keys(linear_layer.kernel_attr) * scale_factor
102
- linear_layer.set_weights_by_keys(linear_layer.kernel_attr, w1_fixed)
100
+ kernel_attr = graph.fw_info.get_kernel_op_attributes(linear_layer.type)[0]
101
+
102
+ w1_fixed = linear_layer.get_weights_by_keys(kernel_attr) * scale_factor
103
+ linear_layer.set_weights_by_keys(kernel_attr, w1_fixed)
103
104
 
104
105
  graph.scale_stats_collector(input_layer, 1 / scale_factor)
105
106
 
106
107
  # After scaling weights may have different thresholds so it needs to be recalculated
107
108
  for nqc in linear_layer.candidates_quantization_cfg:
108
- attr_cfg = nqc.weights_quantization_cfg.get_attr_config(linear_layer.kernel_attr)
109
- assert attr_cfg.enable_weights_quantization
110
- w_params, _ = compute_weights_qparams(w1_fixed,
111
- attr_quant_config=attr_cfg,
112
- weights_error_method=self.quant_cfg.weights_error_method,
113
- l_p_value=self.quant_cfg.l_p_value,
114
- output_channels_axis=attr_cfg.weights_channels_axis.output)
115
- attr_cfg.set_weights_quantization_param(w_params)
109
+ nqc.weights_quantization_cfg.get_attr_config(kernel_attr).calculate_and_set_weights_params(w1_fixed,
110
+ nqc.weights_quantization_cfg.min_threshold)
116
111
 
117
112
  return graph
118
113
 
@@ -122,15 +117,12 @@ class InputScaling(BaseInputScaling):
122
117
  Substitution extends BaseInputScaling to the case of Input-->Linear
123
118
  """
124
119
 
125
- def __init__(self, quant_cfg: QuantizationConfig):
120
+ def __init__(self):
126
121
  """
127
122
  Initialize a ScaleEqualization object.
128
-
129
- Args:
130
- quant_cfg: quantization config.
131
123
  """
132
124
 
133
- super().__init__(matcher_instance=INPUT_MATCHER, quant_cfg=quant_cfg)
125
+ super().__init__(matcher_instance=INPUT_MATCHER)
134
126
 
135
127
 
136
128
  class InputScalingWithPad(BaseInputScaling):
@@ -138,12 +130,9 @@ class InputScalingWithPad(BaseInputScaling):
138
130
  Substitution extends BaseInputScaling to the case of Input-->ZeroPadding-->Linear
139
131
  """
140
132
 
141
- def __init__(self, quant_cfg: QuantizationConfig):
133
+ def __init__(self):
142
134
  """
143
135
  Initialize a ScaleEqualization object.
144
-
145
- Args:
146
- quant_cfg: quantization config.
147
136
  """
148
137
 
149
- super().__init__(matcher_instance=INPUT_MATCHER_WITH_PAD, quant_cfg=quant_cfg)
138
+ super().__init__(matcher_instance=INPUT_MATCHER_WITH_PAD)
@@ -63,15 +63,17 @@ class ScaleEqualization(BaseScaleEqualization):
63
63
  """
64
64
 
65
65
  def __init__(self,
66
- quant_config: QuantizationConfig):
66
+ quant_config: QuantizationConfig,
67
+ fw_info: FrameworkInfo):
67
68
  """
68
69
  Initialize a ScaleEqualization object.
69
70
  Args:
70
71
  quant_config: Quantization configuration.
72
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
71
73
  groups of layers by how they should be quantized, etc.)
72
74
  """
73
75
 
74
- super().__init__(quant_config=quant_config, matcher_instance=MATCHER,
76
+ super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER,
75
77
  kernel_str=KERNEL, bias_str=BIAS)
76
78
 
77
79
 
@@ -81,15 +83,17 @@ class ScaleEqualizationWithPad(BaseScaleEqualization):
81
83
  """
82
84
 
83
85
  def __init__(self,
84
- quant_config: QuantizationConfig):
86
+ quant_config: QuantizationConfig,
87
+ fw_info: FrameworkInfo):
85
88
  """
86
89
  Initialize a ScaleEqualizationWithPad object.
87
90
  Args:
88
91
  quant_config: Quantization configuration.
92
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
89
93
  groups of layers by how they should be quantized, etc.)
90
94
  """
91
95
 
92
- super().__init__(quant_config=quant_config, matcher_instance=MATCHER_WITH_PAD,
96
+ super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER_WITH_PAD,
93
97
  kernel_str=KERNEL, bias_str=BIAS)
94
98
 
95
99
 
@@ -100,15 +104,17 @@ class ScaleEqualizationMidActivation(BaseScaleEqualization):
100
104
  """
101
105
 
102
106
  def __init__(self,
103
- quant_config: QuantizationConfig):
107
+ quant_config: QuantizationConfig,
108
+ fw_info: FrameworkInfo):
104
109
  """
105
110
  Initialize a ScaleEqualizationMidActivation object.
106
111
  Args:
107
112
  quant_config: Quantization configuration.
113
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
108
114
  groups of layers by how they should be quantized, etc.)
109
115
  """
110
116
 
111
- super().__init__(quant_config=quant_config, matcher_instance=MATCHER_MID,
117
+ super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER_MID,
112
118
  kernel_str=KERNEL, bias_str=BIAS)
113
119
 
114
120
 
@@ -118,13 +124,15 @@ class ScaleEqualizationMidActivationWithPad(BaseScaleEqualization):
118
124
  """
119
125
 
120
126
  def __init__(self,
121
- quant_config: QuantizationConfig):
127
+ quant_config: QuantizationConfig,
128
+ fw_info: FrameworkInfo):
122
129
  """
123
130
  Initialize a ScaleEqualizationMidActivationWithPad object.
124
131
  Args:
125
132
  quant_config: Quantization configuration.
133
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
126
134
  groups of layers by how they should be quantized, etc.)
127
135
  """
128
136
 
129
- super().__init__(quant_config=quant_config, matcher_instance=MATCHER_MID_WITH_PAD,
137
+ super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER_MID_WITH_PAD,
130
138
  kernel_str=KERNEL, bias_str=BIAS)