mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -13,98 +13,47 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import numpy as np
16
- from typing import Dict, Union, Optional, Tuple, Callable
16
+ from typing import Dict, Union, Optional, Tuple
17
17
 
18
18
  from mct_quantizers import QuantizationMethod
19
-
20
- import model_compression_toolkit.core.common.quantization.quantization_params_generation as qpg
21
- from model_compression_toolkit.constants import MIN_THRESHOLD
19
+ from model_compression_toolkit.core import QuantizationErrorMethod
22
20
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
23
21
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
22
+ from model_compression_toolkit.core.common.quantization import quantization_params_generation
24
23
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
25
24
  from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
26
- from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationErrorMethod, \
27
- QuantizationConfig
28
-
29
-
30
- def compute_activation_qparams(quant_cfg: QuantizationConfig,
31
- node_activation_quant_cfg: NodeActivationQuantizationConfig,
32
- node_prior_info: NodePriorInfo,
33
- out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
34
- """
35
- Compute the activations params for a given node in a graph according to a params function.
36
-
37
- Args:
38
- quant_cfg: quantization config.
39
- node_activation_quant_cfg: node's activation quantization configuration.
40
- node_prior_info: Prior info collected for the node that is being quantized.
41
- out_stats_container: Tensor containing output statistics of the node.
42
-
43
- Returns:
44
- The computed activation quantization params.
45
- """
46
- activation_quantization_params_fn = _get_activation_quantization_params_fn(
47
- node_activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())
48
-
49
- # Extract and filter histogram data from the statistics container.
50
- z_threshold = quant_cfg.z_threshold
51
- if node_activation_quant_cfg.z_threshold is not None:
52
- z_threshold = node_activation_quant_cfg.z_threshold
53
- bins_values, bins_counts = _get_histogram_data(out_stats_container,
54
- activation_error_method=quant_cfg.activation_error_method,
55
- z_threshold=z_threshold)
56
-
57
- # Retrieve the minimum and maximum values from the statistics container.
58
- min_value, max_value = out_stats_container.get_min_max_values()
59
-
60
- # Determine if the activations should be considered signed.
61
- signed = _determine_signedness(node_activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts)
62
25
 
63
- # Compute and return the activation quantization parameters.
64
- return activation_quantization_params_fn(
65
- bins_values,
66
- bins_counts,
67
- quant_cfg.l_p_value,
68
- node_activation_quant_cfg.activation_n_bits,
69
- min_value,
70
- max_value,
71
- min_threshold=MIN_THRESHOLD,
72
- quant_error_method=quant_cfg.activation_error_method,
73
- is_signed=signed
74
- )
75
-
76
-
77
- def _get_histogram_data(out_stats_container: BaseStatsCollector,
78
- activation_error_method: QuantizationErrorMethod,
79
- z_threshold: float) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
26
+ def get_histogram_data(
27
+ activation_quant_cfg: NodeActivationQuantizationConfig,
28
+ out_stats_container: BaseStatsCollector
29
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
80
30
  """
81
31
  Extract and filter the histogram data from the statistics container.
82
32
 
83
33
  Args:
34
+ activation_quant_cfg: Node's activation quantization configuration.
84
35
  out_stats_container: Statistics container with histogram data.
85
- activation_error_method: activation quantization error method.
86
- z_threshold: z threshold for z-score filtering.
87
36
 
88
37
  Returns:
89
38
  A tuple containing the filtered bins_values and bins_counts.
90
39
  """
91
40
  bins_values, bins_counts = None, None
41
+
92
42
  # If the statistics container collected the histogram, we start by filtering outliers using z threshold
93
43
  # filtering, and then computing the threshold based on the filtered histogram.
94
44
  if out_stats_container.require_collection():
95
- if activation_error_method == QuantizationErrorMethod.HMSE:
45
+ if activation_quant_cfg.activation_error_method == QuantizationErrorMethod.HMSE:
96
46
  bins_values, bins_counts = out_stats_container.weighted_hc.get_histogram()
97
47
  else:
98
48
  bins_values, bins_counts = out_stats_container.hc.get_histogram()
99
- bins_counts = qpg.z_score_filter(
100
- z_threshold,
49
+ bins_counts = quantization_params_generation.z_score_filter(
50
+ activation_quant_cfg.z_threshold,
101
51
  bins_values,
102
52
  bins_counts
103
53
  )
104
54
  return bins_values, bins_counts
105
55
 
106
-
107
- def _determine_signedness(
56
+ def determine_signedness(
108
57
  activation_quant_cfg: NodeActivationQuantizationConfig,
109
58
  nodes_prior_info: NodePriorInfo,
110
59
  min_value: float,
@@ -134,37 +83,73 @@ def _determine_signedness(
134
83
  return np.any(bins_values[:-1][bins_counts > 0] < 0)
135
84
 
136
85
 
137
- _activation_quant_params_fns = {
138
- QuantizationMethod.POWER_OF_TWO: qpg.power_of_two_selection_histogram,
139
- QuantizationMethod.SYMMETRIC: qpg.symmetric_selection_histogram,
140
- QuantizationMethod.UNIFORM: qpg.uniform_selection_histogram,
141
- QuantizationMethod.LUT_POT_QUANTIZER: qpg.lut_kmeans_histogram
142
- }
143
- _activation_no_clipping_quant_params_fns = {
144
- QuantizationMethod.POWER_OF_TWO: qpg.power_of_two_no_clipping_selection_min_max,
145
- QuantizationMethod.SYMMETRIC: qpg.symmetric_no_clipping_selection_min_max,
146
- QuantizationMethod.UNIFORM: qpg.uniform_no_clipping_selection_min_max,
147
- QuantizationMethod.LUT_POT_QUANTIZER: qpg.lut_kmeans_histogram
148
- }
149
-
86
+ def update_activation_quantization_params_fn(
87
+ activation_quant_cfg: NodeActivationQuantizationConfig,
88
+ nodes_prior_info: NodePriorInfo):
89
+ """
90
+ Update the activation quantization parameters function based on the quantization method
91
+ and whether the node's output is bounded.
150
92
 
151
- def _get_activation_quantization_params_fn(activation_quantization_method: QuantizationMethod,
152
- no_clipping: bool) -> Callable:
93
+ Args:
94
+ activation_quant_cfg: Node's activation quantization configuration.
95
+ nodes_prior_info: Prior info collected for the node that is being quantized.
153
96
  """
154
- Generate a function for finding activation quantization parameters.
97
+ if nodes_prior_info.is_output_bounded():
98
+ if activation_quant_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
99
+ activation_quant_cfg.set_activation_quantization_params_fn(
100
+ quantization_params_generation.power_of_two_no_clipping_selection_min_max
101
+ )
102
+ elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.SYMMETRIC:
103
+ activation_quant_cfg.set_activation_quantization_params_fn(
104
+ quantization_params_generation.symmetric_no_clipping_selection_min_max
105
+ )
106
+ elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.UNIFORM:
107
+ activation_quant_cfg.set_activation_quantization_params_fn(
108
+ quantization_params_generation.uniform_no_clipping_selection_min_max
109
+ )
110
+
111
+
112
+ def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
113
+ nodes_prior_info: NodePriorInfo,
114
+ out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
115
+ """
116
+ Compute the activations params for a given node in a graph according to a params function.
155
117
 
156
118
  Args:
157
- activation_quantization_method: Which quantization method to use for activations.
158
- no_clipping: Whether to use the no-clipping version of the quantizer (if available).
119
+ activation_quant_cfg: node's activation quantization configuration.
120
+ nodes_prior_info: Prior info collected for the node that is being quantized.
121
+ out_stats_container: Tensor containing output statistics of the node.
159
122
 
160
123
  Returns:
161
- A function to find the quantization parameters.
124
+ The computed activation quantization params.
162
125
  """
163
- if no_clipping:
164
- params_fn = _activation_no_clipping_quant_params_fns.get(activation_quantization_method)
165
- else:
166
- params_fn = _activation_quant_params_fns.get(activation_quantization_method)
167
- if params_fn is None:
168
- raise ValueError(f"No parameter function found for the specified quantization method: "
169
- "{activation_quantization_method}") # pragma: no cover
170
- return params_fn
126
+ # Update quantization parameters function based on output bounds and quantization method.
127
+ update_activation_quantization_params_fn(activation_quant_cfg, nodes_prior_info)
128
+
129
+ # Extract and filter histogram data from the statistics container.
130
+ bins_values, bins_counts = get_histogram_data(activation_quant_cfg, out_stats_container)
131
+
132
+ # Retrieve the minimum and maximum values from the statistics container.
133
+ min_value, max_value = out_stats_container.get_min_max_values()
134
+
135
+ # Determine if the activations should be considered signed.
136
+ signed = determine_signedness(
137
+ activation_quant_cfg,
138
+ nodes_prior_info,
139
+ min_value,
140
+ bins_values,
141
+ bins_counts
142
+ )
143
+
144
+ # Compute and return the activation quantization parameters.
145
+ return activation_quant_cfg.activation_quantization_params_fn(
146
+ bins_values,
147
+ bins_counts,
148
+ activation_quant_cfg.l_p_value,
149
+ activation_quant_cfg.activation_n_bits,
150
+ min_value,
151
+ max_value,
152
+ min_threshold=activation_quant_cfg.min_threshold,
153
+ quant_error_method=activation_quant_cfg.activation_error_method,
154
+ is_signed=signed
155
+ )
@@ -18,21 +18,44 @@ from tqdm import tqdm
18
18
  from typing import List, Callable, Generator
19
19
 
20
20
  from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
21
- from model_compression_toolkit.core import QuantizationErrorMethod, QuantizationConfig
21
+ from model_compression_toolkit.core import QuantizationErrorMethod
22
22
  from model_compression_toolkit.core.common import Graph, BaseNode
23
- from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
24
23
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
25
24
  from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
26
25
  HessianScoresGranularity
27
26
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
28
- import compute_activation_qparams
27
+ import get_activations_qparams
29
28
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
30
- compute_weights_qparams
29
+ get_weights_qparams
31
30
  from model_compression_toolkit.logger import Logger
32
31
 
33
32
 
33
+ def _collect_nodes_for_hmse(nodes_list: List[BaseNode], graph: Graph) -> List[BaseNode]:
34
+ """
35
+ Collects nodes that are compatiable for parameters selection search using HMSE,
36
+ that is, have a kernel attribute that is configured for HMSE error method.
37
+
38
+ Args:
39
+ nodes_list: A list of nodes to search quantization parameters for.
40
+ graph: Graph to compute its nodes' quantization parameters..
41
+
42
+ Returns: A (possibly empty) list of nodes.
43
+
44
+ """
45
+ hmse_nodes = []
46
+ for n in nodes_list:
47
+ kernel_attr_name = graph.fw_info.get_kernel_op_attributes(n.type)
48
+ kernel_attr_name = None if kernel_attr_name is None or len(kernel_attr_name) == 0 else kernel_attr_name[0]
49
+
50
+ if kernel_attr_name is not None and n.is_weights_quantization_enabled(kernel_attr_name) and \
51
+ all([c.weights_quantization_cfg.get_attr_config(kernel_attr_name).weights_error_method ==
52
+ QuantizationErrorMethod.HMSE for c in n.candidates_quantization_cfg]):
53
+ hmse_nodes.append(n)
54
+
55
+ return hmse_nodes
56
+
57
+
34
58
  def calculate_quantization_params(graph: Graph,
35
- quant_cfg: QuantizationConfig,
36
59
  fw_impl: FrameworkImplementation,
37
60
  repr_data_gen_fn: Callable[[], Generator],
38
61
  nodes: List[BaseNode] = None,
@@ -47,7 +70,6 @@ def calculate_quantization_params(graph: Graph,
47
70
 
48
71
  Args:
49
72
  graph: Graph to compute its nodes' thresholds.
50
- quant_cfg: quantization config.
51
73
  fw_impl: FrameworkImplementation object.
52
74
  repr_data_gen_fn: callable returning representative dataset generator.
53
75
  nodes: List of nodes to compute their thresholds instead of computing it for all nodes in the graph.
@@ -65,16 +87,15 @@ def calculate_quantization_params(graph: Graph,
65
87
  # Collecting nodes that are configured to search weights quantization parameters using HMSE optimization
66
88
  # and computing required Hessian information to be used for HMSE parameters selection.
67
89
  # The Hessian scores are computed and stored in the hessian_info_service object.
68
- if quant_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
69
- nodes_for_hmse = [n for n in nodes_list if n.kernel_attr and n.is_weights_quantization_enabled(n.kernel_attr)]
70
- if nodes_for_hmse:
71
- dataloader = fw_impl.convert_data_gen_to_dataloader(repr_data_gen_fn, batch_size=1)
72
- request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
73
- granularity=HessianScoresGranularity.PER_ELEMENT,
74
- data_loader=dataloader,
75
- n_samples=num_hessian_samples,
76
- target_nodes=nodes_for_hmse)
77
- hessian_info_service.fetch_hessian(request)
90
+ nodes_for_hmse = _collect_nodes_for_hmse(nodes_list, graph)
91
+ if len(nodes_for_hmse) > 0:
92
+ dataloader = fw_impl.convert_data_gen_to_dataloader(repr_data_gen_fn, batch_size=1)
93
+ request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
94
+ granularity=HessianScoresGranularity.PER_ELEMENT,
95
+ data_loader=dataloader,
96
+ n_samples=num_hessian_samples,
97
+ target_nodes=nodes_for_hmse)
98
+ hessian_info_service.fetch_hessian(request)
78
99
 
79
100
  for n in tqdm(nodes_list, "Calculating quantization parameters"): # iterate only nodes that we should compute their thresholds
80
101
  for candidate_qc in n.candidates_quantization_cfg:
@@ -82,34 +103,43 @@ def calculate_quantization_params(graph: Graph,
82
103
  if n.is_weights_quantization_enabled(attr):
83
104
  # If the node's weights attribute should be quantized, we compute its quantization parameters
84
105
  attr_cfg = candidate_qc.weights_quantization_cfg.get_attr_config(attr)
85
- output_channels_axis = attr_cfg.weights_channels_axis.output
106
+ channels_axis = attr_cfg.weights_channels_axis
107
+ if channels_axis is not None:
108
+ output_channels_axis = channels_axis[0]
109
+ else:
110
+ output_channels_axis = None
86
111
 
87
- weights_error_method = quant_cfg.weights_error_method
88
- if weights_error_method == QuantizationErrorMethod.HMSE:
112
+ mod_attr_cfg = attr_cfg
113
+
114
+ if attr_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
89
115
  # Although we collected nodes for HMSE before running the loop, we keep this verification to
90
116
  # notify the user in case of HMSE configured for node that is not compatible for this method
91
- if n.kernel_attr is None or n.kernel_attr not in attr:
117
+ kernel_attr_name = graph.fw_info.get_kernel_op_attributes(n.type)
118
+ if len(kernel_attr_name) > 0:
119
+ kernel_attr_name = kernel_attr_name[0]
120
+
121
+ if kernel_attr_name is None or kernel_attr_name not in attr:
92
122
  Logger.warning(f"The HMSE error method for parameters selection is only supported for "
93
123
  f"kernel weights attributes. Running parameters selection for attribute "
94
124
  f"'{attr}' in node '{n.name}' with the default MSE error method instead.")
95
- weights_error_method = QuantizationErrorMethod.MSE
96
-
97
- weights_params, output_channels_axis = compute_weights_qparams(n.get_weights_by_keys(attr),
98
- attr_cfg,
99
- weights_error_method,
100
- quant_cfg.l_p_value,
101
- output_channels_axis,
102
- node=n,
103
- hessian_info_service=hessian_info_service,
104
- num_hessian_samples=num_hessian_samples)
105
- attr_cfg.weights_channels_axis = ChannelAxisMapping(output_channels_axis, attr_cfg.weights_channels_axis.input)
125
+ mod_attr_cfg = copy.deepcopy(attr_cfg)
126
+ mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE
127
+
128
+ weights_params, output_channels_axis = get_weights_qparams(n.get_weights_by_keys(attr),
129
+ candidate_qc.weights_quantization_cfg,
130
+ mod_attr_cfg,
131
+ output_channels_axis,
132
+ node=n,
133
+ hessian_info_service=hessian_info_service,
134
+ num_hessian_samples=num_hessian_samples)
135
+ attr_cfg.weights_channels_axis = (output_channels_axis, attr_cfg.weights_channels_axis[1])
106
136
  attr_cfg.set_weights_quantization_param(weights_params)
107
137
 
108
- if n.is_activation_quantization_enabled() or n.is_fln_quantization():
138
+ if n.is_activation_quantization_enabled():
109
139
  # If node's activations should be quantized as well, we compute its activation quantization parameters
110
- activation_params = compute_activation_qparams(quant_cfg=quant_cfg,
111
- node_activation_quant_cfg=candidate_qc.activation_quantization_cfg,
112
- node_prior_info=n.prior_info,
113
- out_stats_container=graph.get_out_stats_collector(n))
140
+ activation_params = get_activations_qparams(
141
+ activation_quant_cfg=candidate_qc.activation_quantization_cfg,
142
+ nodes_prior_info=n.prior_info,
143
+ out_stats_container=graph.get_out_stats_collector(n))
114
144
  # Create a NodeQuantizationConfig containing all quantization params and attach it to the node
115
145
  candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params)
@@ -12,43 +12,35 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from functools import partial
16
- from typing import Dict, Any, Tuple, Callable, TYPE_CHECKING
15
+ from typing import Dict, Any, Tuple
17
16
 
18
17
  import numpy as np
19
- from mct_quantizers import QuantizationMethod
20
18
 
21
- from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES, MIN_THRESHOLD
22
- from model_compression_toolkit.core import QuantizationErrorMethod
19
+ from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
23
20
  from model_compression_toolkit.core.common.hessian import HessianInfoService
24
- from model_compression_toolkit.core.common.quantization.quantization_params_generation import \
25
- power_of_two_selection_tensor, lut_kmeans_tensor, symmetric_selection_tensor, uniform_selection_tensor
21
+ from model_compression_toolkit.defaultdict import DefaultDict
22
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
24
+ WeightsAttrQuantizationConfig
26
25
  from model_compression_toolkit.logger import Logger
27
26
 
28
- if TYPE_CHECKING:
29
- from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
30
27
 
31
-
32
- def compute_weights_qparams(weights_attr_data: np.ndarray,
33
- attr_quant_config: 'WeightsAttrQuantizationConfig',
34
- weights_error_method: QuantizationErrorMethod,
35
- l_p_value: int,
36
- output_channels_axis: int,
37
- min_threshold: float = MIN_THRESHOLD,
38
- node=None,
39
- hessian_info_service: HessianInfoService = None,
40
- num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
28
+ def get_weights_qparams(weights_attr_values: np.ndarray,
29
+ weights_quant_config: NodeWeightsQuantizationConfig,
30
+ attr_quant_config: WeightsAttrQuantizationConfig,
31
+ output_channels_axis: int,
32
+ node=None,
33
+ hessian_info_service: HessianInfoService = None,
34
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
41
35
  """
42
36
  Compute thresholds to quantize a kernel according to a NodeWeightsQuantizationConfig
43
37
  instance.
44
38
 
45
39
  Args:
46
- weights_attr_data: Weights attribute parameter to compute the quantization thresholds for.
40
+ weights_attr_values: Weights attribute parameter to compute the quantization thresholds for.
41
+ weights_quant_config: Weights quantization configuration to define how the thresholds are computed.
47
42
  attr_quant_config: A specific weights attribute quantization configuration to get its params.
48
- weights_error_method: quantization error method.
49
- l_p_value: p-norm to use for the Lp-norm distance.
50
43
  output_channels_axis: Index of the kernel output channels dimension.
51
- min_threshold: Minimal threshold to use if threshold is too small.
52
44
  node: The node for which the quantization error is computed (used only with HMSE error method).
53
45
  hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (used only with HMSE error method).
54
46
  num_hessian_samples: Number of samples to approximate Hessian-based scores on (used only with HMSE error method).
@@ -57,43 +49,22 @@ def compute_weights_qparams(weights_attr_data: np.ndarray,
57
49
  A dictionary with the quantization threshold of the kernel.
58
50
  Selected quantization channel axis.
59
51
  """
60
- params_fn = _get_weights_quantization_params_fn(attr_quant_config.weights_quantization_method)
61
- weights_params, output_channels_axis = params_fn(
62
- weights_attr_data,
63
- p=l_p_value,
64
- n_bits=attr_quant_config.weights_n_bits,
65
- per_channel=attr_quant_config.weights_per_channel_threshold,
66
- channel_axis=output_channels_axis,
67
- min_threshold=min_threshold,
68
- quant_error_method=weights_error_method,
69
- node=node,
70
- hessian_info_service=hessian_info_service,
71
- num_hessian_samples=num_hessian_samples)
52
+ if attr_quant_config.weights_quantization_params_fn is not None:
53
+ weights_params, output_channels_axis = attr_quant_config.weights_quantization_params_fn(
54
+ weights_attr_values,
55
+ p=attr_quant_config.l_p_value,
56
+ n_bits=attr_quant_config.weights_n_bits,
57
+ per_channel=attr_quant_config.weights_per_channel_threshold,
58
+ channel_axis=output_channels_axis,
59
+ min_threshold=weights_quant_config.min_threshold,
60
+ quant_error_method=attr_quant_config.weights_error_method,
61
+ node=node,
62
+ hessian_info_service=hessian_info_service,
63
+ num_hessian_samples=num_hessian_samples)
64
+ else: # pragma: no cover
65
+ Logger.error(f"Requested weights quantization parameters computation for node {node.name} without providing a "
66
+ f"weights_quantization_params_fn."
67
+ f"Returning an empty dictionary since no quantization parameters were computed.")
68
+ weights_params = {}
72
69
 
73
70
  return weights_params, output_channels_axis
74
-
75
-
76
- _weights_quant_params_fns = {
77
- QuantizationMethod.POWER_OF_TWO: power_of_two_selection_tensor,
78
- QuantizationMethod.SYMMETRIC: symmetric_selection_tensor,
79
- QuantizationMethod.UNIFORM: uniform_selection_tensor,
80
- QuantizationMethod.LUT_POT_QUANTIZER: partial(lut_kmeans_tensor, is_symmetric=False),
81
- QuantizationMethod.LUT_SYM_QUANTIZER: partial(lut_kmeans_tensor, is_symmetric=True)
82
- }
83
-
84
-
85
- def _get_weights_quantization_params_fn(weights_quantization_method: QuantizationMethod) -> Callable:
86
- """
87
- Generate a function for finding weights quantization parameters.
88
-
89
- Args:
90
- weights_quantization_method: Which quantization method to use for weights.
91
- Returns:
92
- A function to find the quantization parameters.
93
-
94
- """
95
- params_fn = _weights_quant_params_fns.get(weights_quantization_method)
96
- if not params_fn:
97
- Logger.critical(
98
- f"No parameter function found for the specified quantization method: {weights_quantization_method}") # pragma: no cover
99
- return params_fn
@@ -12,7 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
15
+
16
+
16
17
  from model_compression_toolkit.logger import Logger
17
18
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
18
19
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
@@ -46,12 +47,11 @@ def get_quantized_weights_attr_by_qc(attr_name: str,
46
47
  output_channels_axis = None
47
48
 
48
49
  Logger.debug(f'quantizing layer {n.name} attribute {attr_name} with {weights_qc.weights_n_bits} bits')
49
- weights_quantization_fn = get_weights_quantization_fn(weights_qc.weights_quantization_method)
50
- quantized_kernel = weights_quantization_fn(n.get_weights_by_keys(attr_name),
51
- n_bits=weights_qc.weights_n_bits,
52
- signed=True,
53
- quantization_params=weights_qc.weights_quantization_params,
54
- per_channel=weights_qc.weights_per_channel_threshold,
55
- output_channels_axis=output_channels_axis)
50
+ quantized_kernel = weights_qc.weights_quantization_fn(n.get_weights_by_keys(attr_name),
51
+ n_bits=weights_qc.weights_n_bits,
52
+ signed=True,
53
+ quantization_params=weights_qc.weights_quantization_params,
54
+ per_channel=weights_qc.weights_per_channel_threshold,
55
+ output_channels_axis=output_channels_axis)
56
56
 
57
57
  return quantized_kernel, channels_axis