mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -13,98 +13,47 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import numpy as np
|
16
|
-
from typing import Dict, Union, Optional, Tuple
|
16
|
+
from typing import Dict, Union, Optional, Tuple
|
17
17
|
|
18
18
|
from mct_quantizers import QuantizationMethod
|
19
|
-
|
20
|
-
import model_compression_toolkit.core.common.quantization.quantization_params_generation as qpg
|
21
|
-
from model_compression_toolkit.constants import MIN_THRESHOLD
|
19
|
+
from model_compression_toolkit.core import QuantizationErrorMethod
|
22
20
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
|
23
21
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
22
|
+
from model_compression_toolkit.core.common.quantization import quantization_params_generation
|
24
23
|
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
25
24
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
|
26
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationErrorMethod, \
|
27
|
-
QuantizationConfig
|
28
|
-
|
29
|
-
|
30
|
-
def compute_activation_qparams(quant_cfg: QuantizationConfig,
|
31
|
-
node_activation_quant_cfg: NodeActivationQuantizationConfig,
|
32
|
-
node_prior_info: NodePriorInfo,
|
33
|
-
out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
|
34
|
-
"""
|
35
|
-
Compute the activations params for a given node in a graph according to a params function.
|
36
|
-
|
37
|
-
Args:
|
38
|
-
quant_cfg: quantization config.
|
39
|
-
node_activation_quant_cfg: node's activation quantization configuration.
|
40
|
-
node_prior_info: Prior info collected for the node that is being quantized.
|
41
|
-
out_stats_container: Tensor containing output statistics of the node.
|
42
|
-
|
43
|
-
Returns:
|
44
|
-
The computed activation quantization params.
|
45
|
-
"""
|
46
|
-
activation_quantization_params_fn = _get_activation_quantization_params_fn(
|
47
|
-
node_activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())
|
48
|
-
|
49
|
-
# Extract and filter histogram data from the statistics container.
|
50
|
-
z_threshold = quant_cfg.z_threshold
|
51
|
-
if node_activation_quant_cfg.z_threshold is not None:
|
52
|
-
z_threshold = node_activation_quant_cfg.z_threshold
|
53
|
-
bins_values, bins_counts = _get_histogram_data(out_stats_container,
|
54
|
-
activation_error_method=quant_cfg.activation_error_method,
|
55
|
-
z_threshold=z_threshold)
|
56
|
-
|
57
|
-
# Retrieve the minimum and maximum values from the statistics container.
|
58
|
-
min_value, max_value = out_stats_container.get_min_max_values()
|
59
|
-
|
60
|
-
# Determine if the activations should be considered signed.
|
61
|
-
signed = _determine_signedness(node_activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts)
|
62
25
|
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
quant_cfg.l_p_value,
|
68
|
-
node_activation_quant_cfg.activation_n_bits,
|
69
|
-
min_value,
|
70
|
-
max_value,
|
71
|
-
min_threshold=MIN_THRESHOLD,
|
72
|
-
quant_error_method=quant_cfg.activation_error_method,
|
73
|
-
is_signed=signed
|
74
|
-
)
|
75
|
-
|
76
|
-
|
77
|
-
def _get_histogram_data(out_stats_container: BaseStatsCollector,
|
78
|
-
activation_error_method: QuantizationErrorMethod,
|
79
|
-
z_threshold: float) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
26
|
+
def get_histogram_data(
|
27
|
+
activation_quant_cfg: NodeActivationQuantizationConfig,
|
28
|
+
out_stats_container: BaseStatsCollector
|
29
|
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
80
30
|
"""
|
81
31
|
Extract and filter the histogram data from the statistics container.
|
82
32
|
|
83
33
|
Args:
|
34
|
+
activation_quant_cfg: Node's activation quantization configuration.
|
84
35
|
out_stats_container: Statistics container with histogram data.
|
85
|
-
activation_error_method: activation quantization error method.
|
86
|
-
z_threshold: z threshold for z-score filtering.
|
87
36
|
|
88
37
|
Returns:
|
89
38
|
A tuple containing the filtered bins_values and bins_counts.
|
90
39
|
"""
|
91
40
|
bins_values, bins_counts = None, None
|
41
|
+
|
92
42
|
# If the statistics container collected the histogram, we start by filtering outliers using z threshold
|
93
43
|
# filtering, and then computing the threshold based on the filtered histogram.
|
94
44
|
if out_stats_container.require_collection():
|
95
|
-
if activation_error_method == QuantizationErrorMethod.HMSE:
|
45
|
+
if activation_quant_cfg.activation_error_method == QuantizationErrorMethod.HMSE:
|
96
46
|
bins_values, bins_counts = out_stats_container.weighted_hc.get_histogram()
|
97
47
|
else:
|
98
48
|
bins_values, bins_counts = out_stats_container.hc.get_histogram()
|
99
|
-
bins_counts =
|
100
|
-
z_threshold,
|
49
|
+
bins_counts = quantization_params_generation.z_score_filter(
|
50
|
+
activation_quant_cfg.z_threshold,
|
101
51
|
bins_values,
|
102
52
|
bins_counts
|
103
53
|
)
|
104
54
|
return bins_values, bins_counts
|
105
55
|
|
106
|
-
|
107
|
-
def _determine_signedness(
|
56
|
+
def determine_signedness(
|
108
57
|
activation_quant_cfg: NodeActivationQuantizationConfig,
|
109
58
|
nodes_prior_info: NodePriorInfo,
|
110
59
|
min_value: float,
|
@@ -134,37 +83,73 @@ def _determine_signedness(
|
|
134
83
|
return np.any(bins_values[:-1][bins_counts > 0] < 0)
|
135
84
|
|
136
85
|
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
_activation_no_clipping_quant_params_fns = {
|
144
|
-
QuantizationMethod.POWER_OF_TWO: qpg.power_of_two_no_clipping_selection_min_max,
|
145
|
-
QuantizationMethod.SYMMETRIC: qpg.symmetric_no_clipping_selection_min_max,
|
146
|
-
QuantizationMethod.UNIFORM: qpg.uniform_no_clipping_selection_min_max,
|
147
|
-
QuantizationMethod.LUT_POT_QUANTIZER: qpg.lut_kmeans_histogram
|
148
|
-
}
|
149
|
-
|
86
|
+
def update_activation_quantization_params_fn(
|
87
|
+
activation_quant_cfg: NodeActivationQuantizationConfig,
|
88
|
+
nodes_prior_info: NodePriorInfo):
|
89
|
+
"""
|
90
|
+
Update the activation quantization parameters function based on the quantization method
|
91
|
+
and whether the node's output is bounded.
|
150
92
|
|
151
|
-
|
152
|
-
|
93
|
+
Args:
|
94
|
+
activation_quant_cfg: Node's activation quantization configuration.
|
95
|
+
nodes_prior_info: Prior info collected for the node that is being quantized.
|
153
96
|
"""
|
154
|
-
|
97
|
+
if nodes_prior_info.is_output_bounded():
|
98
|
+
if activation_quant_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
|
99
|
+
activation_quant_cfg.set_activation_quantization_params_fn(
|
100
|
+
quantization_params_generation.power_of_two_no_clipping_selection_min_max
|
101
|
+
)
|
102
|
+
elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.SYMMETRIC:
|
103
|
+
activation_quant_cfg.set_activation_quantization_params_fn(
|
104
|
+
quantization_params_generation.symmetric_no_clipping_selection_min_max
|
105
|
+
)
|
106
|
+
elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.UNIFORM:
|
107
|
+
activation_quant_cfg.set_activation_quantization_params_fn(
|
108
|
+
quantization_params_generation.uniform_no_clipping_selection_min_max
|
109
|
+
)
|
110
|
+
|
111
|
+
|
112
|
+
def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
|
113
|
+
nodes_prior_info: NodePriorInfo,
|
114
|
+
out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
|
115
|
+
"""
|
116
|
+
Compute the activations params for a given node in a graph according to a params function.
|
155
117
|
|
156
118
|
Args:
|
157
|
-
|
158
|
-
|
119
|
+
activation_quant_cfg: node's activation quantization configuration.
|
120
|
+
nodes_prior_info: Prior info collected for the node that is being quantized.
|
121
|
+
out_stats_container: Tensor containing output statistics of the node.
|
159
122
|
|
160
123
|
Returns:
|
161
|
-
|
124
|
+
The computed activation quantization params.
|
162
125
|
"""
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
126
|
+
# Update quantization parameters function based on output bounds and quantization method.
|
127
|
+
update_activation_quantization_params_fn(activation_quant_cfg, nodes_prior_info)
|
128
|
+
|
129
|
+
# Extract and filter histogram data from the statistics container.
|
130
|
+
bins_values, bins_counts = get_histogram_data(activation_quant_cfg, out_stats_container)
|
131
|
+
|
132
|
+
# Retrieve the minimum and maximum values from the statistics container.
|
133
|
+
min_value, max_value = out_stats_container.get_min_max_values()
|
134
|
+
|
135
|
+
# Determine if the activations should be considered signed.
|
136
|
+
signed = determine_signedness(
|
137
|
+
activation_quant_cfg,
|
138
|
+
nodes_prior_info,
|
139
|
+
min_value,
|
140
|
+
bins_values,
|
141
|
+
bins_counts
|
142
|
+
)
|
143
|
+
|
144
|
+
# Compute and return the activation quantization parameters.
|
145
|
+
return activation_quant_cfg.activation_quantization_params_fn(
|
146
|
+
bins_values,
|
147
|
+
bins_counts,
|
148
|
+
activation_quant_cfg.l_p_value,
|
149
|
+
activation_quant_cfg.activation_n_bits,
|
150
|
+
min_value,
|
151
|
+
max_value,
|
152
|
+
min_threshold=activation_quant_cfg.min_threshold,
|
153
|
+
quant_error_method=activation_quant_cfg.activation_error_method,
|
154
|
+
is_signed=signed
|
155
|
+
)
|
@@ -18,21 +18,44 @@ from tqdm import tqdm
|
|
18
18
|
from typing import List, Callable, Generator
|
19
19
|
|
20
20
|
from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
|
21
|
-
from model_compression_toolkit.core import QuantizationErrorMethod
|
21
|
+
from model_compression_toolkit.core import QuantizationErrorMethod
|
22
22
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
23
|
-
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
|
24
23
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
25
24
|
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
|
26
25
|
HessianScoresGranularity
|
27
26
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
|
28
|
-
import
|
27
|
+
import get_activations_qparams
|
29
28
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
|
30
|
-
|
29
|
+
get_weights_qparams
|
31
30
|
from model_compression_toolkit.logger import Logger
|
32
31
|
|
33
32
|
|
33
|
+
def _collect_nodes_for_hmse(nodes_list: List[BaseNode], graph: Graph) -> List[BaseNode]:
|
34
|
+
"""
|
35
|
+
Collects nodes that are compatiable for parameters selection search using HMSE,
|
36
|
+
that is, have a kernel attribute that is configured for HMSE error method.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
nodes_list: A list of nodes to search quantization parameters for.
|
40
|
+
graph: Graph to compute its nodes' quantization parameters..
|
41
|
+
|
42
|
+
Returns: A (possibly empty) list of nodes.
|
43
|
+
|
44
|
+
"""
|
45
|
+
hmse_nodes = []
|
46
|
+
for n in nodes_list:
|
47
|
+
kernel_attr_name = graph.fw_info.get_kernel_op_attributes(n.type)
|
48
|
+
kernel_attr_name = None if kernel_attr_name is None or len(kernel_attr_name) == 0 else kernel_attr_name[0]
|
49
|
+
|
50
|
+
if kernel_attr_name is not None and n.is_weights_quantization_enabled(kernel_attr_name) and \
|
51
|
+
all([c.weights_quantization_cfg.get_attr_config(kernel_attr_name).weights_error_method ==
|
52
|
+
QuantizationErrorMethod.HMSE for c in n.candidates_quantization_cfg]):
|
53
|
+
hmse_nodes.append(n)
|
54
|
+
|
55
|
+
return hmse_nodes
|
56
|
+
|
57
|
+
|
34
58
|
def calculate_quantization_params(graph: Graph,
|
35
|
-
quant_cfg: QuantizationConfig,
|
36
59
|
fw_impl: FrameworkImplementation,
|
37
60
|
repr_data_gen_fn: Callable[[], Generator],
|
38
61
|
nodes: List[BaseNode] = None,
|
@@ -47,7 +70,6 @@ def calculate_quantization_params(graph: Graph,
|
|
47
70
|
|
48
71
|
Args:
|
49
72
|
graph: Graph to compute its nodes' thresholds.
|
50
|
-
quant_cfg: quantization config.
|
51
73
|
fw_impl: FrameworkImplementation object.
|
52
74
|
repr_data_gen_fn: callable returning representative dataset generator.
|
53
75
|
nodes: List of nodes to compute their thresholds instead of computing it for all nodes in the graph.
|
@@ -65,16 +87,15 @@ def calculate_quantization_params(graph: Graph,
|
|
65
87
|
# Collecting nodes that are configured to search weights quantization parameters using HMSE optimization
|
66
88
|
# and computing required Hessian information to be used for HMSE parameters selection.
|
67
89
|
# The Hessian scores are computed and stored in the hessian_info_service object.
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
hessian_info_service.fetch_hessian(request)
|
90
|
+
nodes_for_hmse = _collect_nodes_for_hmse(nodes_list, graph)
|
91
|
+
if len(nodes_for_hmse) > 0:
|
92
|
+
dataloader = fw_impl.convert_data_gen_to_dataloader(repr_data_gen_fn, batch_size=1)
|
93
|
+
request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
|
94
|
+
granularity=HessianScoresGranularity.PER_ELEMENT,
|
95
|
+
data_loader=dataloader,
|
96
|
+
n_samples=num_hessian_samples,
|
97
|
+
target_nodes=nodes_for_hmse)
|
98
|
+
hessian_info_service.fetch_hessian(request)
|
78
99
|
|
79
100
|
for n in tqdm(nodes_list, "Calculating quantization parameters"): # iterate only nodes that we should compute their thresholds
|
80
101
|
for candidate_qc in n.candidates_quantization_cfg:
|
@@ -82,34 +103,43 @@ def calculate_quantization_params(graph: Graph,
|
|
82
103
|
if n.is_weights_quantization_enabled(attr):
|
83
104
|
# If the node's weights attribute should be quantized, we compute its quantization parameters
|
84
105
|
attr_cfg = candidate_qc.weights_quantization_cfg.get_attr_config(attr)
|
85
|
-
|
106
|
+
channels_axis = attr_cfg.weights_channels_axis
|
107
|
+
if channels_axis is not None:
|
108
|
+
output_channels_axis = channels_axis[0]
|
109
|
+
else:
|
110
|
+
output_channels_axis = None
|
86
111
|
|
87
|
-
|
88
|
-
|
112
|
+
mod_attr_cfg = attr_cfg
|
113
|
+
|
114
|
+
if attr_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
|
89
115
|
# Although we collected nodes for HMSE before running the loop, we keep this verification to
|
90
116
|
# notify the user in case of HMSE configured for node that is not compatible for this method
|
91
|
-
|
117
|
+
kernel_attr_name = graph.fw_info.get_kernel_op_attributes(n.type)
|
118
|
+
if len(kernel_attr_name) > 0:
|
119
|
+
kernel_attr_name = kernel_attr_name[0]
|
120
|
+
|
121
|
+
if kernel_attr_name is None or kernel_attr_name not in attr:
|
92
122
|
Logger.warning(f"The HMSE error method for parameters selection is only supported for "
|
93
123
|
f"kernel weights attributes. Running parameters selection for attribute "
|
94
124
|
f"'{attr}' in node '{n.name}' with the default MSE error method instead.")
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
attr_cfg.weights_channels_axis =
|
125
|
+
mod_attr_cfg = copy.deepcopy(attr_cfg)
|
126
|
+
mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE
|
127
|
+
|
128
|
+
weights_params, output_channels_axis = get_weights_qparams(n.get_weights_by_keys(attr),
|
129
|
+
candidate_qc.weights_quantization_cfg,
|
130
|
+
mod_attr_cfg,
|
131
|
+
output_channels_axis,
|
132
|
+
node=n,
|
133
|
+
hessian_info_service=hessian_info_service,
|
134
|
+
num_hessian_samples=num_hessian_samples)
|
135
|
+
attr_cfg.weights_channels_axis = (output_channels_axis, attr_cfg.weights_channels_axis[1])
|
106
136
|
attr_cfg.set_weights_quantization_param(weights_params)
|
107
137
|
|
108
|
-
if n.is_activation_quantization_enabled()
|
138
|
+
if n.is_activation_quantization_enabled():
|
109
139
|
# If node's activations should be quantized as well, we compute its activation quantization parameters
|
110
|
-
activation_params =
|
111
|
-
|
112
|
-
|
113
|
-
|
140
|
+
activation_params = get_activations_qparams(
|
141
|
+
activation_quant_cfg=candidate_qc.activation_quantization_cfg,
|
142
|
+
nodes_prior_info=n.prior_info,
|
143
|
+
out_stats_container=graph.get_out_stats_collector(n))
|
114
144
|
# Create a NodeQuantizationConfig containing all quantization params and attach it to the node
|
115
145
|
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params)
|
@@ -12,43 +12,35 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from
|
16
|
-
from typing import Dict, Any, Tuple, Callable, TYPE_CHECKING
|
15
|
+
from typing import Dict, Any, Tuple
|
17
16
|
|
18
17
|
import numpy as np
|
19
|
-
from mct_quantizers import QuantizationMethod
|
20
18
|
|
21
|
-
from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
|
22
|
-
from model_compression_toolkit.core import QuantizationErrorMethod
|
19
|
+
from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
|
23
20
|
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
24
|
-
from model_compression_toolkit.
|
25
|
-
|
21
|
+
from model_compression_toolkit.defaultdict import DefaultDict
|
22
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
23
|
+
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
|
24
|
+
WeightsAttrQuantizationConfig
|
26
25
|
from model_compression_toolkit.logger import Logger
|
27
26
|
|
28
|
-
if TYPE_CHECKING:
|
29
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
|
30
27
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
node=None,
|
39
|
-
hessian_info_service: HessianInfoService = None,
|
40
|
-
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
|
28
|
+
def get_weights_qparams(weights_attr_values: np.ndarray,
|
29
|
+
weights_quant_config: NodeWeightsQuantizationConfig,
|
30
|
+
attr_quant_config: WeightsAttrQuantizationConfig,
|
31
|
+
output_channels_axis: int,
|
32
|
+
node=None,
|
33
|
+
hessian_info_service: HessianInfoService = None,
|
34
|
+
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
|
41
35
|
"""
|
42
36
|
Compute thresholds to quantize a kernel according to a NodeWeightsQuantizationConfig
|
43
37
|
instance.
|
44
38
|
|
45
39
|
Args:
|
46
|
-
|
40
|
+
weights_attr_values: Weights attribute parameter to compute the quantization thresholds for.
|
41
|
+
weights_quant_config: Weights quantization configuration to define how the thresholds are computed.
|
47
42
|
attr_quant_config: A specific weights attribute quantization configuration to get its params.
|
48
|
-
weights_error_method: quantization error method.
|
49
|
-
l_p_value: p-norm to use for the Lp-norm distance.
|
50
43
|
output_channels_axis: Index of the kernel output channels dimension.
|
51
|
-
min_threshold: Minimal threshold to use if threshold is too small.
|
52
44
|
node: The node for which the quantization error is computed (used only with HMSE error method).
|
53
45
|
hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (used only with HMSE error method).
|
54
46
|
num_hessian_samples: Number of samples to approximate Hessian-based scores on (used only with HMSE error method).
|
@@ -57,43 +49,22 @@ def compute_weights_qparams(weights_attr_data: np.ndarray,
|
|
57
49
|
A dictionary with the quantization threshold of the kernel.
|
58
50
|
Selected quantization channel axis.
|
59
51
|
"""
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
52
|
+
if attr_quant_config.weights_quantization_params_fn is not None:
|
53
|
+
weights_params, output_channels_axis = attr_quant_config.weights_quantization_params_fn(
|
54
|
+
weights_attr_values,
|
55
|
+
p=attr_quant_config.l_p_value,
|
56
|
+
n_bits=attr_quant_config.weights_n_bits,
|
57
|
+
per_channel=attr_quant_config.weights_per_channel_threshold,
|
58
|
+
channel_axis=output_channels_axis,
|
59
|
+
min_threshold=weights_quant_config.min_threshold,
|
60
|
+
quant_error_method=attr_quant_config.weights_error_method,
|
61
|
+
node=node,
|
62
|
+
hessian_info_service=hessian_info_service,
|
63
|
+
num_hessian_samples=num_hessian_samples)
|
64
|
+
else: # pragma: no cover
|
65
|
+
Logger.error(f"Requested weights quantization parameters computation for node {node.name} without providing a "
|
66
|
+
f"weights_quantization_params_fn."
|
67
|
+
f"Returning an empty dictionary since no quantization parameters were computed.")
|
68
|
+
weights_params = {}
|
72
69
|
|
73
70
|
return weights_params, output_channels_axis
|
74
|
-
|
75
|
-
|
76
|
-
_weights_quant_params_fns = {
|
77
|
-
QuantizationMethod.POWER_OF_TWO: power_of_two_selection_tensor,
|
78
|
-
QuantizationMethod.SYMMETRIC: symmetric_selection_tensor,
|
79
|
-
QuantizationMethod.UNIFORM: uniform_selection_tensor,
|
80
|
-
QuantizationMethod.LUT_POT_QUANTIZER: partial(lut_kmeans_tensor, is_symmetric=False),
|
81
|
-
QuantizationMethod.LUT_SYM_QUANTIZER: partial(lut_kmeans_tensor, is_symmetric=True)
|
82
|
-
}
|
83
|
-
|
84
|
-
|
85
|
-
def _get_weights_quantization_params_fn(weights_quantization_method: QuantizationMethod) -> Callable:
|
86
|
-
"""
|
87
|
-
Generate a function for finding weights quantization parameters.
|
88
|
-
|
89
|
-
Args:
|
90
|
-
weights_quantization_method: Which quantization method to use for weights.
|
91
|
-
Returns:
|
92
|
-
A function to find the quantization parameters.
|
93
|
-
|
94
|
-
"""
|
95
|
-
params_fn = _weights_quant_params_fns.get(weights_quantization_method)
|
96
|
-
if not params_fn:
|
97
|
-
Logger.critical(
|
98
|
-
f"No parameter function found for the specified quantization method: {weights_quantization_method}") # pragma: no cover
|
99
|
-
return params_fn
|
@@ -12,7 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
|
16
17
|
from model_compression_toolkit.logger import Logger
|
17
18
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
18
19
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
|
@@ -46,12 +47,11 @@ def get_quantized_weights_attr_by_qc(attr_name: str,
|
|
46
47
|
output_channels_axis = None
|
47
48
|
|
48
49
|
Logger.debug(f'quantizing layer {n.name} attribute {attr_name} with {weights_qc.weights_n_bits} bits')
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
output_channels_axis=output_channels_axis)
|
50
|
+
quantized_kernel = weights_qc.weights_quantization_fn(n.get_weights_by_keys(attr_name),
|
51
|
+
n_bits=weights_qc.weights_n_bits,
|
52
|
+
signed=True,
|
53
|
+
quantization_params=weights_qc.weights_quantization_params,
|
54
|
+
per_channel=weights_qc.weights_per_channel_threshold,
|
55
|
+
output_channels_axis=output_channels_axis)
|
56
56
|
|
57
57
|
return quantized_kernel, channels_axis
|