mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -14,37 +14,31 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
import copy
|
17
|
-
from typing import Dict, Any, Tuple, List, Type, Union
|
17
|
+
from typing import Dict, Any, Tuple, List, Type, Union
|
18
18
|
|
19
19
|
import numpy as np
|
20
20
|
|
21
|
-
from model_compression_toolkit.core.common.framework_info import get_fw_info, ChannelAxisMapping
|
22
21
|
from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \
|
23
22
|
ACTIVATION_N_BITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
|
24
|
-
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import NodeQuantizationConfig
|
25
23
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
|
26
24
|
ActivationQuantizationMode
|
27
25
|
from model_compression_toolkit.logger import Logger
|
26
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
|
27
|
+
OpQuantizationConfig
|
28
|
+
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
|
28
29
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
|
30
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
31
|
+
FrameworkQuantizationCapabilities
|
29
32
|
|
30
33
|
|
31
34
|
WeightAttrT = Union[str, int]
|
32
35
|
|
33
36
|
|
34
|
-
class NodeFrameworkInfo(NamedTuple):
|
35
|
-
"""
|
36
|
-
Node's specific framework information.
|
37
|
-
"""
|
38
|
-
channel_axis: ChannelAxisMapping
|
39
|
-
out_channel_axis: int
|
40
|
-
minmax: Tuple[float, float]
|
41
|
-
kernel_attr: str
|
42
|
-
|
43
|
-
|
44
37
|
class BaseNode:
|
45
38
|
"""
|
46
39
|
Class to represent a node in a graph that represents the model.
|
47
40
|
"""
|
41
|
+
|
48
42
|
def __init__(self,
|
49
43
|
name: str,
|
50
44
|
framework_attr: Dict[str, Any],
|
@@ -90,84 +84,28 @@ class BaseNode:
|
|
90
84
|
self.inputs_as_list = inputs_as_list
|
91
85
|
self.final_weights_quantization_cfg = None
|
92
86
|
self.final_activation_quantization_cfg = None
|
93
|
-
self.
|
87
|
+
self.candidates_quantization_cfg = None
|
94
88
|
self.prior_info = None
|
95
89
|
self.has_activation = has_activation
|
96
90
|
self.is_custom = is_custom
|
97
|
-
self.node_fw_info = self._get_fw_node_attrs(layer_class, framework_attr)
|
98
|
-
|
99
|
-
def _get_fw_node_attrs(self, node_type, framework_attr):
|
100
|
-
fw_info = get_fw_info()
|
101
|
-
return None if fw_info is None else NodeFrameworkInfo(
|
102
|
-
fw_info.get_kernel_channels(node_type),
|
103
|
-
fw_info.get_out_channel_axis(node_type),
|
104
|
-
fw_info.get_layer_min_max(node_type, framework_attr),
|
105
|
-
fw_info.get_kernel_op_attribute(node_type),
|
106
|
-
)
|
107
|
-
|
108
|
-
def _assert_fw_info_exists(self):
|
109
|
-
"""
|
110
|
-
Verify NodeFrameworkInfo was initialized.
|
111
|
-
"""
|
112
|
-
assert self.node_fw_info is not None, f"NodeFrameworkInfo not initialized for node {self.name}" # pragma: no cover
|
113
|
-
|
114
|
-
@property
|
115
|
-
def channel_axis(self) -> ChannelAxisMapping:
|
116
|
-
"""
|
117
|
-
Extract channels axis from node's NodeFrameworkInfo.
|
118
|
-
|
119
|
-
Returns:
|
120
|
-
Channels axis named tuple.
|
121
|
-
"""
|
122
|
-
self._assert_fw_info_exists()
|
123
|
-
return self.node_fw_info.channel_axis
|
124
|
-
|
125
|
-
@property
|
126
|
-
def out_channel_axis(self) -> int:
|
127
|
-
"""
|
128
|
-
Extract output channel axis from node's NodeFrameworkInfo.
|
129
|
-
|
130
|
-
Returns:
|
131
|
-
Output channel axis.
|
132
|
-
"""
|
133
|
-
self._assert_fw_info_exists()
|
134
|
-
return self.node_fw_info.out_channel_axis
|
135
91
|
|
136
92
|
@property
|
137
|
-
def
|
93
|
+
def type(self):
|
138
94
|
"""
|
139
|
-
|
140
|
-
|
95
|
+
A function to get the node's layer_class op for convenient comparison
|
141
96
|
Returns:
|
142
|
-
|
143
|
-
"""
|
144
|
-
self._assert_fw_info_exists()
|
145
|
-
return self.node_fw_info.minmax
|
146
|
-
|
147
|
-
@property
|
148
|
-
def kernel_attr(self) -> str:
|
97
|
+
the node's layer_class
|
149
98
|
"""
|
150
|
-
|
99
|
+
return self.layer_class
|
151
100
|
|
152
|
-
|
153
|
-
Kernel name.
|
101
|
+
def get_has_activation(self):
|
154
102
|
"""
|
155
|
-
|
156
|
-
return self.node_fw_info.kernel_attr
|
103
|
+
Returns has_activation attribute.
|
157
104
|
|
158
|
-
|
159
|
-
def candidates_quantization_cfg(self):
|
160
|
-
assert self.quantization_cfg
|
161
|
-
return self.quantization_cfg.candidates_quantization_cfg
|
105
|
+
Returns: Whether the node has activation to quantize.
|
162
106
|
|
163
|
-
@property
|
164
|
-
def type(self):
|
165
|
-
"""
|
166
|
-
A function to get the node's layer_class op for convenient comparison
|
167
|
-
Returns:
|
168
|
-
the node's layer_class
|
169
107
|
"""
|
170
|
-
return self.
|
108
|
+
return self.has_activation
|
171
109
|
|
172
110
|
@property
|
173
111
|
def has_positional_weights(self):
|
@@ -195,31 +133,19 @@ class BaseNode:
|
|
195
133
|
Returns: Whether node activation quantization is enabled or not.
|
196
134
|
"""
|
197
135
|
return self._is_single_quant_mode(ActivationQuantizationMode.QUANT)
|
198
|
-
|
199
|
-
def
|
136
|
+
|
137
|
+
def is_fln_quantization(self) -> bool:
|
200
138
|
"""
|
201
|
-
Returns: Whether node is FLN
|
139
|
+
Returns: Whether the node's activation quantization is FLN
|
202
140
|
"""
|
203
|
-
return self._is_single_quant_mode(ActivationQuantizationMode.
|
204
|
-
|
141
|
+
return self._is_single_quant_mode(ActivationQuantizationMode.FLN_QUANT)
|
142
|
+
|
205
143
|
def is_quantization_preserving(self) -> bool:
|
206
144
|
"""
|
207
145
|
Returns: Whether node activation quantization information is preserved from its inputs.
|
208
146
|
"""
|
209
147
|
return self._is_single_quant_mode(ActivationQuantizationMode.PRESERVE_QUANT)
|
210
148
|
|
211
|
-
def is_no_quantization(self) -> bool:
|
212
|
-
"""
|
213
|
-
Returns: Whether node is no quantization.
|
214
|
-
"""
|
215
|
-
return self._is_single_quant_mode(ActivationQuantizationMode.NO_QUANT)
|
216
|
-
|
217
|
-
def is_fln_quantization(self) -> bool:
|
218
|
-
"""
|
219
|
-
Returns: Whether the node's activation quantization is FLN
|
220
|
-
"""
|
221
|
-
return self._is_single_quant_mode(ActivationQuantizationMode.FLN_QUANT)
|
222
|
-
|
223
149
|
def is_weights_quantization_enabled(self, attr_name: str) -> bool:
|
224
150
|
"""
|
225
151
|
Checks whether a node's weights attribute quantization is enabled.
|
@@ -372,11 +298,14 @@ class BaseNode:
|
|
372
298
|
|
373
299
|
return input_tensors
|
374
300
|
|
375
|
-
def get_num_parameters(self) -> Tuple[int,int]:
|
301
|
+
def get_num_parameters(self, fw_info) -> Tuple[int,int]:
|
376
302
|
"""
|
377
303
|
Compute the number of parameters the node holds.
|
378
304
|
It returns a tuple: Number of quantized parameters, number of float parameters.
|
379
305
|
|
306
|
+
Args:
|
307
|
+
fw_info: Framework info to decide which attributes should be quantized.
|
308
|
+
|
380
309
|
Returns:
|
381
310
|
A tuple of (Number of quantized parameters, number of float parameters).
|
382
311
|
|
@@ -385,10 +314,11 @@ class BaseNode:
|
|
385
314
|
|
386
315
|
q_node_num_params = 0
|
387
316
|
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
317
|
+
for attr in fw_info.get_kernel_op_attributes(self.type):
|
318
|
+
if attr is not None:
|
319
|
+
w = self.get_weights_by_keys(attr)
|
320
|
+
if w is not None:
|
321
|
+
q_node_num_params += w.flatten().shape[0]
|
392
322
|
|
393
323
|
f_node_num_params = total_node_params - q_node_num_params
|
394
324
|
|
@@ -396,19 +326,22 @@ class BaseNode:
|
|
396
326
|
assert int(f_node_num_params) == f_node_num_params
|
397
327
|
return int(q_node_num_params), int(f_node_num_params)
|
398
328
|
|
399
|
-
def get_memory_bytes(self) -> float:
|
329
|
+
def get_memory_bytes(self, fw_info) -> float:
|
400
330
|
"""
|
401
331
|
Compute the number of bytes the node's memory requires.
|
402
332
|
|
333
|
+
Args:
|
334
|
+
fw_info: Framework info to decide which attributes should be quantized.
|
335
|
+
|
403
336
|
Returns: Number of bytes the node's memory requires.
|
404
337
|
|
405
338
|
"""
|
406
339
|
# TODO: this method is used for tensorboard only. If we want to enable logging of other attributes memory
|
407
340
|
# then it needs to be modified. But, it might be better to remove this method from the BaseNode completely.
|
408
|
-
kernel_attr = self.
|
341
|
+
kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0]
|
409
342
|
if kernel_attr is None:
|
410
343
|
return 0
|
411
|
-
q_params, f_params = self.get_num_parameters()
|
344
|
+
q_params, f_params = self.get_num_parameters(fw_info)
|
412
345
|
if self.final_weights_quantization_cfg is None: # float coefficients
|
413
346
|
memory = (f_params+q_params) * FP32_BYTES_PER_PARAMETER
|
414
347
|
else:
|
@@ -418,12 +351,15 @@ class BaseNode:
|
|
418
351
|
|
419
352
|
return memory
|
420
353
|
|
421
|
-
def get_unified_weights_candidates_dict(self) -> Dict[str, Any]:
|
354
|
+
def get_unified_weights_candidates_dict(self, fw_info) -> Dict[str, Any]:
|
422
355
|
"""
|
423
356
|
In Mixed-Precision, a node's kernel can have multiple candidates for weights quantization configuration.
|
424
357
|
In order to display a single view of a node (for example, for logging in TensorBoard) we need a way
|
425
358
|
to create a single dictionary from all candidates.
|
426
|
-
This method is aimed to build such
|
359
|
+
This method is aimed to build such an unified dictionary for a node.
|
360
|
+
|
361
|
+
Args:
|
362
|
+
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
|
427
363
|
|
428
364
|
Returns: A dictionary containing information from node's weight quantization configuration candidates.
|
429
365
|
|
@@ -433,7 +369,7 @@ class BaseNode:
|
|
433
369
|
# We assume that only the kernel attribute have more than one candidate, since we only allow to
|
434
370
|
# quantize the kernel using mixed precision
|
435
371
|
# TODO: need to modify if we want to present a unified config for other attributes
|
436
|
-
kernel_attr = self.
|
372
|
+
kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0]
|
437
373
|
if kernel_attr is None:
|
438
374
|
# This node doesn't have a kernel attribute
|
439
375
|
return {}
|
@@ -501,13 +437,20 @@ class BaseNode:
|
|
501
437
|
candidates = self.get_all_weights_attr_candidates(attr)
|
502
438
|
return all(candidate == candidates[0] for candidate in candidates[1:])
|
503
439
|
|
504
|
-
def has_kernel_weight_to_quantize(self):
|
440
|
+
def has_kernel_weight_to_quantize(self, fw_info):
|
505
441
|
"""
|
506
|
-
Checks whether the node has kernel attribute that need to be quantized according to the
|
442
|
+
Checks whether the node has kernel attribute that need to be quantized according to the framework info.
|
443
|
+
|
444
|
+
Args:
|
445
|
+
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
|
507
446
|
|
508
|
-
Returns: Whether the node
|
447
|
+
Returns: Whether the node has weights that need to be quantized.
|
509
448
|
"""
|
510
|
-
|
449
|
+
attrs = fw_info.get_kernel_op_attributes(self.type)
|
450
|
+
for attr in attrs:
|
451
|
+
if attr and self.get_weights_by_keys(attr) is not None:
|
452
|
+
return True
|
453
|
+
return False
|
511
454
|
|
512
455
|
def has_any_weight_attr_to_quantize(self) -> bool:
|
513
456
|
"""
|
@@ -625,9 +568,8 @@ class BaseNode:
|
|
625
568
|
Returns: True if the node has at list one quantization configuration candidate with activation quantization enabled.
|
626
569
|
"""
|
627
570
|
|
628
|
-
return
|
629
|
-
|
630
|
-
for c in self.candidates_quantization_cfg]))
|
571
|
+
return len(self.candidates_quantization_cfg) > 0 and \
|
572
|
+
any([c.activation_quantization_cfg.enable_activation_quantization for c in self.candidates_quantization_cfg])
|
631
573
|
|
632
574
|
def get_all_weights_attr_candidates(self, attr: str) -> List[WeightsAttrQuantizationConfig]:
|
633
575
|
"""
|
@@ -643,6 +585,79 @@ class BaseNode:
|
|
643
585
|
# the inner method would log an exception.
|
644
586
|
return [c.weights_quantization_cfg.get_attr_config(attr) for c in self.candidates_quantization_cfg]
|
645
587
|
|
588
|
+
def get_qco(self, fqc: FrameworkQuantizationCapabilities) -> QuantizationConfigOptions:
|
589
|
+
"""
|
590
|
+
Get the QuantizationConfigOptions of the node according
|
591
|
+
to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformCapabilities.
|
592
|
+
|
593
|
+
Args:
|
594
|
+
fqc: FQC to extract the QuantizationConfigOptions for the node.
|
595
|
+
|
596
|
+
Returns:
|
597
|
+
QuantizationConfigOptions of the node.
|
598
|
+
"""
|
599
|
+
|
600
|
+
if fqc is None:
|
601
|
+
Logger.critical(f'Can not retrieve QC options for None FQC') # pragma: no cover
|
602
|
+
|
603
|
+
for fl, qco in fqc.filterlayer2qco.items():
|
604
|
+
if self.is_match_filter_params(fl):
|
605
|
+
return qco
|
606
|
+
# Extract qco with is_match_type to overcome mismatch of function types in TF 2.15
|
607
|
+
matching_qcos = [_qco for _type, _qco in fqc.layer2qco.items() if self.is_match_type(_type)]
|
608
|
+
if matching_qcos:
|
609
|
+
if all([_qco == matching_qcos[0] for _qco in matching_qcos]):
|
610
|
+
return matching_qcos[0]
|
611
|
+
else:
|
612
|
+
Logger.critical(f"Found duplicate qco types for node '{self.name}' of type '{self.type}'!") # pragma: no cover
|
613
|
+
return fqc.tpc.default_qco
|
614
|
+
|
615
|
+
def filter_node_qco_by_graph(self, fqc: FrameworkQuantizationCapabilities,
|
616
|
+
next_nodes: List, node_qc_options: QuantizationConfigOptions
|
617
|
+
) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
618
|
+
"""
|
619
|
+
Filter quantization config options that don't match the graph.
|
620
|
+
A node may have several quantization config options with 'activation_n_bits' values, and
|
621
|
+
the next nodes in the graph may support different bit-width as input activation. This function
|
622
|
+
filters out quantization config that don't comply to these attributes.
|
623
|
+
|
624
|
+
Args:
|
625
|
+
fqc: FQC to extract the QuantizationConfigOptions for the next nodes.
|
626
|
+
next_nodes: Output nodes of current node.
|
627
|
+
node_qc_options: Node's QuantizationConfigOptions.
|
628
|
+
|
629
|
+
Returns:
|
630
|
+
|
631
|
+
"""
|
632
|
+
# Filter quantization config options that don't match the graph.
|
633
|
+
_base_config = node_qc_options.base_config
|
634
|
+
_node_qc_options = node_qc_options.quantization_configurations
|
635
|
+
if len(next_nodes):
|
636
|
+
next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
|
637
|
+
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
|
638
|
+
for qc_opts in next_nodes_qc_options
|
639
|
+
for op_cfg in qc_opts.quantization_configurations])
|
640
|
+
|
641
|
+
# Filter node's QC options that match next nodes input bit-width.
|
642
|
+
_node_qc_options = [_option for _option in _node_qc_options
|
643
|
+
if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
|
644
|
+
if len(_node_qc_options) == 0:
|
645
|
+
Logger.critical(f"Graph doesn't match FQC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
|
646
|
+
|
647
|
+
# Verify base config match
|
648
|
+
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
|
649
|
+
for qc_opt in next_nodes_qc_options]):
|
650
|
+
# base_config activation bits doesn't match next node supported input bit-width -> replace with
|
651
|
+
# a qco from quantization_configurations with maximum activation bit-width.
|
652
|
+
if len(_node_qc_options) > 0:
|
653
|
+
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
|
654
|
+
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
|
655
|
+
Logger.warning(f"Node {self} base quantization config changed to match Graph and FQC configuration.\nCause: {self} -> {next_nodes}.")
|
656
|
+
else:
|
657
|
+
Logger.critical(f"Graph doesn't match FQC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
|
658
|
+
|
659
|
+
return _base_config, _node_qc_options
|
660
|
+
|
646
661
|
def is_match_type(self, _type: Type) -> bool:
|
647
662
|
"""
|
648
663
|
Check if input type matches the node type, either in instance type or in type name.
|
@@ -675,7 +690,7 @@ class BaseNode:
|
|
675
690
|
return False
|
676
691
|
|
677
692
|
# Get attributes from node to filter
|
678
|
-
layer_config = self.framework_attr
|
693
|
+
layer_config = self.framework_attr
|
679
694
|
if hasattr(self, "op_call_kwargs"):
|
680
695
|
layer_config.update(self.op_call_kwargs)
|
681
696
|
|
@@ -709,7 +724,7 @@ class BaseNode:
|
|
709
724
|
Logger.critical(f"SIMD is expected to be a non-positive integer but found: {_simd}")
|
710
725
|
return _simd
|
711
726
|
|
712
|
-
def sort_node_candidates(self):
|
727
|
+
def sort_node_candidates(self, fw_info):
|
713
728
|
"""
|
714
729
|
Sorts the node candidates.
|
715
730
|
We assume that the candidates are ordered in the following way (for mixed precision purposes):
|
@@ -718,12 +733,17 @@ class BaseNode:
|
|
718
733
|
- If the node doesn't have a kernel we only consider the candidate activation number of bits to sort
|
719
734
|
the candidates in descending order.
|
720
735
|
The operation is done inplace.
|
736
|
+
|
737
|
+
Args:
|
738
|
+
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
|
739
|
+
|
721
740
|
"""
|
722
|
-
if self.
|
723
|
-
|
724
|
-
|
725
|
-
|
741
|
+
if self.candidates_quantization_cfg is not None:
|
742
|
+
kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0]
|
743
|
+
if kernel_attr is not None:
|
744
|
+
self.candidates_quantization_cfg.sort(
|
745
|
+
key=lambda c: (c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits,
|
726
746
|
c.activation_quantization_cfg.activation_n_bits), reverse=True)
|
727
747
|
else:
|
728
|
-
self.
|
729
|
-
|
748
|
+
self.candidates_quantization_cfg.sort(key=lambda c: c.activation_quantization_cfg.activation_n_bits,
|
749
|
+
reverse=True)
|
@@ -1,21 +1,6 @@
|
|
1
|
-
# Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
1
|
from typing import Dict, Any, Tuple, Type, List, Union
|
17
2
|
|
18
|
-
from model_compression_toolkit.
|
3
|
+
from model_compression_toolkit.verify_packages import FOUND_TF
|
19
4
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
20
5
|
import numpy as np
|
21
6
|
|
@@ -60,7 +45,6 @@ class FunctionalNode(BaseNode):
|
|
60
45
|
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
|
61
46
|
has_activation: Whether the node has activations that we might want to quantize.
|
62
47
|
tensor_input_allocs: A list of indices and strings for allocations input tensors in the node's args and kwargs.
|
63
|
-
|
64
48
|
"""
|
65
49
|
|
66
50
|
super().__init__(name,
|
@@ -79,7 +63,6 @@ class FunctionalNode(BaseNode):
|
|
79
63
|
self.op_call_args = list(op_call_args)
|
80
64
|
self.functional_op = functional_op
|
81
65
|
self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs
|
82
|
-
self.node_fw_info = self._get_fw_node_attrs(functional_op, framework_attr)
|
83
66
|
|
84
67
|
@property
|
85
68
|
def type(self):
|
@@ -103,4 +86,4 @@ class FunctionalNode(BaseNode):
|
|
103
86
|
|
104
87
|
"""
|
105
88
|
names_match = _type.__name__ == self.type.__name__
|
106
|
-
return
|
89
|
+
return super().is_match_type(_type) or names_match
|
@@ -15,11 +15,13 @@
|
|
15
15
|
import abc
|
16
16
|
import uuid
|
17
17
|
|
18
|
+
from model_compression_toolkit.core import FrameworkInfo
|
18
19
|
from model_compression_toolkit.constants import VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX, \
|
19
20
|
VIRTUAL_WEIGHTS_SUFFIX, VIRTUAL_ACTIVATION_SUFFIX, FLOAT_BITWIDTH
|
21
|
+
from model_compression_toolkit.core.common.framework_info import DEFAULT_KERNEL_ATTRIBUTES
|
20
22
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
21
23
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
22
|
-
CandidateNodeQuantizationConfig
|
24
|
+
CandidateNodeQuantizationConfig
|
23
25
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
|
24
26
|
|
25
27
|
|
@@ -75,11 +77,8 @@ class VirtualSplitWeightsNode(VirtualSplitNode):
|
|
75
77
|
|
76
78
|
self.name = origin_node.name + VIRTUAL_WEIGHTS_SUFFIX
|
77
79
|
|
78
|
-
self.
|
79
|
-
|
80
|
-
base_quantization_cfg=None, validate=False
|
81
|
-
)
|
82
|
-
for c in self.quantization_cfg.candidates_quantization_cfg:
|
80
|
+
self.candidates_quantization_cfg = origin_node.get_unique_weights_candidates(kernel_attr)
|
81
|
+
for c in self.candidates_quantization_cfg:
|
83
82
|
c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
84
83
|
c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
|
85
84
|
|
@@ -108,9 +107,10 @@ class VirtualSplitActivationNode(VirtualSplitNode):
|
|
108
107
|
self.weights = {}
|
109
108
|
self.layer_class = activation_class
|
110
109
|
|
111
|
-
self.
|
112
|
-
|
113
|
-
|
110
|
+
self.candidates_quantization_cfg = origin_node.get_unique_activation_candidates()
|
111
|
+
for c in self.candidates_quantization_cfg:
|
112
|
+
c.weights_quantization_cfg.enable_weights_quantization = False
|
113
|
+
c.weights_quantization_cfg.weights_n_bits = FLOAT_BITWIDTH
|
114
114
|
|
115
115
|
|
116
116
|
class VirtualActivationWeightsNode(VirtualNode):
|
@@ -128,23 +128,28 @@ class VirtualActivationWeightsNode(VirtualNode):
|
|
128
128
|
|
129
129
|
def __init__(self,
|
130
130
|
act_node: BaseNode,
|
131
|
-
weights_node: BaseNode
|
131
|
+
weights_node: BaseNode,
|
132
|
+
fw_info: FrameworkInfo):
|
132
133
|
"""
|
133
134
|
Init a VirtualActivationWeightsNode object.
|
134
135
|
|
135
136
|
Args:
|
136
137
|
act_node: The original activation node.
|
137
138
|
weights_node: The original weights node.
|
139
|
+
fw_info: A FrameworkInfo object with framework specific information.
|
138
140
|
"""
|
139
141
|
# Validate weights node
|
142
|
+
kernel_attrs = fw_info.get_kernel_op_attributes(weights_node.type)
|
143
|
+
assert len(kernel_attrs) == 1 and kernel_attrs[0] is not None, f'Expected exactly one kernel attr, {kernel_attrs}'
|
144
|
+
kernel_attr = kernel_attrs[0]
|
140
145
|
conf_weights = [attr for attr in weights_node.weights if weights_node.is_configurable_weight(attr)]
|
141
|
-
if len(conf_weights) > 1 or len(conf_weights) == 1 and not weights_node.is_configurable_weight(
|
146
|
+
if len(conf_weights) > 1 or len(conf_weights) == 1 and not weights_node.is_configurable_weight(kernel_attr):
|
142
147
|
raise NotImplementedError(f'Only kernel weight can be configurable. Got configurable {conf_weights}.')
|
143
148
|
|
144
149
|
weights = weights_node.weights.copy()
|
145
150
|
act_node_w_rename = {}
|
146
151
|
if act_node.weights:
|
147
|
-
if act_node
|
152
|
+
if fw_info.get_kernel_op_attributes(act_node) != DEFAULT_KERNEL_ATTRIBUTES:
|
148
153
|
raise NotImplementedError(f'Node {act_node} with kernel cannot be used as activation for '
|
149
154
|
f'VirtualActivationWeightsNode.')
|
150
155
|
if act_node.has_any_configurable_weight():
|
@@ -152,7 +157,7 @@ class VirtualActivationWeightsNode(VirtualNode):
|
|
152
157
|
'VirtualActivationWeightsNode.')
|
153
158
|
# combine weights from activation and weights
|
154
159
|
for w_id, w in act_node.weights.items():
|
155
|
-
if w_id not in weights and not (isinstance(w_id, str) and
|
160
|
+
if w_id not in weights and not (isinstance(w_id, str) and kernel_attr in w_id):
|
156
161
|
weights[w_id] = w
|
157
162
|
continue
|
158
163
|
# if same identifier is used as in weight nodes (or contains the kernel substring), generate a new
|
@@ -180,7 +185,7 @@ class VirtualActivationWeightsNode(VirtualNode):
|
|
180
185
|
self.original_weights_node = weights_node
|
181
186
|
|
182
187
|
v_candidates = []
|
183
|
-
weights_candidates_quantization_cfg = weights_node.get_unique_weights_candidates(
|
188
|
+
weights_candidates_quantization_cfg = weights_node.get_unique_weights_candidates(kernel_attr)
|
184
189
|
for c_a in act_node.candidates_quantization_cfg:
|
185
190
|
for c_w in weights_candidates_quantization_cfg:
|
186
191
|
composed_candidate = CandidateNodeQuantizationConfig(activation_quantization_cfg=c_a.activation_quantization_cfg,
|
@@ -198,8 +203,7 @@ class VirtualActivationWeightsNode(VirtualNode):
|
|
198
203
|
v_candidates.append(composed_candidate)
|
199
204
|
|
200
205
|
# sorting the candidates by weights number of bits first and then by activation number of bits (reversed order)
|
201
|
-
v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(
|
206
|
+
v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits,
|
202
207
|
c.activation_quantization_cfg.activation_n_bits), reverse=True)
|
203
208
|
|
204
|
-
self.
|
205
|
-
base_quantization_cfg=None, validate=False)
|
209
|
+
self.candidates_quantization_cfg = v_candidates
|
@@ -37,18 +37,20 @@ def set_bit_widths(mixed_precision_enable: bool,
|
|
37
37
|
"""
|
38
38
|
if mixed_precision_enable:
|
39
39
|
assert all([len(n.candidates_quantization_cfg) > 0
|
40
|
-
for n in graph.get_configurable_sorted_nodes()]), \
|
40
|
+
for n in graph.get_configurable_sorted_nodes(graph.fw_info)]), \
|
41
41
|
"All configurable nodes in graph should have at least one candidate configuration in mixed precision mode"
|
42
42
|
|
43
43
|
# Get a list of nodes' names we need to finalize (that they have at least one weight qc candidate).
|
44
|
-
sorted_nodes_names = graph.get_configurable_sorted_nodes_names()
|
44
|
+
sorted_nodes_names = graph.get_configurable_sorted_nodes_names(graph.fw_info)
|
45
45
|
|
46
46
|
for node in graph.nodes: # set a specific node qc for each node final qc
|
47
47
|
# If it's reused, take the configuration that the base node has
|
48
48
|
node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2])
|
49
49
|
if node_name in sorted_nodes_names: # only configurable nodes are in this list
|
50
50
|
node_index_in_graph = sorted_nodes_names.index(node_name)
|
51
|
-
_set_node_final_qc(bit_widths_config[node_index_in_graph],
|
51
|
+
_set_node_final_qc(bit_widths_config[node_index_in_graph],
|
52
|
+
node,
|
53
|
+
graph.fw_info)
|
52
54
|
else:
|
53
55
|
if node.is_activation_quantization_enabled():
|
54
56
|
# If we are here, this means that we are in weights-only mixed-precision
|
@@ -81,7 +83,8 @@ def set_bit_widths(mixed_precision_enable: bool,
|
|
81
83
|
|
82
84
|
|
83
85
|
def _get_node_qc_by_bit_widths(node: BaseNode,
|
84
|
-
node_bit_width_cfg: int
|
86
|
+
node_bit_width_cfg: int,
|
87
|
+
fw_info) -> Any:
|
85
88
|
"""
|
86
89
|
Get the node's quantization configuration that
|
87
90
|
matches to the bit width index as in the MP configuration bit_width_cfg.
|
@@ -90,18 +93,21 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
|
|
90
93
|
Args:
|
91
94
|
node: Node to get its quantization configuration candidate.
|
92
95
|
node_bit_width_cfg: Configuration which determines the node's desired bit width.
|
96
|
+
fw_info: Information relevant to a specific framework about how layers should be quantized.
|
93
97
|
|
94
98
|
Returns:
|
95
99
|
Node quantization configuration if it was found, or None otherwise.
|
96
100
|
"""
|
97
101
|
# only the weights kernel attribute is quantized in weights mixed precision at the moment
|
102
|
+
kernel_attr = fw_info.get_kernel_op_attributes(node.type)
|
103
|
+
|
98
104
|
if node.is_activation_quantization_enabled():
|
99
105
|
qc = node.candidates_quantization_cfg[node_bit_width_cfg]
|
100
106
|
|
101
107
|
return qc
|
102
108
|
|
103
|
-
elif
|
104
|
-
if node.is_weights_quantization_enabled(
|
109
|
+
elif kernel_attr is not None:
|
110
|
+
if node.is_weights_quantization_enabled(kernel_attr[0]):
|
105
111
|
qc = node.candidates_quantization_cfg[node_bit_width_cfg]
|
106
112
|
|
107
113
|
return qc
|
@@ -110,7 +116,8 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
|
|
110
116
|
|
111
117
|
|
112
118
|
def _set_node_final_qc(node_bit_width_cfg: int,
|
113
|
-
node: BaseNode
|
119
|
+
node: BaseNode,
|
120
|
+
fw_info):
|
114
121
|
"""
|
115
122
|
Get the node's quantization configuration that
|
116
123
|
matches to the bit width index as in the MP configuration bit_width_cfg, and use it to finalize the node's
|
@@ -120,9 +127,12 @@ def _set_node_final_qc(node_bit_width_cfg: int,
|
|
120
127
|
Args:
|
121
128
|
node_bit_width_cfg: Configuration which determines the node's desired bit width.
|
122
129
|
node: Node to set its node quantization configuration.
|
130
|
+
fw_info: Information relevant to a specific framework about how layers should be quantized.
|
123
131
|
|
124
132
|
"""
|
125
|
-
node_qc = _get_node_qc_by_bit_widths(node,
|
133
|
+
node_qc = _get_node_qc_by_bit_widths(node,
|
134
|
+
node_bit_width_cfg,
|
135
|
+
fw_info)
|
126
136
|
|
127
137
|
if node_qc is None:
|
128
138
|
Logger.critical(f'Node {node.name} quantization configuration from configuration file' # pragma: no cover
|