mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250927.534__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -12,15 +12,23 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
|
17
|
+
from typing import Callable, Any, List, Tuple, Union, Dict, TYPE_CHECKING
|
16
18
|
from enum import Enum, auto
|
19
|
+
import numpy as np
|
17
20
|
|
18
|
-
from model_compression_toolkit.core.common.
|
21
|
+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
|
19
22
|
from model_compression_toolkit.logger import Logger
|
23
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
|
24
|
+
get_activation_quantization_params_fn, get_weights_quantization_params_fn
|
20
25
|
|
21
|
-
from model_compression_toolkit.
|
26
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
|
27
|
+
QuantizationErrorMethod
|
28
|
+
from model_compression_toolkit.target_platform_capabilities.constants import POS_ATTR
|
22
29
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import \
|
23
|
-
AttributeQuantizationConfig,
|
30
|
+
AttributeQuantizationConfig, \
|
31
|
+
OpQuantizationConfig
|
24
32
|
|
25
33
|
if TYPE_CHECKING:
|
26
34
|
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
|
@@ -38,7 +46,6 @@ class ActivationQuantizationMode(Enum):
|
|
38
46
|
FLN_QUANT = auto()
|
39
47
|
PRESERVE_QUANT = auto()
|
40
48
|
NO_QUANT = auto()
|
41
|
-
FLN_NO_QUANT = auto()
|
42
49
|
|
43
50
|
|
44
51
|
class BaseNodeQuantizationConfig(object):
|
@@ -59,11 +66,12 @@ class BaseNodeQuantizationConfig(object):
|
|
59
66
|
kwargs: A dictionary with additional key arguments.
|
60
67
|
|
61
68
|
"""
|
69
|
+
|
62
70
|
if hasattr(self, config_parameter_name):
|
63
71
|
setattr(self, config_parameter_name, config_parameter_value)
|
64
72
|
else:
|
65
|
-
|
66
|
-
|
73
|
+
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config and "
|
74
|
+
f"was not updated!")
|
67
75
|
|
68
76
|
def __repr__(self) -> str:
|
69
77
|
"""
|
@@ -77,14 +85,29 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
77
85
|
"""
|
78
86
|
Attributes for configuring the quantization of the activations of a node.
|
79
87
|
"""
|
80
|
-
def __init__(self,
|
88
|
+
def __init__(self,
|
89
|
+
qc: QuantizationConfig,
|
90
|
+
op_cfg: OpQuantizationConfig,
|
91
|
+
activation_quantization_fn: Callable,
|
92
|
+
activation_quantization_params_fn: Callable
|
93
|
+
):
|
81
94
|
"""
|
82
95
|
|
83
96
|
Args:
|
97
|
+
qc: QuantizationConfig to create the node's config from.
|
84
98
|
op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
|
99
|
+
activation_quantization_fn: Function to use when quantizing the node's activations.
|
100
|
+
activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
|
85
101
|
"""
|
102
|
+
|
103
|
+
self.activation_quantization_fn = activation_quantization_fn
|
104
|
+
self.activation_quantization_params_fn = activation_quantization_params_fn
|
105
|
+
self.activation_quantization_params = {}
|
86
106
|
self.activation_quantization_method = op_cfg.activation_quantization_method
|
107
|
+
self.activation_error_method = qc.activation_error_method
|
87
108
|
self.activation_n_bits = op_cfg.activation_n_bits
|
109
|
+
self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
|
110
|
+
self.activation_bias_correction_term = None
|
88
111
|
if op_cfg.enable_activation_quantization and op_cfg.quantization_preserving:
|
89
112
|
raise ValueError("An OpQuantizationConfig can't have both enable_activation_quantization and quantization_preserving enabled.")
|
90
113
|
if op_cfg.enable_activation_quantization:
|
@@ -94,13 +117,15 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
94
117
|
else:
|
95
118
|
self.quant_mode = ActivationQuantizationMode.NO_QUANT
|
96
119
|
self.signedness = op_cfg.signedness
|
97
|
-
|
98
|
-
self.
|
99
|
-
|
100
|
-
self.
|
101
|
-
|
102
|
-
|
103
|
-
self.
|
120
|
+
self.activation_channel_equalization = qc.activation_channel_equalization
|
121
|
+
self.input_scaling = qc.input_scaling
|
122
|
+
self.min_threshold = qc.min_threshold
|
123
|
+
self.l_p_value = qc.l_p_value
|
124
|
+
self.shift_negative_activation_correction = qc.shift_negative_activation_correction
|
125
|
+
self.z_threshold = qc.z_threshold
|
126
|
+
self.shift_negative_ratio = qc.shift_negative_ratio
|
127
|
+
self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
|
128
|
+
self.concat_threshold_update = qc.concat_threshold_update
|
104
129
|
|
105
130
|
@property
|
106
131
|
def enable_activation_quantization(self):
|
@@ -113,6 +138,65 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
113
138
|
def fln_quantization(self):
|
114
139
|
return self.quant_mode == ActivationQuantizationMode.FLN_QUANT
|
115
140
|
|
141
|
+
def quantize_node_output(self,
|
142
|
+
tensors: Any) -> Any:
|
143
|
+
"""
|
144
|
+
|
145
|
+
Args:
|
146
|
+
tensors: framework tensor/s
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
Framework tensor/s after applying fake quantization.
|
150
|
+
|
151
|
+
"""
|
152
|
+
fake_quant = self.activation_quantization_fn(self.activation_n_bits,
|
153
|
+
self.activation_quantization_params)
|
154
|
+
|
155
|
+
if fake_quant is None:
|
156
|
+
Logger.critical(
|
157
|
+
"Layer is intended to be quantized, but the fake_quant function is None.") # pragma: no cover
|
158
|
+
|
159
|
+
return fake_quant(tensors)
|
160
|
+
|
161
|
+
@property
|
162
|
+
def activation_error_method(self) -> QuantizationErrorMethod:
|
163
|
+
"""
|
164
|
+
activation_error_method getter.
|
165
|
+
"""
|
166
|
+
return self._activation_error_method
|
167
|
+
|
168
|
+
@activation_error_method.setter
|
169
|
+
def activation_error_method(self, value: QuantizationErrorMethod):
|
170
|
+
"""
|
171
|
+
activation_error_method setter.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
value: New activation_error_method to set to the node activation configuration.
|
175
|
+
|
176
|
+
"""
|
177
|
+
self._activation_error_method = value
|
178
|
+
self.activation_quantization_params_fn = get_activation_quantization_params_fn(activation_quantization_method=self.activation_quantization_method)
|
179
|
+
|
180
|
+
def set_activation_quantization_fn(self, activation_quantization_fn: Callable):
|
181
|
+
"""
|
182
|
+
Sets activation quantization function for the node.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
activation_quantization_fn: Function for quantazing the activations.
|
186
|
+
|
187
|
+
"""
|
188
|
+
self.activation_quantization_fn = activation_quantization_fn
|
189
|
+
|
190
|
+
def set_activation_quantization_params_fn(self, activation_quantization_params_fn:Callable):
|
191
|
+
"""
|
192
|
+
Sets activation params function for the node.
|
193
|
+
|
194
|
+
Args:
|
195
|
+
activation_quantization_params_fn: Function for calculating activation params.
|
196
|
+
|
197
|
+
"""
|
198
|
+
self.activation_quantization_params_fn = activation_quantization_params_fn
|
199
|
+
|
116
200
|
def set_activation_quantization_param(self,
|
117
201
|
activation_params: dict):
|
118
202
|
"""
|
@@ -122,7 +206,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
122
206
|
activation_params: Dictionary that contains weight quantization params.
|
123
207
|
|
124
208
|
"""
|
125
|
-
assert self.quant_mode == ActivationQuantizationMode.QUANT
|
209
|
+
assert self.quant_mode == ActivationQuantizationMode.QUANT
|
126
210
|
for param_name, param_value in activation_params.items():
|
127
211
|
self.activation_quantization_params[param_name] = param_value
|
128
212
|
|
@@ -139,16 +223,36 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
139
223
|
if not isinstance(other, NodeActivationQuantizationConfig):
|
140
224
|
return False # pragma: no cover
|
141
225
|
|
142
|
-
return self.
|
226
|
+
return self.activation_quantization_fn == other.activation_quantization_fn and \
|
227
|
+
self.activation_quantization_params_fn == other.activation_quantization_params_fn and \
|
228
|
+
self.activation_error_method == other.activation_error_method and \
|
229
|
+
self.activation_quantization_method == other.activation_quantization_method and \
|
143
230
|
self.activation_n_bits == other.activation_n_bits and \
|
144
231
|
self.quant_mode == other.quant_mode and \
|
145
|
-
self.
|
232
|
+
self.activation_channel_equalization == other.activation_channel_equalization and \
|
233
|
+
self.input_scaling == other.input_scaling and \
|
234
|
+
self.min_threshold == other.min_threshold and \
|
235
|
+
self.l_p_value == other.l_p_value and \
|
236
|
+
self.shift_negative_activation_correction == other.shift_negative_activation_correction and \
|
237
|
+
self.z_threshold == other.z_threshold and \
|
238
|
+
self.shift_negative_ratio == other.shift_negative_ratio and \
|
239
|
+
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
|
146
240
|
|
147
241
|
def __hash__(self):
|
148
|
-
return hash((self.
|
242
|
+
return hash((self.activation_quantization_fn,
|
243
|
+
self.activation_quantization_params_fn,
|
244
|
+
self.activation_error_method,
|
245
|
+
self.activation_quantization_method,
|
149
246
|
self.activation_n_bits,
|
150
247
|
self.quant_mode,
|
151
|
-
self.
|
248
|
+
self.activation_channel_equalization,
|
249
|
+
self.input_scaling,
|
250
|
+
self.min_threshold,
|
251
|
+
self.l_p_value,
|
252
|
+
self.shift_negative_activation_correction,
|
253
|
+
self.z_threshold,
|
254
|
+
self.shift_negative_ratio,
|
255
|
+
self.shift_negative_threshold_recalculation))
|
152
256
|
|
153
257
|
|
154
258
|
class WeightsAttrQuantizationConfig:
|
@@ -156,21 +260,65 @@ class WeightsAttrQuantizationConfig:
|
|
156
260
|
Configuration for quantizing a weights attribute of a node.
|
157
261
|
"""
|
158
262
|
def __init__(self,
|
263
|
+
qc: QuantizationConfig,
|
159
264
|
weights_attr_cfg: AttributeQuantizationConfig,
|
160
|
-
weights_channels_axis:
|
265
|
+
weights_channels_axis: Tuple[int, int] = None):
|
161
266
|
"""
|
162
267
|
|
163
268
|
Args:
|
269
|
+
qc: QuantizationConfig to create the node's config from.
|
164
270
|
weights_attr_cfg: AttributeQuantizationConfig with parameters to use when creating the node's attribute quantization config.
|
165
271
|
weights_channels_axis: Axis to quantize a node's attribute when quantizing per-channel (if not quantizing per-channel than expecting None).
|
166
272
|
"""
|
273
|
+
self.weights_quantization_fn = get_weights_quantization_fn(weights_attr_cfg.weights_quantization_method)
|
274
|
+
self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_attr_cfg.weights_quantization_method)
|
167
275
|
self.weights_channels_axis = weights_channels_axis
|
276
|
+
self.weights_quantization_params = {}
|
168
277
|
self.weights_quantization_method = weights_attr_cfg.weights_quantization_method
|
278
|
+
self.weights_error_method = qc.weights_error_method
|
169
279
|
self.weights_n_bits = weights_attr_cfg.weights_n_bits
|
170
280
|
self.weights_per_channel_threshold = weights_attr_cfg.weights_per_channel_threshold
|
171
281
|
self.enable_weights_quantization = weights_attr_cfg.enable_weights_quantization
|
282
|
+
self.l_p_value = qc.l_p_value
|
172
283
|
|
173
|
-
|
284
|
+
@property
|
285
|
+
def weights_error_method(self) -> QuantizationErrorMethod:
|
286
|
+
"""
|
287
|
+
weights_error_method getter.
|
288
|
+
"""
|
289
|
+
return self._weights_error_method
|
290
|
+
|
291
|
+
@weights_error_method.setter
|
292
|
+
def weights_error_method(self, value: QuantizationErrorMethod):
|
293
|
+
"""
|
294
|
+
weights_error_method setter.
|
295
|
+
|
296
|
+
Args:
|
297
|
+
value: New weights_error_method to set to the node weights configuration.
|
298
|
+
|
299
|
+
"""
|
300
|
+
self._weights_error_method = value
|
301
|
+
self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_quantization_method=self.weights_quantization_method)
|
302
|
+
|
303
|
+
def set_weights_quantization_fn(self, weights_quantization_fn: Callable):
|
304
|
+
"""
|
305
|
+
Sets weights quantization function for the node.
|
306
|
+
|
307
|
+
Args:
|
308
|
+
weights_quantization_fn: Function for quantazing the weights.
|
309
|
+
|
310
|
+
"""
|
311
|
+
self.weights_quantization_fn = weights_quantization_fn
|
312
|
+
|
313
|
+
def set_weights_quantization_params_fn(self, weights_quantization_params_fn: Callable):
|
314
|
+
"""
|
315
|
+
Sets weights params function for the node.
|
316
|
+
|
317
|
+
Args:
|
318
|
+
weights_quantization_params_fn: Function for calculating the weights params.
|
319
|
+
|
320
|
+
"""
|
321
|
+
self.weights_quantization_params_fn = weights_quantization_params_fn
|
174
322
|
|
175
323
|
def set_weights_quantization_param(self,
|
176
324
|
weights_params: dict):
|
@@ -185,6 +333,31 @@ class WeightsAttrQuantizationConfig:
|
|
185
333
|
for param_name, param_value in weights_params.items():
|
186
334
|
self.weights_quantization_params[param_name] = param_value
|
187
335
|
|
336
|
+
def calculate_and_set_weights_params(self, tensor_data: np.ndarray, min_threshold: float):
|
337
|
+
"""
|
338
|
+
Args:
|
339
|
+
tensor_data: Tensor content as Numpy array.
|
340
|
+
min_threshold: A minimal threshold to set as quantization parameter.
|
341
|
+
|
342
|
+
Returns:
|
343
|
+
Recalculated weights quantization params from the kernel and channel axis.
|
344
|
+
|
345
|
+
"""
|
346
|
+
assert self.enable_weights_quantization
|
347
|
+
assert not (self.weights_per_channel_threshold and self.weights_channels_axis is None), \
|
348
|
+
"Trying to calculate threshold per channel, channel axis in None."
|
349
|
+
if self.weights_quantization_params_fn is not None:
|
350
|
+
self.set_weights_quantization_param(
|
351
|
+
self.weights_quantization_params_fn(tensor_data,
|
352
|
+
p=self.l_p_value,
|
353
|
+
n_bits=self.weights_n_bits,
|
354
|
+
per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
|
355
|
+
channel_axis=self.weights_channels_axis[0], # output channel axis
|
356
|
+
min_threshold=min_threshold)[0] # Take only first output, the q-params, as axis is already chosen.
|
357
|
+
)
|
358
|
+
else:
|
359
|
+
self.set_weights_quantization_param({})
|
360
|
+
|
188
361
|
def __eq__(self, other: Any) -> bool:
|
189
362
|
"""
|
190
363
|
Compares the object to another object to find if they are equal.
|
@@ -198,18 +371,26 @@ class WeightsAttrQuantizationConfig:
|
|
198
371
|
if not isinstance(other, WeightsAttrQuantizationConfig):
|
199
372
|
return False # pragma: no cover
|
200
373
|
|
201
|
-
return self.
|
374
|
+
return self.weights_quantization_fn == other.weights_quantization_fn and \
|
375
|
+
self.weights_quantization_params_fn == other.weights_quantization_params_fn and \
|
376
|
+
self.weights_channels_axis == other.weights_channels_axis and \
|
377
|
+
self.weights_error_method == other.weights_error_method and \
|
202
378
|
self.weights_quantization_method == other.weights_quantization_method and \
|
203
379
|
self.weights_n_bits == other.weights_n_bits and \
|
204
380
|
self.weights_per_channel_threshold == other.weights_per_channel_threshold and \
|
205
|
-
self.enable_weights_quantization == other.enable_weights_quantization
|
381
|
+
self.enable_weights_quantization == other.enable_weights_quantization and \
|
382
|
+
self.l_p_value == other.l_p_value
|
206
383
|
|
207
384
|
def __hash__(self):
|
208
|
-
return hash((self.
|
385
|
+
return hash((self.weights_quantization_fn,
|
386
|
+
self.weights_quantization_params_fn,
|
387
|
+
self.weights_channels_axis,
|
388
|
+
self.weights_error_method,
|
209
389
|
self.weights_quantization_method,
|
210
390
|
self.weights_n_bits,
|
211
391
|
self.weights_per_channel_threshold,
|
212
|
-
self.enable_weights_quantization
|
392
|
+
self.enable_weights_quantization,
|
393
|
+
self.l_p_value))
|
213
394
|
|
214
395
|
|
215
396
|
class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
@@ -217,19 +398,23 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
217
398
|
Holding a mapping between the node's weights attributes and their quantization configurations,
|
218
399
|
in addition to quantization parameters that are global for all attributes of the represented node.
|
219
400
|
"""
|
220
|
-
def __init__(self,
|
401
|
+
def __init__(self, qc: QuantizationConfig,
|
221
402
|
op_cfg: OpQuantizationConfig,
|
222
|
-
weights_channels_axis:
|
403
|
+
weights_channels_axis: Tuple[int, int],
|
223
404
|
node_attrs_list: List[str]):
|
224
405
|
"""
|
225
406
|
|
226
407
|
Args:
|
408
|
+
qc: QuantizationConfig to create the node's config from.
|
227
409
|
op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
|
228
410
|
weights_channels_axis: Axis to quantize a node's weights attribute when quantizing per-channel.
|
229
411
|
node_attrs_list: A list of the node's weights attributes names.
|
230
412
|
|
231
413
|
"""
|
414
|
+
self.min_threshold = qc.min_threshold
|
232
415
|
self.simd_size = op_cfg.simd_size
|
416
|
+
self.weights_second_moment_correction = qc.weights_second_moment_correction
|
417
|
+
self.weights_bias_correction = qc.weights_bias_correction
|
233
418
|
|
234
419
|
# Initialize a quantization configuration for each of the node's attributes
|
235
420
|
self.attributes_config_mapping = {}
|
@@ -241,7 +426,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
241
426
|
# POS_ATTR string. If none are found, it indicates that no specific quantization config is defined for
|
242
427
|
# positional weights, so the default config will be used instead.
|
243
428
|
attrs_included_in_name = {k: v for k, v in op_cfg.attr_weights_configs_mapping.items() if
|
244
|
-
|
429
|
+
POS_ATTR in k}
|
245
430
|
|
246
431
|
if len(attrs_included_in_name) > 1: # pragma: no cover
|
247
432
|
raise ValueError(f"Found multiple attribute in FQC OpConfig that are contained "
|
@@ -257,7 +442,8 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
257
442
|
attr_cfg = list(attrs_included_in_name.values())[0]
|
258
443
|
|
259
444
|
# Register this attribute under the positional attributes config mapping.
|
260
|
-
self.pos_attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(
|
445
|
+
self.pos_attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(qc=qc,
|
446
|
+
weights_attr_cfg=attr_cfg,
|
261
447
|
weights_channels_axis=
|
262
448
|
weights_channels_axis)
|
263
449
|
else:
|
@@ -274,16 +460,9 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
274
460
|
else:
|
275
461
|
attr_cfg = list(attrs_included_in_name.values())[0]
|
276
462
|
|
277
|
-
self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(
|
463
|
+
self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(qc=qc,
|
464
|
+
weights_attr_cfg=attr_cfg,
|
278
465
|
weights_channels_axis=weights_channels_axis)
|
279
|
-
# TODO this is set by batch norm reconstruction substitution when folded batch norms are added back, to mark
|
280
|
-
# the nodes that the correction should be applied to (for some nodes it gets disabled) and BNs removed.
|
281
|
-
# The actual correction is only computed when it's applied in ptq, so it seems that both substitutions could
|
282
|
-
# be unified, and no info need to pass between.
|
283
|
-
self.weights_second_moment_correction = None
|
284
|
-
# TODO: computed corrected bias is injected to the node config. Probably shouldn't be here. Also it can be
|
285
|
-
# computed on the final config, instead of all candidates and then there is no need to save it at all.
|
286
|
-
self.bias_corrected = None
|
287
466
|
|
288
467
|
def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationConfig:
|
289
468
|
"""
|
@@ -420,8 +599,8 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
420
599
|
if hasattr(attr_cfg, config_parameter_name):
|
421
600
|
setattr(attr_cfg, config_parameter_name, config_parameter_value)
|
422
601
|
else:
|
423
|
-
|
424
|
-
|
602
|
+
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
|
603
|
+
f"weights attribute {attr_name} and was not updated!")
|
425
604
|
else: # pragma: no cover
|
426
605
|
Logger.critical(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
|
427
606
|
|
@@ -438,7 +617,10 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
438
617
|
if not isinstance(other, NodeWeightsQuantizationConfig):
|
439
618
|
return False # pragma: no cover
|
440
619
|
|
441
|
-
return self.
|
620
|
+
return self.min_threshold == other.min_threshold and \
|
621
|
+
self.simd_size == other.simd_size and \
|
622
|
+
self.weights_second_moment_correction == other.weights_second_moment_correction and \
|
623
|
+
self.weights_bias_correction == other.weights_bias_correction and \
|
442
624
|
self.attributes_config_mapping.keys() == other.attributes_config_mapping.keys() and \
|
443
625
|
all([self.attributes_config_mapping[k] == other.attributes_config_mapping[k]
|
444
626
|
for k in self.attributes_config_mapping.keys()]) and \
|
@@ -447,6 +629,9 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
447
629
|
for k in self.pos_attributes_config_mapping.keys()])
|
448
630
|
|
449
631
|
def __hash__(self):
|
450
|
-
return hash((self.
|
632
|
+
return hash((self.min_threshold,
|
633
|
+
self.simd_size,
|
634
|
+
self.weights_second_moment_correction,
|
635
|
+
self.weights_bias_correction,
|
451
636
|
frozenset(self.attributes_config_mapping),
|
452
637
|
frozenset(self.pos_attributes_config_mapping)))
|
@@ -90,6 +90,7 @@ class QuantizationConfig:
|
|
90
90
|
shift_negative_activation_correction: bool = True
|
91
91
|
activation_channel_equalization: bool = False
|
92
92
|
z_threshold: float = math.inf
|
93
|
+
min_threshold: float = MIN_THRESHOLD
|
93
94
|
l_p_value: int = 2
|
94
95
|
linear_collapsing: bool = True
|
95
96
|
residual_collapsing: bool = True
|
@@ -14,35 +14,15 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from collections.abc import Callable
|
17
|
+
from functools import partial
|
17
18
|
|
18
19
|
from mct_quantizers import QuantizationMethod
|
19
|
-
|
20
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
|
21
20
|
from model_compression_toolkit.logger import Logger
|
22
21
|
from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
|
23
22
|
from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
|
24
23
|
symmetric_quantizer, uniform_quantizer
|
25
24
|
|
26
25
|
|
27
|
-
def get_activation_quantization_fn(activation_quantization_cfg: NodeActivationQuantizationConfig,
|
28
|
-
get_activation_quantization_fn_factory: Callable) -> Callable:
|
29
|
-
"""
|
30
|
-
Get activation quantizer based on activation quantization configuration.
|
31
|
-
|
32
|
-
Args:
|
33
|
-
activation_quantization_cfg: activation quantization configuration.
|
34
|
-
get_activation_quantization_fn_factory: activation quantization functions factory.
|
35
|
-
|
36
|
-
Returns:
|
37
|
-
Activation quantizer that accepts a tensor and returns a quantized tensor.
|
38
|
-
"""
|
39
|
-
quantizer_factory = get_activation_quantization_fn_factory(
|
40
|
-
activation_quantization_cfg.activation_quantization_method)
|
41
|
-
quantizer = quantizer_factory(activation_quantization_cfg.activation_n_bits,
|
42
|
-
activation_quantization_cfg.activation_quantization_params)
|
43
|
-
return quantizer
|
44
|
-
|
45
|
-
|
46
26
|
def get_weights_quantization_fn(weights_quantization_method: QuantizationMethod) -> Callable:
|
47
27
|
"""
|
48
28
|
Generate a function for weight quantization.
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from collections.abc import Callable
|
17
|
+
from functools import partial
|
18
|
+
|
19
|
+
from mct_quantizers import QuantizationMethod
|
20
|
+
from model_compression_toolkit.logger import Logger
|
21
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import \
|
22
|
+
lut_kmeans_tensor, lut_kmeans_histogram
|
23
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import \
|
24
|
+
symmetric_selection_tensor, symmetric_selection_histogram
|
25
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import \
|
26
|
+
uniform_selection_histogram, uniform_selection_tensor
|
27
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import \
|
28
|
+
power_of_two_selection_tensor, power_of_two_selection_histogram
|
29
|
+
|
30
|
+
|
31
|
+
def get_activation_quantization_params_fn(activation_quantization_method: QuantizationMethod) -> Callable:
|
32
|
+
"""
|
33
|
+
Generate a function for finding activation quantization parameters.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
activation_quantization_method: Which quantization method to use for activations.
|
37
|
+
Returns:
|
38
|
+
A function to find the quantization parameters.
|
39
|
+
|
40
|
+
"""
|
41
|
+
if activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
|
42
|
+
params_fn = power_of_two_selection_histogram
|
43
|
+
elif activation_quantization_method == QuantizationMethod.SYMMETRIC:
|
44
|
+
params_fn = symmetric_selection_histogram
|
45
|
+
elif activation_quantization_method == QuantizationMethod.UNIFORM:
|
46
|
+
params_fn = uniform_selection_histogram
|
47
|
+
elif activation_quantization_method == QuantizationMethod.LUT_POT_QUANTIZER:
|
48
|
+
params_fn = lut_kmeans_histogram
|
49
|
+
else:
|
50
|
+
Logger.critical(
|
51
|
+
f"No parameter function found for the specified quantization method: {activation_quantization_method}") # pragma: no cover
|
52
|
+
return params_fn
|
53
|
+
|
54
|
+
|
55
|
+
def get_weights_quantization_params_fn(weights_quantization_method: QuantizationMethod) -> Callable:
|
56
|
+
"""
|
57
|
+
Generate a function for finding weights quantization parameters.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
weights_quantization_method: Which quantization method to use for weights.
|
61
|
+
Returns:
|
62
|
+
A function to find the quantization parameters.
|
63
|
+
|
64
|
+
"""
|
65
|
+
if weights_quantization_method == QuantizationMethod.POWER_OF_TWO:
|
66
|
+
params_fn = power_of_two_selection_tensor
|
67
|
+
elif weights_quantization_method == QuantizationMethod.SYMMETRIC:
|
68
|
+
params_fn = symmetric_selection_tensor
|
69
|
+
elif weights_quantization_method == QuantizationMethod.UNIFORM:
|
70
|
+
params_fn = uniform_selection_tensor
|
71
|
+
elif weights_quantization_method == QuantizationMethod.LUT_POT_QUANTIZER:
|
72
|
+
params_fn = partial(lut_kmeans_tensor, is_symmetric=False)
|
73
|
+
elif weights_quantization_method == QuantizationMethod.LUT_SYM_QUANTIZER:
|
74
|
+
params_fn = partial(lut_kmeans_tensor, is_symmetric=True)
|
75
|
+
else:
|
76
|
+
Logger.critical(
|
77
|
+
f"No parameter function found for the specified quantization method: {weights_quantization_method}") # pragma: no cover
|
78
|
+
return params_fn
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py
CHANGED
@@ -12,12 +12,9 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import
|
16
|
-
|
17
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import
|
18
|
-
|
19
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.
|
20
|
-
symmetric_no_clipping_selection_min_max, symmetric_selection_histogram, symmetric_selection_tensor)
|
21
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import (
|
22
|
-
uniform_no_clipping_selection_min_max, uniform_selection_histogram, uniform_selection_tensor)
|
15
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import power_of_two_no_clipping_selection_min_max, \
|
16
|
+
power_of_two_selection_histogram, power_of_two_selection_tensor
|
17
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import lut_kmeans_tensor
|
18
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import symmetric_no_clipping_selection_min_max
|
19
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import uniform_no_clipping_selection_min_max
|
23
20
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.outlier_filter import z_score_filter
|