mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -12,15 +12,23 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Any, List, Dict, TYPE_CHECKING
15
+
16
+
17
+ from typing import Callable, Any, List, Tuple, Union, Dict, TYPE_CHECKING
16
18
  from enum import Enum, auto
19
+ import numpy as np
17
20
 
18
- from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
21
+ from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
19
22
  from model_compression_toolkit.logger import Logger
23
+ from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
24
+ get_activation_quantization_params_fn, get_weights_quantization_params_fn
20
25
 
21
- from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
26
+ from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
27
+ QuantizationErrorMethod
28
+ from model_compression_toolkit.target_platform_capabilities.constants import POS_ATTR
22
29
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import \
23
- AttributeQuantizationConfig, OpQuantizationConfig
30
+ AttributeQuantizationConfig, \
31
+ OpQuantizationConfig
24
32
 
25
33
  if TYPE_CHECKING:
26
34
  from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
@@ -38,7 +46,6 @@ class ActivationQuantizationMode(Enum):
38
46
  FLN_QUANT = auto()
39
47
  PRESERVE_QUANT = auto()
40
48
  NO_QUANT = auto()
41
- FLN_NO_QUANT = auto()
42
49
 
43
50
 
44
51
  class BaseNodeQuantizationConfig(object):
@@ -59,11 +66,12 @@ class BaseNodeQuantizationConfig(object):
59
66
  kwargs: A dictionary with additional key arguments.
60
67
 
61
68
  """
69
+
62
70
  if hasattr(self, config_parameter_name):
63
71
  setattr(self, config_parameter_name, config_parameter_value)
64
72
  else:
65
- raise AttributeError(
66
- f"Parameter {config_parameter_name} could not be found in the node quantization config.")
73
+ Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config and "
74
+ f"was not updated!")
67
75
 
68
76
  def __repr__(self) -> str:
69
77
  """
@@ -77,14 +85,29 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
77
85
  """
78
86
  Attributes for configuring the quantization of the activations of a node.
79
87
  """
80
- def __init__(self, op_cfg: OpQuantizationConfig):
88
+ def __init__(self,
89
+ qc: QuantizationConfig,
90
+ op_cfg: OpQuantizationConfig,
91
+ activation_quantization_fn: Callable,
92
+ activation_quantization_params_fn: Callable
93
+ ):
81
94
  """
82
95
 
83
96
  Args:
97
+ qc: QuantizationConfig to create the node's config from.
84
98
  op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
99
+ activation_quantization_fn: Function to use when quantizing the node's activations.
100
+ activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
85
101
  """
102
+
103
+ self.activation_quantization_fn = activation_quantization_fn
104
+ self.activation_quantization_params_fn = activation_quantization_params_fn
105
+ self.activation_quantization_params = {}
86
106
  self.activation_quantization_method = op_cfg.activation_quantization_method
107
+ self.activation_error_method = qc.activation_error_method
87
108
  self.activation_n_bits = op_cfg.activation_n_bits
109
+ self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
110
+ self.activation_bias_correction_term = None
88
111
  if op_cfg.enable_activation_quantization and op_cfg.quantization_preserving:
89
112
  raise ValueError("An OpQuantizationConfig can't have both enable_activation_quantization and quantization_preserving enabled.")
90
113
  if op_cfg.enable_activation_quantization:
@@ -94,13 +117,15 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
94
117
  else:
95
118
  self.quant_mode = ActivationQuantizationMode.NO_QUANT
96
119
  self.signedness = op_cfg.signedness
97
-
98
- self.activation_quantization_params = {}
99
- # TODO: computed by compute_activation_bias_correction. Probably shouldnt be here.
100
- self.activation_bias_correction_term = None
101
- # Z-threshold is a global param from QuantizationConfig, however it can be overridden per node by NetworkEditor.
102
- # Since activation qparams are re-computed in several places, it's easier to keep it here and update it once.
103
- self.z_threshold = None
120
+ self.activation_channel_equalization = qc.activation_channel_equalization
121
+ self.input_scaling = qc.input_scaling
122
+ self.min_threshold = qc.min_threshold
123
+ self.l_p_value = qc.l_p_value
124
+ self.shift_negative_activation_correction = qc.shift_negative_activation_correction
125
+ self.z_threshold = qc.z_threshold
126
+ self.shift_negative_ratio = qc.shift_negative_ratio
127
+ self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
128
+ self.concat_threshold_update = qc.concat_threshold_update
104
129
 
105
130
  @property
106
131
  def enable_activation_quantization(self):
@@ -113,6 +138,65 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
113
138
  def fln_quantization(self):
114
139
  return self.quant_mode == ActivationQuantizationMode.FLN_QUANT
115
140
 
141
+ def quantize_node_output(self,
142
+ tensors: Any) -> Any:
143
+ """
144
+
145
+ Args:
146
+ tensors: framework tensor/s
147
+
148
+ Returns:
149
+ Framework tensor/s after applying fake quantization.
150
+
151
+ """
152
+ fake_quant = self.activation_quantization_fn(self.activation_n_bits,
153
+ self.activation_quantization_params)
154
+
155
+ if fake_quant is None:
156
+ Logger.critical(
157
+ "Layer is intended to be quantized, but the fake_quant function is None.") # pragma: no cover
158
+
159
+ return fake_quant(tensors)
160
+
161
+ @property
162
+ def activation_error_method(self) -> QuantizationErrorMethod:
163
+ """
164
+ activation_error_method getter.
165
+ """
166
+ return self._activation_error_method
167
+
168
+ @activation_error_method.setter
169
+ def activation_error_method(self, value: QuantizationErrorMethod):
170
+ """
171
+ activation_error_method setter.
172
+
173
+ Args:
174
+ value: New activation_error_method to set to the node activation configuration.
175
+
176
+ """
177
+ self._activation_error_method = value
178
+ self.activation_quantization_params_fn = get_activation_quantization_params_fn(activation_quantization_method=self.activation_quantization_method)
179
+
180
+ def set_activation_quantization_fn(self, activation_quantization_fn: Callable):
181
+ """
182
+ Sets activation quantization function for the node.
183
+
184
+ Args:
185
+ activation_quantization_fn: Function for quantazing the activations.
186
+
187
+ """
188
+ self.activation_quantization_fn = activation_quantization_fn
189
+
190
+ def set_activation_quantization_params_fn(self, activation_quantization_params_fn:Callable):
191
+ """
192
+ Sets activation params function for the node.
193
+
194
+ Args:
195
+ activation_quantization_params_fn: Function for calculating activation params.
196
+
197
+ """
198
+ self.activation_quantization_params_fn = activation_quantization_params_fn
199
+
116
200
  def set_activation_quantization_param(self,
117
201
  activation_params: dict):
118
202
  """
@@ -122,7 +206,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
122
206
  activation_params: Dictionary that contains weight quantization params.
123
207
 
124
208
  """
125
- assert self.quant_mode == ActivationQuantizationMode.QUANT or self.quant_mode == ActivationQuantizationMode.FLN_QUANT
209
+ assert self.quant_mode == ActivationQuantizationMode.QUANT
126
210
  for param_name, param_value in activation_params.items():
127
211
  self.activation_quantization_params[param_name] = param_value
128
212
 
@@ -139,16 +223,36 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
139
223
  if not isinstance(other, NodeActivationQuantizationConfig):
140
224
  return False # pragma: no cover
141
225
 
142
- return self.activation_quantization_method == other.activation_quantization_method and \
226
+ return self.activation_quantization_fn == other.activation_quantization_fn and \
227
+ self.activation_quantization_params_fn == other.activation_quantization_params_fn and \
228
+ self.activation_error_method == other.activation_error_method and \
229
+ self.activation_quantization_method == other.activation_quantization_method and \
143
230
  self.activation_n_bits == other.activation_n_bits and \
144
231
  self.quant_mode == other.quant_mode and \
145
- self.signedness == other.signedness
232
+ self.activation_channel_equalization == other.activation_channel_equalization and \
233
+ self.input_scaling == other.input_scaling and \
234
+ self.min_threshold == other.min_threshold and \
235
+ self.l_p_value == other.l_p_value and \
236
+ self.shift_negative_activation_correction == other.shift_negative_activation_correction and \
237
+ self.z_threshold == other.z_threshold and \
238
+ self.shift_negative_ratio == other.shift_negative_ratio and \
239
+ self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
146
240
 
147
241
  def __hash__(self):
148
- return hash((self.activation_quantization_method,
242
+ return hash((self.activation_quantization_fn,
243
+ self.activation_quantization_params_fn,
244
+ self.activation_error_method,
245
+ self.activation_quantization_method,
149
246
  self.activation_n_bits,
150
247
  self.quant_mode,
151
- self.signedness))
248
+ self.activation_channel_equalization,
249
+ self.input_scaling,
250
+ self.min_threshold,
251
+ self.l_p_value,
252
+ self.shift_negative_activation_correction,
253
+ self.z_threshold,
254
+ self.shift_negative_ratio,
255
+ self.shift_negative_threshold_recalculation))
152
256
 
153
257
 
154
258
  class WeightsAttrQuantizationConfig:
@@ -156,21 +260,65 @@ class WeightsAttrQuantizationConfig:
156
260
  Configuration for quantizing a weights attribute of a node.
157
261
  """
158
262
  def __init__(self,
263
+ qc: QuantizationConfig,
159
264
  weights_attr_cfg: AttributeQuantizationConfig,
160
- weights_channels_axis: ChannelAxisMapping = None):
265
+ weights_channels_axis: Tuple[int, int] = None):
161
266
  """
162
267
 
163
268
  Args:
269
+ qc: QuantizationConfig to create the node's config from.
164
270
  weights_attr_cfg: AttributeQuantizationConfig with parameters to use when creating the node's attribute quantization config.
165
271
  weights_channels_axis: Axis to quantize a node's attribute when quantizing per-channel (if not quantizing per-channel than expecting None).
166
272
  """
273
+ self.weights_quantization_fn = get_weights_quantization_fn(weights_attr_cfg.weights_quantization_method)
274
+ self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_attr_cfg.weights_quantization_method)
167
275
  self.weights_channels_axis = weights_channels_axis
276
+ self.weights_quantization_params = {}
168
277
  self.weights_quantization_method = weights_attr_cfg.weights_quantization_method
278
+ self.weights_error_method = qc.weights_error_method
169
279
  self.weights_n_bits = weights_attr_cfg.weights_n_bits
170
280
  self.weights_per_channel_threshold = weights_attr_cfg.weights_per_channel_threshold
171
281
  self.enable_weights_quantization = weights_attr_cfg.enable_weights_quantization
282
+ self.l_p_value = qc.l_p_value
172
283
 
173
- self.weights_quantization_params = {}
284
+ @property
285
+ def weights_error_method(self) -> QuantizationErrorMethod:
286
+ """
287
+ weights_error_method getter.
288
+ """
289
+ return self._weights_error_method
290
+
291
+ @weights_error_method.setter
292
+ def weights_error_method(self, value: QuantizationErrorMethod):
293
+ """
294
+ weights_error_method setter.
295
+
296
+ Args:
297
+ value: New weights_error_method to set to the node weights configuration.
298
+
299
+ """
300
+ self._weights_error_method = value
301
+ self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_quantization_method=self.weights_quantization_method)
302
+
303
+ def set_weights_quantization_fn(self, weights_quantization_fn: Callable):
304
+ """
305
+ Sets weights quantization function for the node.
306
+
307
+ Args:
308
+ weights_quantization_fn: Function for quantazing the weights.
309
+
310
+ """
311
+ self.weights_quantization_fn = weights_quantization_fn
312
+
313
+ def set_weights_quantization_params_fn(self, weights_quantization_params_fn: Callable):
314
+ """
315
+ Sets weights params function for the node.
316
+
317
+ Args:
318
+ weights_quantization_params_fn: Function for calculating the weights params.
319
+
320
+ """
321
+ self.weights_quantization_params_fn = weights_quantization_params_fn
174
322
 
175
323
  def set_weights_quantization_param(self,
176
324
  weights_params: dict):
@@ -185,6 +333,31 @@ class WeightsAttrQuantizationConfig:
185
333
  for param_name, param_value in weights_params.items():
186
334
  self.weights_quantization_params[param_name] = param_value
187
335
 
336
+ def calculate_and_set_weights_params(self, tensor_data: np.ndarray, min_threshold: float):
337
+ """
338
+ Args:
339
+ tensor_data: Tensor content as Numpy array.
340
+ min_threshold: A minimal threshold to set as quantization parameter.
341
+
342
+ Returns:
343
+ Recalculated weights quantization params from the kernel and channel axis.
344
+
345
+ """
346
+ assert self.enable_weights_quantization
347
+ assert not (self.weights_per_channel_threshold and self.weights_channels_axis is None), \
348
+ "Trying to calculate threshold per channel, channel axis in None."
349
+ if self.weights_quantization_params_fn is not None:
350
+ self.set_weights_quantization_param(
351
+ self.weights_quantization_params_fn(tensor_data,
352
+ p=self.l_p_value,
353
+ n_bits=self.weights_n_bits,
354
+ per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
355
+ channel_axis=self.weights_channels_axis[0], # output channel axis
356
+ min_threshold=min_threshold)[0] # Take only first output, the q-params, as axis is already chosen.
357
+ )
358
+ else:
359
+ self.set_weights_quantization_param({})
360
+
188
361
  def __eq__(self, other: Any) -> bool:
189
362
  """
190
363
  Compares the object to another object to find if they are equal.
@@ -198,18 +371,26 @@ class WeightsAttrQuantizationConfig:
198
371
  if not isinstance(other, WeightsAttrQuantizationConfig):
199
372
  return False # pragma: no cover
200
373
 
201
- return self.weights_channels_axis == other.weights_channels_axis and \
374
+ return self.weights_quantization_fn == other.weights_quantization_fn and \
375
+ self.weights_quantization_params_fn == other.weights_quantization_params_fn and \
376
+ self.weights_channels_axis == other.weights_channels_axis and \
377
+ self.weights_error_method == other.weights_error_method and \
202
378
  self.weights_quantization_method == other.weights_quantization_method and \
203
379
  self.weights_n_bits == other.weights_n_bits and \
204
380
  self.weights_per_channel_threshold == other.weights_per_channel_threshold and \
205
- self.enable_weights_quantization == other.enable_weights_quantization
381
+ self.enable_weights_quantization == other.enable_weights_quantization and \
382
+ self.l_p_value == other.l_p_value
206
383
 
207
384
  def __hash__(self):
208
- return hash((self.weights_channels_axis,
385
+ return hash((self.weights_quantization_fn,
386
+ self.weights_quantization_params_fn,
387
+ self.weights_channels_axis,
388
+ self.weights_error_method,
209
389
  self.weights_quantization_method,
210
390
  self.weights_n_bits,
211
391
  self.weights_per_channel_threshold,
212
- self.enable_weights_quantization))
392
+ self.enable_weights_quantization,
393
+ self.l_p_value))
213
394
 
214
395
 
215
396
  class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
@@ -217,19 +398,23 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
217
398
  Holding a mapping between the node's weights attributes and their quantization configurations,
218
399
  in addition to quantization parameters that are global for all attributes of the represented node.
219
400
  """
220
- def __init__(self,
401
+ def __init__(self, qc: QuantizationConfig,
221
402
  op_cfg: OpQuantizationConfig,
222
- weights_channels_axis: ChannelAxisMapping,
403
+ weights_channels_axis: Tuple[int, int],
223
404
  node_attrs_list: List[str]):
224
405
  """
225
406
 
226
407
  Args:
408
+ qc: QuantizationConfig to create the node's config from.
227
409
  op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
228
410
  weights_channels_axis: Axis to quantize a node's weights attribute when quantizing per-channel.
229
411
  node_attrs_list: A list of the node's weights attributes names.
230
412
 
231
413
  """
414
+ self.min_threshold = qc.min_threshold
232
415
  self.simd_size = op_cfg.simd_size
416
+ self.weights_second_moment_correction = qc.weights_second_moment_correction
417
+ self.weights_bias_correction = qc.weights_bias_correction
233
418
 
234
419
  # Initialize a quantization configuration for each of the node's attributes
235
420
  self.attributes_config_mapping = {}
@@ -241,7 +426,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
241
426
  # POS_ATTR string. If none are found, it indicates that no specific quantization config is defined for
242
427
  # positional weights, so the default config will be used instead.
243
428
  attrs_included_in_name = {k: v for k, v in op_cfg.attr_weights_configs_mapping.items() if
244
- POSITIONAL_ATTR in k}
429
+ POS_ATTR in k}
245
430
 
246
431
  if len(attrs_included_in_name) > 1: # pragma: no cover
247
432
  raise ValueError(f"Found multiple attribute in FQC OpConfig that are contained "
@@ -257,7 +442,8 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
257
442
  attr_cfg = list(attrs_included_in_name.values())[0]
258
443
 
259
444
  # Register this attribute under the positional attributes config mapping.
260
- self.pos_attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(weights_attr_cfg=attr_cfg,
445
+ self.pos_attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(qc=qc,
446
+ weights_attr_cfg=attr_cfg,
261
447
  weights_channels_axis=
262
448
  weights_channels_axis)
263
449
  else:
@@ -274,16 +460,9 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
274
460
  else:
275
461
  attr_cfg = list(attrs_included_in_name.values())[0]
276
462
 
277
- self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(weights_attr_cfg=attr_cfg,
463
+ self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(qc=qc,
464
+ weights_attr_cfg=attr_cfg,
278
465
  weights_channels_axis=weights_channels_axis)
279
- # TODO this is set by batch norm reconstruction substitution when folded batch norms are added back, to mark
280
- # the nodes that the correction should be applied to (for some nodes it gets disabled) and BNs removed.
281
- # The actual correction is only computed when it's applied in ptq, so it seems that both substitutions could
282
- # be unified, and no info need to pass between.
283
- self.weights_second_moment_correction = None
284
- # TODO: computed corrected bias is injected to the node config. Probably shouldn't be here. Also it can be
285
- # computed on the final config, instead of all candidates and then there is no need to save it at all.
286
- self.bias_corrected = None
287
466
 
288
467
  def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationConfig:
289
468
  """
@@ -420,8 +599,8 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
420
599
  if hasattr(attr_cfg, config_parameter_name):
421
600
  setattr(attr_cfg, config_parameter_name, config_parameter_value)
422
601
  else:
423
- raise AttributeError(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
424
- f"weights attribute {attr_name}.")
602
+ Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
603
+ f"weights attribute {attr_name} and was not updated!")
425
604
  else: # pragma: no cover
426
605
  Logger.critical(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
427
606
 
@@ -438,7 +617,10 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
438
617
  if not isinstance(other, NodeWeightsQuantizationConfig):
439
618
  return False # pragma: no cover
440
619
 
441
- return self.simd_size == other.simd_size and \
620
+ return self.min_threshold == other.min_threshold and \
621
+ self.simd_size == other.simd_size and \
622
+ self.weights_second_moment_correction == other.weights_second_moment_correction and \
623
+ self.weights_bias_correction == other.weights_bias_correction and \
442
624
  self.attributes_config_mapping.keys() == other.attributes_config_mapping.keys() and \
443
625
  all([self.attributes_config_mapping[k] == other.attributes_config_mapping[k]
444
626
  for k in self.attributes_config_mapping.keys()]) and \
@@ -447,6 +629,9 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
447
629
  for k in self.pos_attributes_config_mapping.keys()])
448
630
 
449
631
  def __hash__(self):
450
- return hash((self.simd_size,
632
+ return hash((self.min_threshold,
633
+ self.simd_size,
634
+ self.weights_second_moment_correction,
635
+ self.weights_bias_correction,
451
636
  frozenset(self.attributes_config_mapping),
452
637
  frozenset(self.pos_attributes_config_mapping)))
@@ -90,6 +90,7 @@ class QuantizationConfig:
90
90
  shift_negative_activation_correction: bool = True
91
91
  activation_channel_equalization: bool = False
92
92
  z_threshold: float = math.inf
93
+ min_threshold: float = MIN_THRESHOLD
93
94
  l_p_value: int = 2
94
95
  linear_collapsing: bool = True
95
96
  residual_collapsing: bool = True
@@ -14,35 +14,15 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from collections.abc import Callable
17
+ from functools import partial
17
18
 
18
19
  from mct_quantizers import QuantizationMethod
19
-
20
- from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
21
20
  from model_compression_toolkit.logger import Logger
22
21
  from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
23
22
  from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
24
23
  symmetric_quantizer, uniform_quantizer
25
24
 
26
25
 
27
- def get_activation_quantization_fn(activation_quantization_cfg: NodeActivationQuantizationConfig,
28
- get_activation_quantization_fn_factory: Callable) -> Callable:
29
- """
30
- Get activation quantizer based on activation quantization configuration.
31
-
32
- Args:
33
- activation_quantization_cfg: activation quantization configuration.
34
- get_activation_quantization_fn_factory: activation quantization functions factory.
35
-
36
- Returns:
37
- Activation quantizer that accepts a tensor and returns a quantized tensor.
38
- """
39
- quantizer_factory = get_activation_quantization_fn_factory(
40
- activation_quantization_cfg.activation_quantization_method)
41
- quantizer = quantizer_factory(activation_quantization_cfg.activation_n_bits,
42
- activation_quantization_cfg.activation_quantization_params)
43
- return quantizer
44
-
45
-
46
26
  def get_weights_quantization_fn(weights_quantization_method: QuantizationMethod) -> Callable:
47
27
  """
48
28
  Generate a function for weight quantization.
@@ -0,0 +1,78 @@
1
+ # Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from collections.abc import Callable
17
+ from functools import partial
18
+
19
+ from mct_quantizers import QuantizationMethod
20
+ from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import \
22
+ lut_kmeans_tensor, lut_kmeans_histogram
23
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import \
24
+ symmetric_selection_tensor, symmetric_selection_histogram
25
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import \
26
+ uniform_selection_histogram, uniform_selection_tensor
27
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import \
28
+ power_of_two_selection_tensor, power_of_two_selection_histogram
29
+
30
+
31
+ def get_activation_quantization_params_fn(activation_quantization_method: QuantizationMethod) -> Callable:
32
+ """
33
+ Generate a function for finding activation quantization parameters.
34
+
35
+ Args:
36
+ activation_quantization_method: Which quantization method to use for activations.
37
+ Returns:
38
+ A function to find the quantization parameters.
39
+
40
+ """
41
+ if activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
42
+ params_fn = power_of_two_selection_histogram
43
+ elif activation_quantization_method == QuantizationMethod.SYMMETRIC:
44
+ params_fn = symmetric_selection_histogram
45
+ elif activation_quantization_method == QuantizationMethod.UNIFORM:
46
+ params_fn = uniform_selection_histogram
47
+ elif activation_quantization_method == QuantizationMethod.LUT_POT_QUANTIZER:
48
+ params_fn = lut_kmeans_histogram
49
+ else:
50
+ Logger.critical(
51
+ f"No parameter function found for the specified quantization method: {activation_quantization_method}") # pragma: no cover
52
+ return params_fn
53
+
54
+
55
+ def get_weights_quantization_params_fn(weights_quantization_method: QuantizationMethod) -> Callable:
56
+ """
57
+ Generate a function for finding weights quantization parameters.
58
+
59
+ Args:
60
+ weights_quantization_method: Which quantization method to use for weights.
61
+ Returns:
62
+ A function to find the quantization parameters.
63
+
64
+ """
65
+ if weights_quantization_method == QuantizationMethod.POWER_OF_TWO:
66
+ params_fn = power_of_two_selection_tensor
67
+ elif weights_quantization_method == QuantizationMethod.SYMMETRIC:
68
+ params_fn = symmetric_selection_tensor
69
+ elif weights_quantization_method == QuantizationMethod.UNIFORM:
70
+ params_fn = uniform_selection_tensor
71
+ elif weights_quantization_method == QuantizationMethod.LUT_POT_QUANTIZER:
72
+ params_fn = partial(lut_kmeans_tensor, is_symmetric=False)
73
+ elif weights_quantization_method == QuantizationMethod.LUT_SYM_QUANTIZER:
74
+ params_fn = partial(lut_kmeans_tensor, is_symmetric=True)
75
+ else:
76
+ Logger.critical(
77
+ f"No parameter function found for the specified quantization method: {weights_quantization_method}") # pragma: no cover
78
+ return params_fn
@@ -12,12 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import (
16
- power_of_two_no_clipping_selection_min_max, power_of_two_selection_histogram, power_of_two_selection_tensor)
17
- from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import (
18
- lut_kmeans_tensor, lut_kmeans_histogram)
19
- from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import (
20
- symmetric_no_clipping_selection_min_max, symmetric_selection_histogram, symmetric_selection_tensor)
21
- from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import (
22
- uniform_no_clipping_selection_min_max, uniform_selection_histogram, uniform_selection_tensor)
15
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import power_of_two_no_clipping_selection_min_max, \
16
+ power_of_two_selection_histogram, power_of_two_selection_tensor
17
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import lut_kmeans_tensor
18
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import symmetric_no_clipping_selection_min_max
19
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import uniform_no_clipping_selection_min_max
23
20
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.outlier_filter import z_score_filter