mct-nightly 2.4.0.20250924.535__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py
CHANGED
@@ -14,6 +14,8 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
import copy
|
16
16
|
|
17
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
18
|
+
from model_compression_toolkit.core import CoreConfig
|
17
19
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
18
20
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
19
21
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
|
@@ -21,6 +23,7 @@ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_s
|
|
21
23
|
|
22
24
|
|
23
25
|
def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
|
26
|
+
core_config: CoreConfig,
|
24
27
|
fw_impl: FrameworkImplementation) -> Graph:
|
25
28
|
"""
|
26
29
|
Get a graph, where each node has a final weights quantization configuration (with a bias
|
@@ -28,6 +31,7 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
|
|
28
31
|
|
29
32
|
Args:
|
30
33
|
graph_to_apply_bias_correction: Graph to apply bias correction to.
|
34
|
+
core_config: CoreConfig containing parameters of how the model should be quantized.
|
31
35
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
32
36
|
|
33
37
|
Returns:
|
@@ -36,14 +40,21 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
|
|
36
40
|
|
37
41
|
graph = copy.deepcopy(graph_to_apply_bias_correction)
|
38
42
|
for n in graph.nodes:
|
39
|
-
|
40
|
-
|
41
|
-
|
43
|
+
# bias correction is only relevant for nodes with kernel op
|
44
|
+
kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
|
45
|
+
if core_config.quantization_config.weights_bias_correction and kernel_attr is not None and \
|
46
|
+
n.is_weights_quantization_enabled(kernel_attr) and \
|
47
|
+
not n.final_weights_quantization_cfg.weights_second_moment_correction:
|
48
|
+
# If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg,
|
49
|
+
# a bias correction term was calculated during model preparation, and is used now in the node's bias term.
|
50
|
+
if n.final_weights_quantization_cfg.weights_bias_correction:
|
51
|
+
_apply_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
|
42
52
|
return graph
|
43
53
|
|
44
54
|
|
45
55
|
def _apply_bias_correction_to_node(node: BaseNode,
|
46
|
-
fw_impl: FrameworkImplementation
|
56
|
+
fw_impl: FrameworkImplementation,
|
57
|
+
qc: QuantizationConfig):
|
47
58
|
"""
|
48
59
|
Set new bias to node using the bias-correction term that is stored in the
|
49
60
|
final weights quantization configuration.
|
@@ -67,5 +78,7 @@ def _apply_bias_correction_to_node(node: BaseNode,
|
|
67
78
|
node.set_weights_by_keys(fw_impl.constants.BIAS, - correction)
|
68
79
|
node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node.
|
69
80
|
node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
|
70
|
-
WeightsAttrQuantizationConfig(
|
71
|
-
|
81
|
+
WeightsAttrQuantizationConfig(
|
82
|
+
qc,
|
83
|
+
AttributeQuantizationConfig(
|
84
|
+
enable_weights_quantization=False)))
|
@@ -24,7 +24,7 @@ from model_compression_toolkit.core.common.model_builder_mode import ModelBuilde
|
|
24
24
|
from model_compression_toolkit.core.common.model_collector import ModelCollector
|
25
25
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
26
26
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
|
27
|
-
import
|
27
|
+
import get_activations_qparams
|
28
28
|
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
|
29
29
|
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
|
30
30
|
|
@@ -32,6 +32,7 @@ from model_compression_toolkit.core.common.substitutions.apply_substitutions imp
|
|
32
32
|
def _collect_and_assign_act_threshold(graph: Graph,
|
33
33
|
representative_data_gen: Callable,
|
34
34
|
core_config: CoreConfig,
|
35
|
+
fw_info: FrameworkInfo,
|
35
36
|
fw_impl: FrameworkImplementation):
|
36
37
|
"""
|
37
38
|
Collect statistics after second moment correction and assign new thresholds to activations.
|
@@ -40,32 +41,36 @@ def _collect_and_assign_act_threshold(graph: Graph,
|
|
40
41
|
representative_data_gen (Callable): Dataset used for calibration.
|
41
42
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be
|
42
43
|
quantized, including mixed precision parameters.
|
44
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
43
45
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
44
46
|
"""
|
45
47
|
|
46
48
|
mi = ModelCollector(graph,
|
47
49
|
fw_impl,
|
48
|
-
|
50
|
+
fw_info,
|
51
|
+
core_config.quantization_config) # Mark points for statistics collection
|
49
52
|
|
50
53
|
for _data in tqdm(representative_data_gen()):
|
51
54
|
mi.infer(_data)
|
52
55
|
|
53
|
-
for n in graph.nodes:
|
56
|
+
for n in list(graph.nodes):
|
54
57
|
if n.is_activation_quantization_enabled():
|
55
|
-
activation_params =
|
56
|
-
|
57
|
-
|
58
|
-
|
58
|
+
activation_params = get_activations_qparams(
|
59
|
+
activation_quant_cfg=n.final_activation_quantization_cfg,
|
60
|
+
nodes_prior_info=n.prior_info,
|
61
|
+
out_stats_container=graph.get_out_stats_collector(n))
|
59
62
|
n.final_activation_quantization_cfg.set_activation_quantization_param(activation_params)
|
60
63
|
|
61
64
|
|
62
65
|
def quantized_model_builder_for_second_moment_correction(graph: common.Graph,
|
66
|
+
fw_info: FrameworkInfo,
|
63
67
|
fw_impl: Any):
|
64
68
|
"""
|
65
69
|
Build a framework model from a graph for second moment correction.
|
66
70
|
|
67
71
|
Args:
|
68
|
-
graph: Graph to build from.
|
72
|
+
graph: Graph to build the from.
|
73
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
69
74
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
70
75
|
|
71
76
|
Returns:
|
@@ -74,13 +79,15 @@ def quantized_model_builder_for_second_moment_correction(graph: common.Graph,
|
|
74
79
|
quantized_tg = quantize_graph_weights(graph)
|
75
80
|
|
76
81
|
quantized_model, user_info = fw_impl.model_builder(quantized_tg,
|
77
|
-
mode=ModelBuilderMode.FLOAT
|
82
|
+
mode=ModelBuilderMode.FLOAT,
|
83
|
+
fw_info=fw_info)
|
78
84
|
return quantized_model
|
79
85
|
|
80
86
|
|
81
87
|
def apply_second_moment_correction_to_graph(graph: Graph,
|
82
88
|
representative_data_gen: Callable,
|
83
89
|
core_config: CoreConfig,
|
90
|
+
fw_info: FrameworkInfo,
|
84
91
|
fw_impl: FrameworkImplementation) -> Graph:
|
85
92
|
"""
|
86
93
|
Apply second moment correction on graph.
|
@@ -89,14 +96,15 @@ def apply_second_moment_correction_to_graph(graph: Graph,
|
|
89
96
|
representative_data_gen (Callable): Dataset used for calibration.
|
90
97
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be
|
91
98
|
quantized, including mixed precision parameters.
|
99
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
92
100
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
93
101
|
|
94
102
|
Returns:
|
95
103
|
Graph after second moment correction.
|
96
104
|
"""
|
97
|
-
semi_quantized_model = quantized_model_builder_for_second_moment_correction(graph, fw_impl)
|
105
|
+
semi_quantized_model = quantized_model_builder_for_second_moment_correction(graph, fw_info, fw_impl)
|
98
106
|
fw_impl.apply_second_moment_correction(semi_quantized_model, core_config, representative_data_gen, graph)
|
99
107
|
graph = substitute(graph, fw_impl.get_substitutions_after_second_moment_correction(core_config.quantization_config))
|
100
|
-
_collect_and_assign_act_threshold(graph, representative_data_gen, core_config, fw_impl)
|
108
|
+
_collect_and_assign_act_threshold(graph, representative_data_gen, core_config, fw_info, fw_impl)
|
101
109
|
|
102
110
|
return graph
|
@@ -18,7 +18,7 @@ from typing import Any, Callable
|
|
18
18
|
from model_compression_toolkit.core import QuantizationConfig
|
19
19
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
20
20
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
|
-
from model_compression_toolkit.core.common.
|
21
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
22
22
|
|
23
23
|
|
24
24
|
def get_previous_node_with_activation_quantization(linear_node: BaseNode,
|
@@ -64,11 +64,11 @@ def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray:
|
|
64
64
|
|
65
65
|
def compute_activation_bias_correction(graph: Graph,
|
66
66
|
quant_config: QuantizationConfig,
|
67
|
+
fw_info: FrameworkInfo,
|
67
68
|
fw_impl: FrameworkImplementation,
|
68
69
|
linear_node: BaseNode,
|
69
70
|
prev_node: BaseNode,
|
70
|
-
kernel_size: str
|
71
|
-
get_activation_quantization_fn_factory: Callable) -> Graph:
|
71
|
+
kernel_size: str) -> Graph:
|
72
72
|
"""
|
73
73
|
Compute the activation bias correction term, and store it in the final activation
|
74
74
|
quantization configuration.
|
@@ -76,11 +76,11 @@ def compute_activation_bias_correction(graph: Graph,
|
|
76
76
|
Args:
|
77
77
|
graph: Graph with nodes to compute the activation bias correction for each node's final activation quantization configuration.
|
78
78
|
quant_config: QuantizationConfig of how the model should be quantized.
|
79
|
+
fw_info: Framework info like lists of nodes their kernel should quantized.
|
79
80
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
80
81
|
linear_node: Node to compute the activation bias correction for.
|
81
82
|
prev_node: Node to compute the activation error caused by his activation quantization.
|
82
83
|
kernel_size: The framework specific attribute name of the convolution layer's kernel size.
|
83
|
-
get_activation_quantization_fn_factory: activation quantization functions factory.
|
84
84
|
|
85
85
|
Returns:
|
86
86
|
Graph with activation bias correction term for each node.
|
@@ -107,9 +107,7 @@ def compute_activation_bias_correction(graph: Graph,
|
|
107
107
|
float_centers = calculate_bin_centers(float_bins)
|
108
108
|
|
109
109
|
# Quantize the bin edges and calculate the centers of the quantized bins
|
110
|
-
|
111
|
-
get_activation_quantization_fn_factory)
|
112
|
-
quant_bins = activation_quantizer(fw_impl.to_tensor(float_bins))
|
110
|
+
quant_bins = prev_node_act_quant_cfg.quantize_node_output(fw_impl.to_tensor(float_bins))
|
113
111
|
quant_bins = fw_impl.to_numpy(quant_bins)
|
114
112
|
quant_centers = calculate_bin_centers(quant_bins)
|
115
113
|
|
@@ -129,18 +127,19 @@ def compute_activation_bias_correction(graph: Graph,
|
|
129
127
|
if normalized_bias < quant_config.activation_bias_correction_threshold:
|
130
128
|
return graph
|
131
129
|
|
132
|
-
kernel = linear_node.get_weights_by_keys(linear_node.
|
130
|
+
kernel = linear_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(linear_node.type)[0])
|
133
131
|
|
134
132
|
# Compute the activation bias correction by applying the quantization error to the kernel, resulting in an output
|
135
133
|
# size matching the number of output channels.
|
136
134
|
if kernel is not None:
|
137
135
|
|
138
136
|
# Get the axes that are not the output channel.
|
137
|
+
output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type)
|
139
138
|
axis_not_output_channel = list(range(len(kernel.shape)))
|
140
|
-
axis_not_output_channel.remove(
|
139
|
+
axis_not_output_channel.remove(output_channel_index)
|
141
140
|
|
142
141
|
# Special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters.
|
143
|
-
if
|
142
|
+
if output_channel_index == input_channel_index:
|
144
143
|
axis_not_output_channel.remove(3) # 3 is the depth multiplier index.
|
145
144
|
|
146
145
|
activation_bias_correction_term = mean_diff * np.sum(kernel, axis=tuple(axis_not_output_channel))
|
@@ -151,20 +150,21 @@ def compute_activation_bias_correction(graph: Graph,
|
|
151
150
|
|
152
151
|
def compute_activation_bias_correction_of_graph(graph: Graph,
|
153
152
|
quant_config: QuantizationConfig,
|
153
|
+
fw_info: FrameworkInfo,
|
154
154
|
fw_impl: FrameworkImplementation,
|
155
155
|
activation_bias_correction_node_matchers: Callable,
|
156
|
-
kernel_size: str
|
157
|
-
get_activation_quantization_fn_factory: Callable) -> Graph:
|
156
|
+
kernel_size: str) -> Graph:
|
158
157
|
"""
|
159
158
|
Compute the activation bias correction term for the graph.
|
160
159
|
|
161
160
|
Args:
|
162
161
|
graph: Graph with nodes to compute the activation bias correction.
|
163
162
|
quant_config: QuantizationConfig of how the model should be quantized.
|
163
|
+
fw_info: Framework info like lists of nodes their kernel should quantized.
|
164
164
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
165
165
|
activation_bias_correction_node_matchers: Function to match the layers for activation bias correction.
|
166
166
|
kernel_size: The framework specific attribute name of the convolution layer's kernel size.
|
167
|
-
|
167
|
+
|
168
168
|
|
169
169
|
Returns:
|
170
170
|
Graph with activation bias correction term for each relevant node.
|
@@ -177,9 +177,9 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
|
|
177
177
|
if prev_node is not None:
|
178
178
|
graph = compute_activation_bias_correction(graph=graph,
|
179
179
|
quant_config=quant_config,
|
180
|
+
fw_info=fw_info,
|
180
181
|
fw_impl=fw_impl,
|
181
182
|
linear_node=n,
|
182
183
|
prev_node=prev_node,
|
183
|
-
kernel_size=kernel_size
|
184
|
-
get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
|
184
|
+
kernel_size=kernel_size)
|
185
185
|
return graph
|
model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py
CHANGED
@@ -18,6 +18,7 @@ from typing import Any
|
|
18
18
|
import numpy as np
|
19
19
|
|
20
20
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
21
22
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
22
23
|
from model_compression_toolkit.core.common.quantization.quantize_node import get_quantized_weights_attr_by_qc
|
23
24
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
@@ -25,6 +26,7 @@ from model_compression_toolkit.logger import Logger
|
|
25
26
|
|
26
27
|
|
27
28
|
def compute_bias_correction_of_graph(graph: Graph,
|
29
|
+
fw_info: FrameworkInfo,
|
28
30
|
fw_impl: FrameworkImplementation) -> Graph:
|
29
31
|
"""
|
30
32
|
For each node in a graph, and for each candidate weights quantization configuration,
|
@@ -33,6 +35,7 @@ def compute_bias_correction_of_graph(graph: Graph,
|
|
33
35
|
Args:
|
34
36
|
graph: Graph with nodes to compute the bias correction for
|
35
37
|
each node's weights quantization configuration candidates.
|
38
|
+
fw_info: Framework info like lists of nodes their kernel should quantized.
|
36
39
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
37
40
|
|
38
41
|
Returns:
|
@@ -43,14 +46,25 @@ def compute_bias_correction_of_graph(graph: Graph,
|
|
43
46
|
for n in graph.nodes:
|
44
47
|
# Bias correction is computed based on the quantized kernel, so we need to get the specific kernel attribute
|
45
48
|
# name out of all the weights attributes of the node.
|
46
|
-
if
|
47
|
-
|
48
|
-
|
49
|
+
if fw_info.is_kernel_op(n.type):
|
50
|
+
kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
|
51
|
+
if n.is_weights_quantization_enabled(kernel_attr):
|
52
|
+
# Bias correction is not applied to layers with constant inputs.
|
53
|
+
if n.has_positional_weights:
|
54
|
+
for candidate_qc in n.candidates_quantization_cfg:
|
55
|
+
candidate_qc.weights_quantization_cfg.weights_bias_correction = False
|
56
|
+
else:
|
57
|
+
_compute_bias_correction_per_candidate_qc(n,
|
58
|
+
kernel_attr,
|
59
|
+
fw_info,
|
60
|
+
graph.get_in_stats_collector(n),
|
61
|
+
fw_impl=fw_impl)
|
49
62
|
return graph
|
50
63
|
|
51
64
|
|
52
65
|
def _compute_bias_correction_per_candidate_qc(node: BaseNode,
|
53
66
|
kernel_attr: str,
|
67
|
+
fw_info: FrameworkInfo,
|
54
68
|
node_in_stats_collector: BaseStatsCollector,
|
55
69
|
fw_impl: FrameworkImplementation):
|
56
70
|
"""
|
@@ -60,13 +74,15 @@ def _compute_bias_correction_per_candidate_qc(node: BaseNode,
|
|
60
74
|
Args:
|
61
75
|
node: Node to compute the bias correction for its different candidates.
|
62
76
|
kernel_attr: The name of the kernel attribute of the node.
|
77
|
+
fw_info: Framework info like lists of nodes their kernel should quantized.
|
63
78
|
node_in_stats_collector: Statistics collector of the node for the mean per-channel.
|
64
79
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
65
80
|
|
66
81
|
"""
|
67
82
|
|
68
83
|
for candidate_qc in node.candidates_quantization_cfg:
|
69
|
-
if
|
84
|
+
if candidate_qc.weights_quantization_cfg.weights_bias_correction and not \
|
85
|
+
candidate_qc.weights_quantization_cfg.weights_second_moment_correction:
|
70
86
|
|
71
87
|
quantized_kernel, io_channels_axes = get_quantized_weights_attr_by_qc(kernel_attr,
|
72
88
|
node,
|
@@ -32,6 +32,7 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
|
|
32
32
|
|
33
33
|
def statistics_correction_runner(transformed_graph: Graph,
|
34
34
|
core_config: CoreConfig,
|
35
|
+
fw_info: FrameworkInfo,
|
35
36
|
fw_impl: FrameworkImplementation,
|
36
37
|
tb_w: TensorboardWriter = None, ) -> Graph:
|
37
38
|
"""
|
@@ -40,6 +41,7 @@ def statistics_correction_runner(transformed_graph: Graph,
|
|
40
41
|
transformed_graph: Graph to add statistics correction.
|
41
42
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be
|
42
43
|
quantized, including mixed precision parameters.
|
44
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
43
45
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
44
46
|
tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
|
45
47
|
|
@@ -56,9 +58,9 @@ def statistics_correction_runner(transformed_graph: Graph,
|
|
56
58
|
########################################################
|
57
59
|
# Compute bias correction to nodes' config candidates
|
58
60
|
########################################################
|
59
|
-
|
60
|
-
|
61
|
-
|
61
|
+
tg_with_bias = compute_bias_correction_of_graph(tg_with_bias,
|
62
|
+
fw_info,
|
63
|
+
fw_impl)
|
62
64
|
|
63
65
|
if tb_w is not None:
|
64
66
|
tb_w.add_graph(tg_with_bias, 'statistics_computation')
|
@@ -69,6 +71,7 @@ def statistics_correction_runner(transformed_graph: Graph,
|
|
69
71
|
def apply_statistics_correction(transformed_graph: Graph,
|
70
72
|
representative_data_gen: Callable,
|
71
73
|
core_config: CoreConfig,
|
74
|
+
fw_info: FrameworkInfo,
|
72
75
|
fw_impl: FrameworkImplementation,
|
73
76
|
tb_w: TensorboardWriter = None, ) -> Graph:
|
74
77
|
"""
|
@@ -78,6 +81,7 @@ def apply_statistics_correction(transformed_graph: Graph,
|
|
78
81
|
representative_data_gen (Callable): Dataset used for calibration.
|
79
82
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be
|
80
83
|
quantized, including mixed precision parameters.
|
84
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
81
85
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
82
86
|
tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
|
83
87
|
|
@@ -90,13 +94,14 @@ def apply_statistics_correction(transformed_graph: Graph,
|
|
90
94
|
#############################################
|
91
95
|
if core_config.quantization_config.weights_second_moment_correction:
|
92
96
|
transformed_graph = apply_second_moment_correction_to_graph(transformed_graph, representative_data_gen,
|
93
|
-
core_config, fw_impl)
|
97
|
+
core_config, fw_info, fw_impl)
|
94
98
|
|
95
99
|
#############################################
|
96
100
|
# Apply Bias Correction
|
97
101
|
#############################################
|
98
102
|
if core_config.quantization_config.weights_bias_correction:
|
99
103
|
transformed_graph = apply_bias_correction_to_graph(transformed_graph,
|
104
|
+
core_config,
|
100
105
|
fw_impl=fw_impl)
|
101
106
|
if tb_w is not None:
|
102
107
|
tb_w.add_graph(transformed_graph, 'after_statistics_correction')
|
@@ -20,6 +20,7 @@ from typing import Callable
|
|
20
20
|
import numpy as np
|
21
21
|
|
22
22
|
from model_compression_toolkit.core.common import Graph
|
23
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
23
24
|
from model_compression_toolkit.core import common
|
24
25
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
|
25
26
|
ActivationQuantizationMode
|
@@ -83,28 +84,30 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
|
|
83
84
|
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
|
84
85
|
# we should skip the substitution.
|
85
86
|
if source_node.is_reused():
|
87
|
+
for qc in source_node.candidates_quantization_cfg:
|
88
|
+
qc.weights_quantization_cfg.weights_second_moment_correction = False
|
86
89
|
return graph
|
87
90
|
|
88
91
|
# We apply only on nodes with folded BatchNormalization.
|
89
92
|
if source_node.prior_info.std_output is None or source_node.prior_info.mean_output is None:
|
93
|
+
for qc in source_node.candidates_quantization_cfg:
|
94
|
+
qc.weights_quantization_cfg.weights_second_moment_correction = False
|
90
95
|
return graph
|
91
96
|
|
92
97
|
# This feature disabled for models with weights quantization method of Power of 2
|
93
98
|
for qc in source_node.candidates_quantization_cfg:
|
94
99
|
# this feature is relevant only for layers with kernel op
|
95
|
-
|
100
|
+
kernel_attr = graph.fw_info.get_kernel_op_attributes(source_node.type)
|
101
|
+
if kernel_attr is None:
|
96
102
|
Logger.error(f"Can't preform BatchNorm reconstruction on a node {source_node.name} without a kernel op.")
|
97
|
-
if (qc.weights_quantization_cfg.get_attr_config(
|
103
|
+
if (qc.weights_quantization_cfg.get_attr_config(kernel_attr[0]).weights_quantization_method
|
98
104
|
== QuantizationMethod.POWER_OF_TWO):
|
99
105
|
Logger.warning("Second moment statistics correction feature disabled for models with weights "
|
100
106
|
"quantization method of Power of 2")
|
107
|
+
for qc_inner in source_node.candidates_quantization_cfg:
|
108
|
+
qc_inner.weights_quantization_cfg.weights_second_moment_correction = False
|
101
109
|
return graph
|
102
110
|
|
103
|
-
# turn on second moment correction flag
|
104
|
-
def set_second_moment_correction(qc):
|
105
|
-
qc.weights_quantization_cfg.weights_second_moment_correction = True
|
106
|
-
source_node.quantization_cfg.update_all(set_second_moment_correction)
|
107
|
-
|
108
111
|
eps = self.epsilon_val
|
109
112
|
|
110
113
|
original_gamma = source_node.prior_info.std_output
|
@@ -122,7 +125,7 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
|
|
122
125
|
|
123
126
|
bn_node.prior_info = copy.deepcopy(source_node.prior_info)
|
124
127
|
|
125
|
-
bn_node.
|
128
|
+
bn_node.candidates_quantization_cfg = copy.deepcopy(source_node.candidates_quantization_cfg)
|
126
129
|
|
127
130
|
for qc in bn_node.candidates_quantization_cfg:
|
128
131
|
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
@@ -137,6 +140,7 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
|
|
137
140
|
# reconstructed node BN attributes need to be quantized and how.
|
138
141
|
qc.weights_quantization_cfg.set_attr_config(attr,
|
139
142
|
WeightsAttrQuantizationConfig(
|
143
|
+
QuantizationConfig(),
|
140
144
|
AttributeQuantizationConfig(
|
141
145
|
enable_weights_quantization=False)))
|
142
146
|
|
@@ -157,7 +157,7 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
|
|
157
157
|
graph.remove_node(bn_node)
|
158
158
|
graph.remove_node(source_node)
|
159
159
|
|
160
|
-
self._calc_weights_quantization_params(conv_bn, weights_scale)
|
160
|
+
self._calc_weights_quantization_params(conv_bn, weights_scale, graph.fw_info)
|
161
161
|
|
162
162
|
assert num_nodes_before_substitution - len(graph.nodes) == 1
|
163
163
|
assert num_edges_before_substitution - len(graph.edges) == 1
|
@@ -165,15 +165,18 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
|
|
165
165
|
|
166
166
|
def _calc_weights_quantization_params(self,
|
167
167
|
conv_bn: BaseNode,
|
168
|
-
weights_scale: np.ndarray
|
168
|
+
weights_scale: np.ndarray,
|
169
|
+
fw_info):
|
169
170
|
"""
|
170
171
|
Update node weights quantization params.
|
171
172
|
Args:
|
172
173
|
conv_bn: Convolution node to update the weights quantization params.
|
173
174
|
weights_scale: Weight scale factor in which to multiply the conv node's weight.
|
175
|
+
fw_info: FrameworkInfo object with information about the specific framework's model
|
174
176
|
"""
|
175
177
|
# Conv layer is ensured to have a kernel attribute
|
176
|
-
|
178
|
+
kernel_attr = fw_info.get_kernel_op_attributes(conv_bn.type)[0]
|
179
|
+
conv_bn_kernel_cfg = conv_bn.final_weights_quantization_cfg.get_attr_config(kernel_attr)
|
177
180
|
# In case of SYMMETRIC weight quantization method, we update the threshold by weights_scale
|
178
181
|
if conv_bn_kernel_cfg.weights_quantization_method == QuantizationMethod.SYMMETRIC:
|
179
182
|
original_threshold = conv_bn_kernel_cfg.weights_quantization_params[THRESHOLD]
|
@@ -20,6 +20,8 @@ import scipy
|
|
20
20
|
|
21
21
|
from model_compression_toolkit.core import common
|
22
22
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
23
|
+
from model_compression_toolkit.defaultdict import DefaultDict
|
24
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
23
25
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
24
26
|
|
25
27
|
|
@@ -75,6 +77,7 @@ def fixed_second_moment_after_relu(mu: np.ndarray,
|
|
75
77
|
|
76
78
|
def scale_reshaping(scale: np.ndarray,
|
77
79
|
op2d: common.BaseNode,
|
80
|
+
kernel_channel_mapping: DefaultDict,
|
78
81
|
kernel_str: str,
|
79
82
|
in_channels: bool = True) -> np.ndarray:
|
80
83
|
"""
|
@@ -86,6 +89,7 @@ def scale_reshaping(scale: np.ndarray,
|
|
86
89
|
Args:
|
87
90
|
scale: Scale factor to scale the kernel channels by.
|
88
91
|
op2d: Node to scale its kernel.
|
92
|
+
kernel_channel_mapping: Mapping from a layer to a tuple of indices of its output/input kernel channels.
|
89
93
|
kernel_str: The framework specific attribute name of the convolution layer's weight/kernel.
|
90
94
|
in_channels: Kernel's index of input channels.
|
91
95
|
|
@@ -95,11 +99,12 @@ def scale_reshaping(scale: np.ndarray,
|
|
95
99
|
|
96
100
|
op_ndims = op2d.get_weights_by_keys(kernel_str).ndim
|
97
101
|
reshape_target = np.ones(op_ndims, dtype=np.int32)
|
98
|
-
reshape_target[op2d.
|
102
|
+
reshape_target[kernel_channel_mapping.get(op2d.type)[int(in_channels)]] = -1
|
99
103
|
return np.reshape(scale, reshape_target)
|
100
104
|
|
101
105
|
|
102
|
-
def update_linear_nodes(
|
106
|
+
def update_linear_nodes(fw_info: FrameworkInfo,
|
107
|
+
first_op2d_node: BaseNode,
|
103
108
|
second_op2d_node: BaseNode,
|
104
109
|
scale_factor: np.ndarray,
|
105
110
|
kernel_str: str,
|
@@ -111,6 +116,7 @@ def update_linear_nodes(first_op2d_node: BaseNode,
|
|
111
116
|
The scale factor contain a scale value per-channel.
|
112
117
|
|
113
118
|
Args:
|
119
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
114
120
|
groups of layers by how they should be quantized, etc.)
|
115
121
|
first_op2d_node: Node to multiply its kernel by the scale factor.
|
116
122
|
second_op2d_node: Node to divide its kernel by the scale factor.
|
@@ -119,12 +125,15 @@ def update_linear_nodes(first_op2d_node: BaseNode,
|
|
119
125
|
kernel_str: The framework specific attribute name of the convolution layer's weight/kernel.
|
120
126
|
|
121
127
|
"""
|
128
|
+
|
122
129
|
w2_fixed = second_op2d_node.get_weights_by_keys(kernel_str) / scale_reshaping(scale_factor,
|
123
130
|
second_op2d_node,
|
131
|
+
fw_info.kernel_channels_mapping,
|
124
132
|
kernel_str)
|
125
133
|
|
126
134
|
w1_fixed = first_op2d_node.get_weights_by_keys(kernel_str) * scale_reshaping(scale_factor,
|
127
135
|
first_op2d_node,
|
136
|
+
fw_info.kernel_channels_mapping,
|
128
137
|
kernel_str,
|
129
138
|
in_channels=False)
|
130
139
|
|
@@ -159,7 +168,8 @@ def calculate_scale_correction(first_op2d_node: BaseNode) -> tuple:
|
|
159
168
|
return scale_factor
|
160
169
|
|
161
170
|
|
162
|
-
def scale_equalization_lnl(
|
171
|
+
def scale_equalization_lnl(fw_info: FrameworkInfo,
|
172
|
+
first_op2d_node: BaseNode,
|
163
173
|
second_op2d_node: BaseNode,
|
164
174
|
kernel_str: str,
|
165
175
|
bias_str: str):
|
@@ -169,6 +179,7 @@ def scale_equalization_lnl(first_op2d_node: BaseNode,
|
|
169
179
|
follows the activation node to get the same expected output without the scaling.
|
170
180
|
|
171
181
|
Args:
|
182
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
172
183
|
groups of layers by how they should be quantized, etc.)
|
173
184
|
first_op2d_node: Node to multiply its kernel by the scale factor.
|
174
185
|
second_op2d_node: Node to divide its kernel by the scale factor.
|
@@ -178,7 +189,8 @@ def scale_equalization_lnl(first_op2d_node: BaseNode,
|
|
178
189
|
"""
|
179
190
|
scale_factor = calculate_scale_correction(first_op2d_node)
|
180
191
|
|
181
|
-
update_linear_nodes(
|
192
|
+
update_linear_nodes(fw_info,
|
193
|
+
first_op2d_node,
|
182
194
|
second_op2d_node,
|
183
195
|
scale_factor,
|
184
196
|
kernel_str,
|
@@ -194,6 +206,7 @@ class BaseScaleEqualization(common.BaseSubstitution):
|
|
194
206
|
|
195
207
|
def __init__(self,
|
196
208
|
quant_config: QuantizationConfig,
|
209
|
+
fw_info: FrameworkInfo,
|
197
210
|
matcher_instance,
|
198
211
|
kernel_str: str,
|
199
212
|
bias_str: str):
|
@@ -201,11 +214,13 @@ class BaseScaleEqualization(common.BaseSubstitution):
|
|
201
214
|
Initialize a ScaleEqualization object.
|
202
215
|
Args:
|
203
216
|
quant_config: QuantizationConfig containing parameters of how the model should be quantized.
|
217
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
204
218
|
groups of layers by how they should be quantized, etc.)
|
205
219
|
matcher_instance: Per substitution matcher instance of type WalkMatcher
|
206
220
|
"""
|
207
221
|
|
208
222
|
self.quant_config = quant_config
|
223
|
+
self.fw_info = fw_info
|
209
224
|
self.kernel_str = kernel_str
|
210
225
|
self.bias_str = bias_str
|
211
226
|
super().__init__(matcher_instance=matcher_instance)
|
@@ -228,7 +243,8 @@ class BaseScaleEqualization(common.BaseSubstitution):
|
|
228
243
|
act_node = nodes_list[1]
|
229
244
|
second_op2d_node = nodes_list[-1]
|
230
245
|
if first_op2d_node.prior_info.std_output is not None and act_node.is_activation_quantization_enabled():
|
231
|
-
scale_equalization_lnl(
|
246
|
+
scale_equalization_lnl(self.fw_info,
|
247
|
+
first_op2d_node,
|
232
248
|
second_op2d_node,
|
233
249
|
self.kernel_str,
|
234
250
|
self.bias_str)
|