mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250927.534__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -12,22 +12,74 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
import copy
|
16
|
+
from typing import List, Tuple, Dict, Optional
|
16
17
|
|
18
|
+
from mct_quantizers.common.constants import WEIGHTS_N_BITS, ACTIVATION_N_BITS
|
19
|
+
from model_compression_toolkit.constants import WEIGHTS, ACTIVATION
|
17
20
|
from model_compression_toolkit.core.common import BaseNode
|
18
|
-
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
19
21
|
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
20
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
|
21
22
|
from model_compression_toolkit.logger import Logger
|
22
|
-
from model_compression_toolkit.
|
23
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
24
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
25
|
+
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
26
|
+
CandidateNodeQuantizationConfig
|
27
|
+
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig, \
|
28
|
+
ActivationQuantizationMode
|
29
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
|
30
|
+
QuantizationErrorMethod
|
31
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
|
32
|
+
get_activation_quantization_params_fn
|
33
|
+
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
|
23
34
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
|
24
35
|
QuantizationConfigOptions
|
25
|
-
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
|
26
36
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
27
37
|
FrameworkQuantizationCapabilities
|
38
|
+
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
|
39
|
+
|
40
|
+
|
41
|
+
def set_quantization_configuration_to_graph(graph: Graph,
|
42
|
+
quant_config: QuantizationConfig,
|
43
|
+
bit_width_config: BitWidthConfig = None,
|
44
|
+
mixed_precision_enable: bool = False,
|
45
|
+
running_gptq: bool = False) -> Graph:
|
46
|
+
"""
|
47
|
+
Add quantization configuration for each graph node.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
graph (Graph): Graph for which to add quantization info to each node.
|
51
|
+
quant_config (QuantizationConfig): Quantization configuration containing parameters for how the graph should be quantized.
|
52
|
+
bit_width_config (BitWidthConfig): Configuration for manual bit width selection. Defaults to None.
|
53
|
+
mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
|
54
|
+
running_gptq (bool): Whether or not a GPTQ optimization is planned to run after the PTQ process. Defaults to False.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
Graph: The graph with quantization configurations attached to each node in it.
|
58
|
+
"""
|
59
|
+
|
60
|
+
if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
|
61
|
+
if not running_gptq:
|
62
|
+
raise ValueError(f"The HMSE error method for parameters selection is only supported when running GPTQ "
|
63
|
+
f"optimization due to long execution time that is not suitable for basic PTQ.")
|
64
|
+
Logger.warning("Using the HMSE error method for weights quantization parameters search. "
|
65
|
+
"Note: This method may significantly increase runtime during the parameter search process.")
|
66
|
+
|
67
|
+
nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
|
68
|
+
nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_weights_bit_widths(graph)
|
69
|
+
|
70
|
+
for n in graph.get_topo_sorted_nodes():
|
71
|
+
manual_bit_width_override = {ACTIVATION: nodes_to_manipulate_activation_bit_widths.get(n),
|
72
|
+
WEIGHTS: nodes_to_manipulate_weights_bit_widths.get(n)}
|
73
|
+
set_quantization_configs_to_node(node=n,
|
74
|
+
graph=graph,
|
75
|
+
quant_config=quant_config,
|
76
|
+
fw_info=graph.fw_info,
|
77
|
+
fqc=graph.fqc,
|
78
|
+
mixed_precision_enable=mixed_precision_enable,
|
79
|
+
manual_bit_width_override=manual_bit_width_override)
|
80
|
+
return graph
|
28
81
|
|
29
82
|
|
30
|
-
# TODO irena refactor (if needed) and move to load_fqc
|
31
83
|
def filter_node_qco_by_graph(node: BaseNode,
|
32
84
|
fqc: FrameworkQuantizationCapabilities,
|
33
85
|
graph: Graph,
|
@@ -50,8 +102,6 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
50
102
|
that are compatible with next nodes supported input bit-widths.
|
51
103
|
|
52
104
|
"""
|
53
|
-
from model_compression_toolkit.quantization_preparation.load_fqc import fetch_qc_options_for_node
|
54
|
-
|
55
105
|
# Filter quantization config options that don't match the graph.
|
56
106
|
_base_config = node_qc_options.base_config
|
57
107
|
_node_qc_options = node_qc_options.quantization_configurations
|
@@ -61,7 +111,7 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
61
111
|
next_nodes = []
|
62
112
|
while len(_next_nodes):
|
63
113
|
n = _next_nodes.pop(0)
|
64
|
-
qco =
|
114
|
+
qco = n.get_qco(fqc)
|
65
115
|
qp = [qc.quantization_preserving for qc in qco.quantization_configurations]
|
66
116
|
if not all(qp) and any(qp):
|
67
117
|
Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.')
|
@@ -71,8 +121,7 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
71
121
|
|
72
122
|
if len(next_nodes) == 0:
|
73
123
|
return _base_config, _node_qc_options
|
74
|
-
|
75
|
-
next_nodes_qc_options = [fetch_qc_options_for_node(_node, fqc) for _node in next_nodes]
|
124
|
+
next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
|
76
125
|
all_next_nodes_supported_input_bitwidth = [max_input_activation_n_bits(op_cfg)
|
77
126
|
for qc_opts in next_nodes_qc_options
|
78
127
|
for op_cfg in qc_opts.quantization_configurations
|
@@ -102,98 +151,368 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
102
151
|
return _base_config, _node_qc_options
|
103
152
|
|
104
153
|
|
105
|
-
def
|
154
|
+
def set_quantization_configs_to_node(node: BaseNode,
|
155
|
+
graph: Graph,
|
156
|
+
quant_config: QuantizationConfig,
|
157
|
+
fw_info: FrameworkInfo,
|
158
|
+
fqc: FrameworkQuantizationCapabilities,
|
159
|
+
mixed_precision_enable: bool = False,
|
160
|
+
manual_bit_width_override: Optional[Dict] = None):
|
106
161
|
"""
|
107
|
-
|
162
|
+
Create and set quantization configurations to a node (for both weights and activation).
|
108
163
|
|
109
164
|
Args:
|
110
|
-
|
111
|
-
|
165
|
+
node (BaseNode): Node to set its quantization configurations.
|
166
|
+
graph (Graph): Model's internal representation graph.
|
167
|
+
quant_config (QuantizationConfig): Quantization configuration to generate the node's configurations from.
|
168
|
+
fw_info (FrameworkInfo): Information needed for quantization about the specific framework.
|
169
|
+
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to get default OpQuantizationConfig.
|
170
|
+
mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
|
171
|
+
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
|
112
172
|
"""
|
113
|
-
|
114
|
-
|
173
|
+
node_qc_options = node.get_qco(fqc)
|
174
|
+
base_config, node_qc_options_list = filter_node_qco_by_graph(node, fqc, graph, node_qc_options)
|
115
175
|
|
116
|
-
|
117
|
-
|
176
|
+
# If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation and weights bits equal to manual_bit_width_override,
|
177
|
+
# and update base_config accordingly.
|
178
|
+
if manual_bit_width_override is None:
|
179
|
+
manual_bit_width_override = {ACTIVATION: None, WEIGHTS: None}
|
180
|
+
|
181
|
+
base_config, node_qc_options_list = filter_qc_options_with_manual_bit_width(
|
182
|
+
node=node,
|
183
|
+
node_qc_options_list=node_qc_options_list,
|
184
|
+
base_config=base_config,
|
185
|
+
manual_bit_width_override=manual_bit_width_override,
|
186
|
+
mixed_precision_enable=mixed_precision_enable)
|
118
187
|
|
119
|
-
|
120
|
-
|
188
|
+
# Create QC candidates for weights and activation combined
|
189
|
+
weight_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
|
190
|
+
node.candidates_quantization_cfg = _create_node_candidates_qc(quant_config,
|
191
|
+
fw_info,
|
192
|
+
weight_channel_axis,
|
193
|
+
node_qc_options_list,
|
194
|
+
base_config,
|
195
|
+
node,
|
196
|
+
mixed_precision_enable=mixed_precision_enable)
|
121
197
|
|
198
|
+
# sorting the candidates by kernel attribute weights number of bits first and then by activation number of bits
|
199
|
+
# (in reversed order). since only kernel attribute is quantized in weights mixed precision,
|
200
|
+
# if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
|
201
|
+
node.sort_node_candidates(fw_info)
|
122
202
|
|
123
|
-
|
124
|
-
|
203
|
+
for candidate_qc in node.candidates_quantization_cfg:
|
204
|
+
if candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.QUANT and \
|
205
|
+
not node.get_has_activation():
|
206
|
+
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
207
|
+
elif candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.PRESERVE_QUANT:
|
208
|
+
prev_nodes = graph.get_prev_nodes(node)
|
209
|
+
if len(prev_nodes) != 1:
|
210
|
+
# Preserving the quantization of more than 1 previous node is ambiguous, so disable it.
|
211
|
+
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.")
|
212
|
+
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
213
|
+
elif not prev_nodes[0].is_quantization_preserving() and not prev_nodes[0].is_activation_quantization_enabled():
|
214
|
+
# Preserving the quantization of an unquantized node isn't possible, so disable it.
|
215
|
+
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.")
|
216
|
+
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
217
|
+
|
218
|
+
|
219
|
+
def create_node_activation_qc(qc: QuantizationConfig,
|
220
|
+
fw_info: FrameworkInfo,
|
221
|
+
op_cfg: OpQuantizationConfig) -> NodeActivationQuantizationConfig:
|
125
222
|
"""
|
126
|
-
|
127
|
-
activation bitwidth in the base quantization config.
|
223
|
+
Create an activation quantization configuration from a QuantizationConfig object.
|
128
224
|
|
129
225
|
Args:
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
"""
|
157
|
-
|
158
|
-
|
226
|
+
qc: QuantizationConfig to create the node's config from.
|
227
|
+
fw_info: Information about the specific framework the node was created from (e.g., whether or not its
|
228
|
+
weights/activations should be quantized)
|
229
|
+
op_cfg: OpQuantizationConfig with quantizers types to set in node quantization configuration.
|
230
|
+
|
231
|
+
Returns:
|
232
|
+
Activation quantization configuration of a node.
|
233
|
+
"""
|
234
|
+
|
235
|
+
activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
|
236
|
+
if activation_quantization_fn is None:
|
237
|
+
Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
|
238
|
+
|
239
|
+
activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
|
240
|
+
|
241
|
+
return NodeActivationQuantizationConfig(qc,
|
242
|
+
op_cfg,
|
243
|
+
activation_quantization_fn,
|
244
|
+
activation_quantization_params_fn)
|
245
|
+
|
246
|
+
|
247
|
+
def _create_node_single_candidate_qc(qc: QuantizationConfig,
|
248
|
+
fw_info: FrameworkInfo,
|
249
|
+
weight_channel_axis: Tuple[int, int],
|
250
|
+
op_cfg: OpQuantizationConfig,
|
251
|
+
node_attrs_list: List[str]) -> CandidateNodeQuantizationConfig:
|
252
|
+
"""
|
253
|
+
Create quantization configuration candidate from a QuantizationConfig object.
|
254
|
+
Creates both weights and activation quantization configurations
|
255
|
+
and initialize a candidate object that encapsulates both.
|
159
256
|
|
160
257
|
Args:
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
258
|
+
qc: QuantizationConfig to create the node's config from.
|
259
|
+
fw_info: Information about the specific framework the node was created from (e.g., whether its
|
260
|
+
weights/activations should be quantized)
|
261
|
+
weight_channel_axis: (Output, Input) channel index of the node's kernel.
|
262
|
+
op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
|
263
|
+
node_attrs_list: A list of the node's weights attributes names.
|
264
|
+
|
265
|
+
Returns: a CandidateNodeQuantizationConfig object with both weights and activation quantization config objects.
|
266
|
+
|
267
|
+
"""
|
268
|
+
|
269
|
+
# parameters for weights attributes quantization are set within CandidateNodeQuantizationConfig initialization
|
270
|
+
|
271
|
+
# get parameters for activation quantization
|
272
|
+
activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
|
273
|
+
if activation_quantization_fn is None:
|
274
|
+
Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
|
275
|
+
|
276
|
+
activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
|
277
|
+
|
278
|
+
# TODO: remove this validation and warning once enabling all attributes quantization by default
|
279
|
+
attrs_with_enabled_quantization = [attr for attr, cfg in op_cfg.attr_weights_configs_mapping.items()
|
280
|
+
if cfg.enable_weights_quantization]
|
281
|
+
if len(attrs_with_enabled_quantization) > 1:
|
282
|
+
Logger.warning(f"Multiple weights attributes quantization is enabled via the provided FQC."
|
283
|
+
f"Quantizing any attribute other than the kernel is experimental "
|
284
|
+
f"and may be subject to unstable behavior."
|
285
|
+
f"Attributes with enabled weights quantization: {attrs_with_enabled_quantization}.")
|
286
|
+
|
287
|
+
return CandidateNodeQuantizationConfig(qc=qc,
|
288
|
+
op_cfg=op_cfg,
|
289
|
+
activation_quantization_fn=activation_quantization_fn,
|
290
|
+
activation_quantization_params_fn=activation_quantization_params_fn,
|
291
|
+
weights_channels_axis=weight_channel_axis,
|
292
|
+
node_attrs_list=node_attrs_list)
|
293
|
+
|
294
|
+
|
295
|
+
def _create_node_candidates_qc(qc: QuantizationConfig,
|
296
|
+
fw_info: FrameworkInfo,
|
297
|
+
weight_channel_axis: Tuple[int, int],
|
298
|
+
node_qc_options_list: List[OpQuantizationConfig],
|
299
|
+
base_config: OpQuantizationConfig,
|
300
|
+
node: BaseNode,
|
301
|
+
mixed_precision_enable: bool = False) -> List[CandidateNodeQuantizationConfig]:
|
302
|
+
"""
|
303
|
+
Create a list of candidates of weights and activation quantization configurations for a node.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
qc (QuantizationConfig): Quantization configuration the quantization process should follow.
|
307
|
+
fw_info (FrameworkInfo): Framework information (e.g., which layers should have their kernels quantized).
|
308
|
+
weight_channel_axis (Tuple[int, int]): (Output, Input) channel index of the node's kernel.
|
309
|
+
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs of node.
|
310
|
+
base_config (OpQuantizationConfig): Base quantization config for node.
|
311
|
+
node (BaseNode): A node to set quantization configuration candidates to.
|
312
|
+
mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
|
313
|
+
|
314
|
+
Returns:
|
315
|
+
List[CandidateNodeQuantizationConfig]: List of candidates of weights quantization configurations to set for a node.
|
316
|
+
"""
|
317
|
+
|
318
|
+
candidates = []
|
319
|
+
node_attrs_list = node.get_node_weights_attributes()
|
320
|
+
|
321
|
+
if mixed_precision_enable:
|
322
|
+
for op_cfg in node_qc_options_list:
|
323
|
+
candidate_qc = copy.deepcopy(qc)
|
324
|
+
candidates.append(_create_node_single_candidate_qc(candidate_qc,
|
325
|
+
fw_info,
|
326
|
+
weight_channel_axis,
|
327
|
+
op_cfg,
|
328
|
+
node_attrs_list))
|
329
|
+
|
330
|
+
else:
|
331
|
+
candidates.append(_create_node_single_candidate_qc(qc,
|
332
|
+
fw_info,
|
333
|
+
weight_channel_axis,
|
334
|
+
base_config,
|
335
|
+
node_attrs_list))
|
336
|
+
|
337
|
+
return candidates
|
338
|
+
|
339
|
+
|
340
|
+
def filter_qc_options_with_manual_bit_width(
|
341
|
+
node: BaseNode,
|
342
|
+
node_qc_options_list: List[OpQuantizationConfig],
|
343
|
+
base_config: OpQuantizationConfig,
|
344
|
+
manual_bit_width_override: Optional[Dict],
|
345
|
+
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
346
|
+
"""
|
347
|
+
Update the quantization configurations for a node, allowing manual bit-width overrides if specified.
|
348
|
+
|
349
|
+
Args:
|
350
|
+
node (BaseNode): A node to set quantization configuration candidates to.
|
351
|
+
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
|
352
|
+
base_config (OpQuantizationConfig): Base quantization config for the node.
|
353
|
+
manual_bit_width_override (Optional[Dict]): Specifies a custom bit-width to override the node's activation and weights bit-width.
|
354
|
+
mixed_precision_enable (bool): Whether mixed precision is enabled.
|
355
|
+
|
356
|
+
Returns:
|
357
|
+
Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
|
358
|
+
"""
|
359
|
+
base_config, node_qc_options_list = filter_activation_qc_options_with_manual_bit_width(node,
|
360
|
+
node_qc_options_list,
|
361
|
+
base_config,
|
362
|
+
manual_bit_width_override.get(ACTIVATION),
|
363
|
+
mixed_precision_enable)
|
364
|
+
|
365
|
+
base_config, node_qc_options_list = filter_weights_qc_options_with_manual_bit_width(node,
|
366
|
+
node_qc_options_list,
|
367
|
+
base_config,
|
368
|
+
manual_bit_width_override.get(WEIGHTS),
|
369
|
+
mixed_precision_enable)
|
370
|
+
return base_config, node_qc_options_list
|
371
|
+
|
372
|
+
|
373
|
+
def filter_activation_qc_options_with_manual_bit_width(
|
374
|
+
node: BaseNode,
|
375
|
+
node_qc_options_list: List[OpQuantizationConfig],
|
376
|
+
base_config: OpQuantizationConfig,
|
377
|
+
activation_manual_bit_width_override: Optional[int],
|
378
|
+
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
379
|
+
"""
|
380
|
+
Update the activation quantization configurations for a node, allowing manual bit-width overrides if specified.
|
381
|
+
|
382
|
+
Args:
|
383
|
+
node (BaseNode): A node to set quantization configuration candidates to.
|
384
|
+
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
|
385
|
+
base_config (OpQuantizationConfig): Base quantization config for the node.
|
386
|
+
activation_manual_bit_width_override (Optional[Dict]): Specifies a custom bit-width to override the node's activation bit-width.
|
387
|
+
mixed_precision_enable (bool): Whether mixed precision is enabled.
|
388
|
+
|
389
|
+
Returns:
|
390
|
+
Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
|
391
|
+
"""
|
392
|
+
if activation_manual_bit_width_override is None:
|
393
|
+
return base_config, node_qc_options_list
|
394
|
+
|
395
|
+
# Filter node_qc_options_list to retain only the options with activation bits equal to activation_manual_bit_width_override.
|
396
|
+
node_qc_options_list = [op_cfg for op_cfg in node_qc_options_list if
|
397
|
+
activation_manual_bit_width_override == op_cfg.activation_n_bits]
|
398
|
+
if len(node_qc_options_list) == 0:
|
399
|
+
Logger.critical(f"Manually selected activation bit-width {activation_manual_bit_width_override} is invalid for node {node}.")
|
400
|
+
else:
|
401
|
+
# Update the base_config to one of the values from the filtered node_qc_options_list.
|
402
|
+
# First, check if a configuration similar to the original base_config but with activation bits equal to activation_manual_bit_width_override exists.
|
403
|
+
# If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
|
404
|
+
Logger.info(f"Setting node {node} bit-width to manually selected bit-width: {activation_manual_bit_width_override} bits.")
|
405
|
+
updated_base_config = base_config.clone_and_edit({ACTIVATION_N_BITS, activation_manual_bit_width_override})
|
406
|
+
if updated_base_config in node_qc_options_list:
|
407
|
+
# If a base_config with the specified activation_manual_bit_width_override exists in the node_qc_options_list,
|
408
|
+
# point the base_config to this option.
|
409
|
+
base_config = node_qc_options_list[node_qc_options_list.index(updated_base_config)]
|
410
|
+
else:
|
411
|
+
# Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
|
412
|
+
base_config = node_qc_options_list[0]
|
413
|
+
if len(node_qc_options_list) > 0 and not mixed_precision_enable:
|
414
|
+
Logger.info(
|
415
|
+
f"Request received to select {activation_manual_bit_width_override} activation bits. However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
|
416
|
+
f" Overriding base_config with an option that uses {activation_manual_bit_width_override} bit activations.") # pragma: no cover
|
417
|
+
|
418
|
+
return base_config, node_qc_options_list
|
419
|
+
|
420
|
+
|
421
|
+
def filter_weights_qc_options_with_manual_bit_width(
|
422
|
+
node: BaseNode,
|
423
|
+
node_qc_options_list: List[OpQuantizationConfig],
|
424
|
+
base_config: OpQuantizationConfig,
|
425
|
+
weights_manual_bit_width_override: Optional[Tuple[int, WeightAttrT]],
|
426
|
+
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
427
|
+
"""
|
428
|
+
Update the weights quantization configurations for a node, allowing manual bit-width overrides if specified.
|
429
|
+
|
430
|
+
Args:
|
431
|
+
node (BaseNode): A node to set quantization configuration candidates to.
|
432
|
+
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
|
433
|
+
base_config (OpQuantizationConfig): Base quantization config for the node.
|
434
|
+
weights_manual_bit_width_override (Optional[[int, WeightAttrT]]): Specifies a custom bit-width to override the node's weights bit-width.
|
435
|
+
mixed_precision_enable (bool): Whether mixed precision is enabled.
|
436
|
+
|
437
|
+
Returns:
|
438
|
+
Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
|
439
|
+
"""
|
440
|
+
if not weights_manual_bit_width_override:
|
441
|
+
return base_config, node_qc_options_list
|
442
|
+
|
443
|
+
# Filter node_qc_options_list to retain only the options with weights bits equal to weights_manual_bit_width_override.
|
444
|
+
node_qc_options_weights_list = _filter_options(node_qc_options_list, weights_manual_bit_width_override)
|
445
|
+
|
446
|
+
if len(node_qc_options_weights_list) == 0:
|
447
|
+
Logger.critical(f"Manually selected weights bit-width {weights_manual_bit_width_override} is invalid for node {node}.")
|
448
|
+
else:
|
449
|
+
# Update the base_config to one of the values from the filtered node_qc_options_list.
|
450
|
+
# First, check if a configuration similar to the original base_config but with weights bits equal to weights_manual_bit_width_override exists.
|
451
|
+
# If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
|
452
|
+
updated_base_config = base_config.clone_and_edit()
|
453
|
+
|
454
|
+
for bit_width, attr in weights_manual_bit_width_override:
|
455
|
+
Logger.info(f"Setting node {node} bit-width to manually selected {attr} bit-width: {bit_width} bits.")
|
456
|
+
updated_base_config = updated_base_config.clone_and_edit(attr_to_edit={attr : {WEIGHTS_N_BITS: bit_width}})
|
457
|
+
|
458
|
+
if updated_base_config in node_qc_options_weights_list:
|
459
|
+
# If a base_config with the specified weights_manual_bit_width_override exists in the node_qc_options_list,
|
460
|
+
# point the base_config to this option.
|
461
|
+
base_config = node_qc_options_weights_list[node_qc_options_weights_list.index(updated_base_config)]
|
462
|
+
else:
|
463
|
+
# Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
|
464
|
+
base_config = node_qc_options_weights_list[0]
|
465
|
+
if len(node_qc_options_weights_list) > 0 and not mixed_precision_enable:
|
466
|
+
Logger.info(
|
467
|
+
f"Request received to select weights bit-widths {weights_manual_bit_width_override}."
|
468
|
+
f"However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
|
469
|
+
f" Overriding base_config with an option that uses manually selected weights bit-widths {weights_manual_bit_width_override}.") # pragma: no cover
|
470
|
+
|
471
|
+
return base_config, node_qc_options_weights_list
|
472
|
+
|
473
|
+
|
474
|
+
def _is_valid_option(
|
475
|
+
op_cfg: OpQuantizationConfig,
|
476
|
+
attr: WeightAttrT,
|
477
|
+
bit_width: int) -> bool:
|
478
|
+
"""
|
479
|
+
Judge whether the specified option is valid based on the specified attribute and bit width.
|
480
|
+
|
481
|
+
Args:
|
482
|
+
op_cfg (OpQuantizationConfig): The quantization configuration to be judged.
|
483
|
+
attr (WeightAttrT): The filtered node's attributes to apply bit-width manipulation to.
|
484
|
+
bit_width (int): The bit width to be applied to the selected nodes.
|
485
|
+
|
486
|
+
Returns:
|
487
|
+
Result to judge whether the specified option is valid based on the specified attribute and bit width
|
488
|
+
"""
|
489
|
+
weights_attrs = op_cfg.attr_weights_configs_mapping.keys()
|
490
|
+
|
491
|
+
if attr not in weights_attrs:
|
492
|
+
return False
|
493
|
+
|
494
|
+
weights_n_bits = op_cfg.attr_weights_configs_mapping[attr].weights_n_bits
|
495
|
+
return weights_n_bits == bit_width
|
496
|
+
|
497
|
+
|
498
|
+
def _filter_options(
|
499
|
+
node_qc_options_list: List[OpQuantizationConfig],
|
500
|
+
weights_manual_bit_width_override: Tuple[int, WeightAttrT]) -> List[OpQuantizationConfig]:
|
501
|
+
"""
|
502
|
+
Filter the options based on the specified bit width and attribute.
|
503
|
+
|
504
|
+
Args:
|
505
|
+
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
|
506
|
+
weights_manual_bit_width_override (Tuple[int, WeightAttrT])): Specifies a custom bit-width to override the node's weights bit-width.
|
507
|
+
|
508
|
+
Returns:
|
509
|
+
List[OpQuantizationConfig]: Filtered the options based on the specified bit width and attribute.
|
510
|
+
"""
|
511
|
+
filtered_options = []
|
512
|
+
|
513
|
+
for bit_width, attr in weights_manual_bit_width_override:
|
514
|
+
for op_cfg in node_qc_options_list:
|
515
|
+
if _is_valid_option(op_cfg, attr, bit_width):
|
516
|
+
filtered_options.append(op_cfg)
|
517
|
+
|
518
|
+
return filtered_options
|
@@ -38,16 +38,18 @@ def apply_activation_bias_correction_to_graph(graph: Graph,
|
|
38
38
|
|
39
39
|
for n in graph.nodes:
|
40
40
|
# Activation bias correction is only relevant for nodes with kernel op
|
41
|
-
|
41
|
+
kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
|
42
|
+
if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \
|
42
43
|
n.final_activation_quantization_cfg.activation_bias_correction_term is not None:
|
43
44
|
# If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was
|
44
45
|
# calculated during model preparation, and is used now in the node's bias term.
|
45
|
-
_apply_activation_bias_correction_to_node(n, fw_impl)
|
46
|
+
_apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
|
46
47
|
return graph
|
47
48
|
|
48
49
|
|
49
50
|
def _apply_activation_bias_correction_to_node(node: BaseNode,
|
50
|
-
fw_impl: FrameworkImplementation
|
51
|
+
fw_impl: FrameworkImplementation,
|
52
|
+
qc: QuantizationConfig):
|
51
53
|
"""
|
52
54
|
Set new bias to node using the activation bias correction term that is stored in the
|
53
55
|
final activation quantization configuration.
|
@@ -55,6 +57,7 @@ def _apply_activation_bias_correction_to_node(node: BaseNode,
|
|
55
57
|
Args:
|
56
58
|
node: Node to set its corrected bias after activation bias correction.
|
57
59
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
60
|
+
qc: QuantizationConfig containing parameters of how the model should be quantized.
|
58
61
|
|
59
62
|
"""
|
60
63
|
correction = node.final_activation_quantization_cfg.activation_bias_correction_term
|
@@ -70,6 +73,7 @@ def _apply_activation_bias_correction_to_node(node: BaseNode,
|
|
70
73
|
# Configure the quantization of the bias as disabled.
|
71
74
|
node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
|
72
75
|
WeightsAttrQuantizationConfig(
|
76
|
+
qc,
|
73
77
|
AttributeQuantizationConfig(
|
74
78
|
enable_weights_quantization=False)))
|
75
79
|
else:
|