mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -18,8 +18,6 @@ import numpy as np
|
|
18
18
|
|
19
19
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
20
20
|
CandidateNodeQuantizationConfig
|
21
|
-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import (get_activation_quantization_fn,
|
22
|
-
get_weights_quantization_fn)
|
23
21
|
|
24
22
|
|
25
23
|
def verify_candidates_descending_order(node_q_cfg: List[CandidateNodeQuantizationConfig],
|
@@ -79,21 +77,20 @@ def init_quantized_weights(node_q_cfg: List[CandidateNodeQuantizationConfig],
|
|
79
77
|
quantized_weights = []
|
80
78
|
for qc in node_q_cfg:
|
81
79
|
qc_weights_attr = qc.weights_quantization_cfg.get_attr_config(kernel_attr)
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
80
|
+
q_weight = qc_weights_attr.weights_quantization_fn(float_weights,
|
81
|
+
qc_weights_attr.weights_n_bits,
|
82
|
+
True,
|
83
|
+
qc_weights_attr.weights_quantization_params,
|
84
|
+
qc_weights_attr.weights_per_channel_threshold,
|
85
|
+
qc_weights_attr.weights_channels_axis[
|
86
|
+
0]) # output channel axis
|
89
87
|
|
90
88
|
quantized_weights.append(fw_tensor_convert_func(q_weight))
|
91
89
|
|
92
90
|
return quantized_weights
|
93
91
|
|
94
92
|
|
95
|
-
def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
|
96
|
-
get_activation_quantization_fn_factory: Callable) -> List:
|
93
|
+
def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]) -> List:
|
97
94
|
"""
|
98
95
|
Builds a list of quantizers for each of the bitwidth candidates for activation quantization,
|
99
96
|
to be stored and used during MP search.
|
@@ -101,7 +98,6 @@ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
|
|
101
98
|
Args:
|
102
99
|
node_q_cfg: Quantization configuration candidates of the node that generated the layer that will
|
103
100
|
use this quantizer.
|
104
|
-
get_activation_quantization_fn_factory: activation quantization functions factory.
|
105
101
|
|
106
102
|
Returns: a list of activation quantizers - for each bitwidth and layer's attribute to be quantized.
|
107
103
|
"""
|
@@ -109,7 +105,6 @@ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
|
|
109
105
|
activation_quantizers = []
|
110
106
|
for index, qc in enumerate(node_q_cfg):
|
111
107
|
q_activation = node_q_cfg[index].activation_quantization_cfg
|
112
|
-
|
113
|
-
activation_quantizers.append(quantizer)
|
108
|
+
activation_quantizers.append(q_activation.quantize_node_output)
|
114
109
|
|
115
110
|
return activation_quantizers
|
@@ -12,12 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
import numpy as np
|
16
|
+
|
17
|
+
from model_compression_toolkit.core import ResourceUtilization, FrameworkInfo
|
16
18
|
from model_compression_toolkit.core.common import Graph
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
20
|
+
FrameworkQuantizationCapabilities
|
17
21
|
|
18
22
|
|
19
23
|
def filter_candidates_for_mixed_precision(graph: Graph,
|
20
|
-
target_resource_utilization: ResourceUtilization
|
24
|
+
target_resource_utilization: ResourceUtilization,
|
25
|
+
fw_info: FrameworkInfo,
|
26
|
+
fqc: FrameworkQuantizationCapabilities):
|
21
27
|
"""
|
22
28
|
Filters out candidates in case of mixed precision search for only weights or activation compression.
|
23
29
|
For instance, if running only weights compression - filters out candidates of activation configurable nodes
|
@@ -29,6 +35,9 @@ def filter_candidates_for_mixed_precision(graph: Graph,
|
|
29
35
|
Args:
|
30
36
|
graph: A graph representation of the model to be quantized.
|
31
37
|
target_resource_utilization: The resource utilization of the target device.
|
38
|
+
fw_info: fw_info: Information needed for quantization about the specific framework.
|
39
|
+
fqc: FrameworkQuantizationCapabilities object that describes the desired inference target platform.
|
40
|
+
|
32
41
|
"""
|
33
42
|
|
34
43
|
tru = target_resource_utilization
|
@@ -40,21 +49,21 @@ def filter_candidates_for_mixed_precision(graph: Graph,
|
|
40
49
|
# filter out candidates activation only configurable node
|
41
50
|
activation_configurable_nodes = [n for n in graph.get_activation_configurable_nodes()]
|
42
51
|
for n in activation_configurable_nodes:
|
43
|
-
base_cfg_nbits = n.
|
44
|
-
|
52
|
+
base_cfg_nbits = n.get_qco(fqc).base_config.activation_n_bits
|
53
|
+
filtered_conf = [c for c in n.candidates_quantization_cfg if
|
45
54
|
c.activation_quantization_cfg.enable_activation_quantization and
|
46
55
|
c.activation_quantization_cfg.activation_n_bits == base_cfg_nbits]
|
47
56
|
|
48
|
-
n.
|
57
|
+
n.candidates_quantization_cfg = filtered_conf
|
49
58
|
|
50
59
|
elif tru.activation_restricted() and not tru.weight_restricted():
|
51
60
|
# Running mixed precision for activation compression only -
|
52
61
|
# filter out candidates weights only configurable node
|
53
|
-
weight_configurable_nodes = [n for n in graph.get_weights_configurable_nodes()]
|
62
|
+
weight_configurable_nodes = [n for n in graph.get_weights_configurable_nodes(fw_info)]
|
54
63
|
for n in weight_configurable_nodes:
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
c.weights_quantization_cfg.get_attr_config(
|
59
|
-
c.weights_quantization_cfg.get_attr_config(
|
60
|
-
n.
|
64
|
+
kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
|
65
|
+
base_cfg_nbits = n.get_qco(fqc).base_config.attr_weights_configs_mapping[kernel_attr].weights_n_bits
|
66
|
+
filtered_conf = [c for c in n.candidates_quantization_cfg if
|
67
|
+
c.weights_quantization_cfg.get_attr_config(kernel_attr).enable_weights_quantization and
|
68
|
+
c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == base_cfg_nbits]
|
69
|
+
n.candidates_quantization_cfg = filtered_conf
|
@@ -30,10 +30,11 @@ from model_compression_toolkit.core.common.quantization.node_quantization_config
|
|
30
30
|
class MixedPrecisionRUHelper:
|
31
31
|
""" Helper class for resource utilization computations for mixed precision optimization. """
|
32
32
|
|
33
|
-
def __init__(self, graph: Graph, fw_impl: FrameworkImplementation):
|
33
|
+
def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation):
|
34
34
|
self.graph = graph
|
35
|
+
self.fw_info = fw_info
|
35
36
|
self.fw_impl = fw_impl
|
36
|
-
self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl)
|
37
|
+
self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
|
37
38
|
|
38
39
|
def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Dict[BaseNode, int]) -> Dict[RUTarget, np.ndarray]:
|
39
40
|
"""
|
@@ -35,6 +35,7 @@ class BitWidthSearchMethod(Enum):
|
|
35
35
|
|
36
36
|
|
37
37
|
def search_bit_width(graph: Graph,
|
38
|
+
fw_info: FrameworkInfo,
|
38
39
|
fw_impl: FrameworkImplementation,
|
39
40
|
target_resource_utilization: ResourceUtilization,
|
40
41
|
mp_config: MixedPrecisionQuantizationConfig,
|
@@ -51,6 +52,7 @@ def search_bit_width(graph: Graph,
|
|
51
52
|
|
52
53
|
Args:
|
53
54
|
graph: Graph to search a MP configuration for.
|
55
|
+
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
|
54
56
|
fw_impl: FrameworkImplementation object with specific framework methods implementation.
|
55
57
|
target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
|
56
58
|
mp_config: Mixed-precision quantization configuration.
|
@@ -77,7 +79,7 @@ def search_bit_width(graph: Graph,
|
|
77
79
|
|
78
80
|
# Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
|
79
81
|
# even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
|
80
|
-
se = SensitivityEvaluation(graph, mp_config, representative_data_gen=representative_data_gen,
|
82
|
+
se = SensitivityEvaluation(graph, mp_config, representative_data_gen=representative_data_gen, fw_info=fw_info,
|
81
83
|
fw_impl=fw_impl, disable_activation_for_metric=disable_activation_for_metric,
|
82
84
|
hessian_info_service=hessian_info_service)
|
83
85
|
|
@@ -91,6 +93,7 @@ def search_bit_width(graph: Graph,
|
|
91
93
|
|
92
94
|
# Search manager and LP are highly coupled, so LP search method was moved inside search manager.
|
93
95
|
search_manager = MixedPrecisionSearchManager(graph,
|
96
|
+
fw_info=fw_info,
|
94
97
|
fw_impl=fw_impl,
|
95
98
|
sensitivity_evaluator=se,
|
96
99
|
target_resource_utilization=target_resource_utilization,
|
@@ -102,6 +105,6 @@ def search_bit_width(graph: Graph,
|
|
102
105
|
if mp_config.refine_mp_solution:
|
103
106
|
nodes_bit_cfg = greedy_solution_refinement_procedure(nodes_bit_cfg, search_manager, target_resource_utilization)
|
104
107
|
|
105
|
-
topo_bit_cfg = [nodes_bit_cfg[n] for n in graph.get_configurable_sorted_nodes()]
|
108
|
+
topo_bit_cfg = [nodes_bit_cfg[n] for n in graph.get_configurable_sorted_nodes(fw_info)]
|
106
109
|
assert len(topo_bit_cfg) == len(nodes_bit_cfg)
|
107
110
|
return topo_bit_cfg
|
@@ -53,6 +53,7 @@ class MixedPrecisionSearchManager:
|
|
53
53
|
|
54
54
|
def __init__(self,
|
55
55
|
graph: Graph,
|
56
|
+
fw_info: FrameworkInfo,
|
56
57
|
fw_impl: FrameworkImplementation,
|
57
58
|
sensitivity_evaluator: SensitivityEvaluation,
|
58
59
|
target_resource_utilization: ResourceUtilization,
|
@@ -61,12 +62,14 @@ class MixedPrecisionSearchManager:
|
|
61
62
|
|
62
63
|
Args:
|
63
64
|
graph: Graph to search for its MP configuration.
|
65
|
+
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
|
64
66
|
fw_impl: FrameworkImplementation object with specific framework methods implementation.
|
65
67
|
sensitivity_evaluator: A SensitivityEvaluation which provides a function that evaluates the sensitivity of
|
66
68
|
a bit-width configuration for the MP model.
|
67
69
|
target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
|
68
70
|
"""
|
69
71
|
|
72
|
+
self.fw_info = fw_info
|
70
73
|
self.fw_impl = fw_impl
|
71
74
|
|
72
75
|
self.original_graph = graph
|
@@ -78,12 +81,12 @@ class MixedPrecisionSearchManager:
|
|
78
81
|
self.target_resource_utilization = target_resource_utilization
|
79
82
|
self.mp_config = mp_config
|
80
83
|
|
81
|
-
self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes()
|
84
|
+
self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes(fw_info)
|
82
85
|
|
83
86
|
self.ru_targets = target_resource_utilization.get_restricted_targets()
|
84
|
-
self.orig_graph_ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_impl)
|
87
|
+
self.orig_graph_ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_info, fw_impl)
|
85
88
|
|
86
|
-
self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config()
|
89
|
+
self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(fw_info)
|
87
90
|
|
88
91
|
self.config_reconstructor = None
|
89
92
|
orig_min_config = self.min_ru_config
|
@@ -124,9 +124,10 @@ class ResourceUtilizationCalculator:
|
|
124
124
|
unexpected_qc_error = 'Custom quantization configuration is not expected for non-custom bit mode.'
|
125
125
|
unexpected_qc_nodes_error = 'Custom quantization configuration contains unexpected node names.'
|
126
126
|
|
127
|
-
def __init__(self, graph: Graph, fw_impl: FrameworkImplementation):
|
127
|
+
def __init__(self, graph: Graph, fw_impl: FrameworkImplementation, fw_info: FrameworkInfo):
|
128
128
|
self.graph = graph
|
129
129
|
self.fw_impl = fw_impl
|
130
|
+
self.fw_info = fw_info
|
130
131
|
|
131
132
|
# Currently we go over the full graph even if utilization won't be requested for all nodes.
|
132
133
|
# We could fill the cache on the fly only for requested nodes, but it's probably negligible.
|
@@ -543,10 +544,14 @@ class ResourceUtilizationCalculator:
|
|
543
544
|
self._validate_custom_qcs(w_qc, bitwidth_mode)
|
544
545
|
|
545
546
|
# check if the node has kernel
|
546
|
-
|
547
|
+
kernel_attrs = self.fw_info.get_kernel_op_attributes(n.type)
|
548
|
+
if len(kernel_attrs) > 1: # pragma: no cover
|
549
|
+
raise NotImplementedError('Multiple kernel attributes are not supported for BOPS computation.')
|
550
|
+
if not kernel_attrs or not kernel_attrs[0]:
|
547
551
|
return 0
|
548
552
|
|
549
|
-
|
553
|
+
kernel_attr = kernel_attrs[0]
|
554
|
+
node_mac = self.fw_impl.get_node_mac_operations(n, self.fw_info)
|
550
555
|
if node_mac == 0:
|
551
556
|
return node_mac
|
552
557
|
|
@@ -554,12 +559,12 @@ class ResourceUtilizationCalculator:
|
|
554
559
|
assert len(prev_nodes) == 1, f'Weights node is expected to have exactly one input, {n} has {len(prev_nodes)}'
|
555
560
|
a_node = prev_nodes[0]
|
556
561
|
if (target_criterion == TargetInclusionCriterion.AnyQuantized and
|
557
|
-
not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(
|
562
|
+
not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(kernel_attr))):
|
558
563
|
return 0
|
559
564
|
|
560
565
|
act_qc = self._extract_qc(a_node, act_qcs)
|
561
566
|
a_nbits = self._get_activation_nbits(a_node, bitwidth_mode, act_qc)
|
562
|
-
w_nbits = self._get_weight_nbits(n,
|
567
|
+
w_nbits = self._get_weight_nbits(n, kernel_attr, bitwidth_mode, w_qc)
|
563
568
|
node_bops = a_nbits * w_nbits * node_mac
|
564
569
|
return node_bops
|
565
570
|
|
@@ -15,7 +15,7 @@
|
|
15
15
|
import copy
|
16
16
|
from typing import Callable, Any
|
17
17
|
|
18
|
-
from model_compression_toolkit.core import ResourceUtilization, CoreConfig, QuantizationErrorMethod
|
18
|
+
from model_compression_toolkit.core import FrameworkInfo, ResourceUtilization, CoreConfig, QuantizationErrorMethod
|
19
19
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
20
20
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
|
21
21
|
ResourceUtilizationCalculator, BitwidthMode, TargetInclusionCriterion
|
@@ -27,6 +27,7 @@ def compute_resource_utilization_data(in_model: Any,
|
|
27
27
|
representative_data_gen: Callable,
|
28
28
|
core_config: CoreConfig,
|
29
29
|
fqc: FrameworkQuantizationCapabilities,
|
30
|
+
fw_info: FrameworkInfo,
|
30
31
|
fw_impl: FrameworkImplementation) -> ResourceUtilization:
|
31
32
|
"""
|
32
33
|
Compute Resource Utilization of a model with the default single precision quantization.
|
@@ -38,6 +39,7 @@ def compute_resource_utilization_data(in_model: Any,
|
|
38
39
|
core_config: CoreConfig containing parameters of how the model should be quantized.
|
39
40
|
fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
|
40
41
|
the attached framework operator's information.
|
42
|
+
fw_info: Information needed for quantization about the specific framework.
|
41
43
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
42
44
|
|
43
45
|
Returns:
|
@@ -53,11 +55,12 @@ def compute_resource_utilization_data(in_model: Any,
|
|
53
55
|
transformed_graph = graph_preparation_runner(in_model,
|
54
56
|
representative_data_gen=representative_data_gen,
|
55
57
|
quantization_config=core_config.quantization_config,
|
58
|
+
fw_info=fw_info,
|
56
59
|
fw_impl=fw_impl,
|
57
60
|
fqc=fqc,
|
58
61
|
bit_width_config=core_config.bit_width_config,
|
59
62
|
mixed_precision_enable=False,
|
60
63
|
running_gptq=False)
|
61
64
|
|
62
|
-
ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl)
|
65
|
+
ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info)
|
63
66
|
return ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused, BitwidthMode.QDefaultSP)
|
model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15
15
|
import numpy as np
|
16
16
|
from typing import runtime_checkable, Protocol, Callable, Any, List, Tuple
|
17
17
|
|
18
|
-
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, MpDistanceWeighting
|
18
|
+
from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig, MpDistanceWeighting
|
19
19
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
20
20
|
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
|
21
21
|
HessianScoresGranularity
|
@@ -62,12 +62,15 @@ class DistanceMetricCalculator(MetricCalculator):
|
|
62
62
|
graph: Graph,
|
63
63
|
mp_config: MixedPrecisionQuantizationConfig,
|
64
64
|
representative_data_gen: Callable,
|
65
|
+
fw_info: FrameworkInfo,
|
65
66
|
fw_impl: Any,
|
66
67
|
hessian_info_service: HessianInfoService = None):
|
67
68
|
"""
|
68
69
|
Args:
|
69
70
|
graph: Graph to search for its MP configuration.
|
70
71
|
mp_config: MP Quantization configuration for how the graph should be quantized.
|
72
|
+
fw_info: FrameworkInfo object about the specific framework
|
73
|
+
(e.g., attributes of different layers' weights to quantize).
|
71
74
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
72
75
|
representative_data_gen: Dataset used for getting batches for inference.
|
73
76
|
hessian_info_service: HessianInfoService to fetch Hessian approximation information.
|
@@ -75,13 +78,14 @@ class DistanceMetricCalculator(MetricCalculator):
|
|
75
78
|
self.graph = graph
|
76
79
|
self.mp_config = mp_config
|
77
80
|
self.representative_data_gen = representative_data_gen
|
81
|
+
self.fw_info = fw_info
|
78
82
|
self.fw_impl = fw_impl
|
79
83
|
|
80
84
|
if self.mp_config.distance_weighting_method == MpDistanceWeighting.HESSIAN:
|
81
85
|
assert hessian_info_service is not None, ('Expected HessianInfoService object to be passed with Hessian '
|
82
86
|
'distance weighting')
|
83
87
|
|
84
|
-
self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names()
|
88
|
+
self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names(self.fw_info)
|
85
89
|
|
86
90
|
# Get interest points and output points set for distance measurement and set other helper datasets
|
87
91
|
# We define a separate set of output nodes of the model for the purpose of sensitivity computation.
|
@@ -392,8 +396,9 @@ class DistanceMetricCalculator(MetricCalculator):
|
|
392
396
|
"""
|
393
397
|
|
394
398
|
return [n.node for n in graph.get_outputs()
|
395
|
-
if (
|
396
|
-
|
399
|
+
if (graph.fw_info.is_kernel_op(n.node.type) and
|
400
|
+
n.node.is_weights_quantization_enabled(graph.fw_info.get_kernel_op_attributes(n.node.type)[0])) or
|
401
|
+
n.node.is_activation_quantization_enabled()]
|
397
402
|
|
398
403
|
@staticmethod
|
399
404
|
def bound_num_interest_points(sorted_ip_list: List[BaseNode], num_ip_factor: float) -> List[BaseNode]:
|
model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py
CHANGED
@@ -38,6 +38,7 @@ class SensitivityEvaluation:
|
|
38
38
|
graph: Graph,
|
39
39
|
mp_config: MixedPrecisionQuantizationConfig,
|
40
40
|
representative_data_gen: Callable,
|
41
|
+
fw_info: FrameworkInfo,
|
41
42
|
fw_impl: Any,
|
42
43
|
disable_activation_for_metric: bool = False,
|
43
44
|
hessian_info_service: HessianInfoService = None
|
@@ -45,6 +46,8 @@ class SensitivityEvaluation:
|
|
45
46
|
"""
|
46
47
|
Args:
|
47
48
|
graph: Graph to search for its MP configuration.
|
49
|
+
fw_info: FrameworkInfo object about the specific framework
|
50
|
+
(e.g., attributes of different layers' weights to quantize).
|
48
51
|
mp_config: MP Quantization configuration for how the graph should be quantized.
|
49
52
|
representative_data_gen: Dataset used for getting batches for inference.
|
50
53
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
@@ -54,13 +57,14 @@ class SensitivityEvaluation:
|
|
54
57
|
"""
|
55
58
|
self.mp_config = mp_config
|
56
59
|
self.representative_data_gen = representative_data_gen
|
60
|
+
self.fw_info = fw_info
|
57
61
|
self.fw_impl = fw_impl
|
58
62
|
|
59
63
|
if self.mp_config.custom_metric_fn:
|
60
64
|
self.metric_calculator = CustomMetricCalculator(graph, self.mp_config.custom_metric_fn)
|
61
65
|
else:
|
62
66
|
self.metric_calculator = DistanceMetricCalculator(graph, mp_config, representative_data_gen,
|
63
|
-
fw_impl=fw_impl,
|
67
|
+
fw_info=fw_info, fw_impl=fw_impl,
|
64
68
|
hessian_info_service=hessian_info_service)
|
65
69
|
|
66
70
|
# Build a mixed-precision model which can be configured to use different bitwidth in different layers.
|
@@ -107,7 +111,8 @@ class SensitivityEvaluation:
|
|
107
111
|
|
108
112
|
model_mp, _, conf_node2layers = self.fw_impl.model_builder(evaluation_graph,
|
109
113
|
mode=ModelBuilderMode.MIXEDPRECISION,
|
110
|
-
append2output=outputs
|
114
|
+
append2output=outputs,
|
115
|
+
fw_info=self.fw_info)
|
111
116
|
|
112
117
|
# Disable all configurable quantizers. They will be activated one at a time during sensitivity evaluation.
|
113
118
|
for layer in itertools.chain(*conf_node2layers.values()):
|
@@ -50,11 +50,8 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
|
|
50
50
|
if target_resource_utilization.bops_restricted():
|
51
51
|
Logger.info(f'Target resource utilization constraint BOPs - Skipping MP greedy solution refinement')
|
52
52
|
return mp_solution
|
53
|
-
assert search_manager.using_virtual_graph is False
|
54
53
|
|
55
|
-
|
56
|
-
activation_restricted = tru.activation_restricted() or tru.total_mem_restricted() or tru.bops_restricted()
|
57
|
-
weights_restricted = tru.weight_restricted() or tru.total_mem_restricted() or tru.bops_restricted()
|
54
|
+
assert search_manager.using_virtual_graph is False
|
58
55
|
|
59
56
|
new_solution = mp_solution.copy()
|
60
57
|
changed = True
|
@@ -65,7 +62,7 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
|
|
65
62
|
nodes_next_candidate = {}
|
66
63
|
|
67
64
|
for node in search_manager.mp_topo_configurable_nodes:
|
68
|
-
if new_solution[node] ==
|
65
|
+
if new_solution[node] == 0:
|
69
66
|
# layer has max config in the given solution, nothing to optimize
|
70
67
|
continue
|
71
68
|
|
@@ -74,8 +71,9 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
|
|
74
71
|
# only weights kernel attribute is quantized with weights mixed precision
|
75
72
|
valid_candidates = _get_valid_candidates_indices(node_candidates,
|
76
73
|
new_solution[node],
|
77
|
-
activation_restricted,
|
78
|
-
|
74
|
+
target_resource_utilization.activation_restricted(),
|
75
|
+
target_resource_utilization.weight_restricted()
|
76
|
+
)
|
79
77
|
|
80
78
|
# Create a list of ru for the valid candidates.
|
81
79
|
updated_ru = []
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
18
18
|
from typing import List, Union, Tuple, Optional
|
19
19
|
|
20
20
|
from networkx.algorithms.dag import topological_sort
|
21
|
-
from model_compression_toolkit.core import QuantizationErrorMethod
|
21
|
+
from model_compression_toolkit.core import FrameworkInfo, QuantizationErrorMethod
|
22
22
|
from model_compression_toolkit.core import common
|
23
23
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
24
24
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
@@ -30,6 +30,7 @@ from model_compression_toolkit.core.common.collectors.statistics_collector impor
|
|
30
30
|
|
31
31
|
|
32
32
|
def create_stats_collector_for_node(node: common.BaseNode,
|
33
|
+
fw_info: FrameworkInfo,
|
33
34
|
quant_node_in_fln: bool) -> BaseStatsCollector:
|
34
35
|
"""
|
35
36
|
Gets a node and a groups list and create and return a statistics collector for a node
|
@@ -38,7 +39,7 @@ def create_stats_collector_for_node(node: common.BaseNode,
|
|
38
39
|
|
39
40
|
Args:
|
40
41
|
node: Node to create its statistics collector.
|
41
|
-
|
42
|
+
fw_info: Information relevant to a specific framework about what is out channel axis (for statistics per-channel).
|
42
43
|
|
43
44
|
Returns:
|
44
45
|
Statistics collector for statistics collection for the node.
|
@@ -47,7 +48,7 @@ def create_stats_collector_for_node(node: common.BaseNode,
|
|
47
48
|
if node.is_activation_quantization_enabled() or quant_node_in_fln:
|
48
49
|
min_output = getattr(node.prior_info, 'min_output', None)
|
49
50
|
max_output = getattr(node.prior_info, 'max_output', None)
|
50
|
-
stats_collector = common.StatsCollector(out_channel_axis=node.
|
51
|
+
stats_collector = common.StatsCollector(out_channel_axis=fw_info.out_channel_axis_mapping.get(node.type),
|
51
52
|
init_min_value=min_output,
|
52
53
|
init_max_value=max_output)
|
53
54
|
else:
|
@@ -58,20 +59,20 @@ def create_stats_collector_for_node(node: common.BaseNode,
|
|
58
59
|
|
59
60
|
def create_tensor2node(graph: common.Graph,
|
60
61
|
node: common.BaseNode,
|
61
|
-
|
62
|
+
fw_info: common.FrameworkInfo):
|
62
63
|
"""
|
63
64
|
Force statistic collector creation and assignment for a node.
|
64
65
|
Args:
|
65
66
|
graph: Graph of the node (for retrieving the current tensor).
|
66
67
|
node: Node to create a tensor for.
|
67
|
-
|
68
|
+
fw_info: Specific framework information (for example, output channels index).
|
68
69
|
|
69
70
|
"""
|
70
71
|
current_sc = graph.get_out_stats_collector(node)
|
71
72
|
is_list_nostat_collectors = isinstance(current_sc, list) and len(
|
72
73
|
[sc for sc in current_sc if not isinstance(sc, common.NoStatsCollector)]) == 0
|
73
74
|
if isinstance(current_sc, common.NoStatsCollector) or current_sc is None or is_list_nostat_collectors:
|
74
|
-
stats_collector = common.StatsCollector(
|
75
|
+
stats_collector = common.StatsCollector(fw_info.out_channel_axis_mapping.get(node.type))
|
75
76
|
graph.set_out_stats_collector_to_node(node, stats_collector)
|
76
77
|
|
77
78
|
|
@@ -139,6 +140,7 @@ class ModelCollector:
|
|
139
140
|
|
140
141
|
def __init__(self, graph: Graph,
|
141
142
|
fw_impl: FrameworkImplementation,
|
143
|
+
fw_info: FrameworkInfo,
|
142
144
|
hessian_info_service: HessianInfoService = None,
|
143
145
|
qc: common.QuantizationConfig = common.DEFAULTCONFIG):
|
144
146
|
"""
|
@@ -147,10 +149,12 @@ class ModelCollector:
|
|
147
149
|
Args:
|
148
150
|
graph: Graph to build a model from it.
|
149
151
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
152
|
+
fw_info: FrameworkInfo object with a specific framework information.
|
150
153
|
qc: Quantization configuration containing parameters for how the graph should be quantized.
|
151
154
|
"""
|
152
155
|
|
153
156
|
self.fw_impl = fw_impl
|
157
|
+
self.fw_info = fw_info
|
154
158
|
self.hessian_service = hessian_info_service
|
155
159
|
self.qc = qc
|
156
160
|
self.model_outputs = [out.node for out in graph.get_outputs()]
|
@@ -158,27 +162,17 @@ class ModelCollector:
|
|
158
162
|
# Assign statistics collectors to nodes
|
159
163
|
for n in graph.get_topo_sorted_nodes():
|
160
164
|
quant_node_in_fln = n.is_fln_quantization() and graph.fusing_info.is_quantized_node_in_fln(n)
|
161
|
-
sc = create_stats_collector_for_node(n, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
|
162
|
-
if isinstance(sc, common.StatsCollector) and (sc.mc.axis is None or sc.mpcc.axis is None):
|
163
|
-
# Missing output channel axis info, so try to extract it from previous and next nodes output channel axis.
|
164
|
-
possible_output_channel_axis_set = {nn.out_channel_axis for nn in graph.get_next_nodes(n) + graph.get_prev_nodes(n)}
|
165
|
-
# Filter out None values.
|
166
|
-
possible_output_channel_axis_list = list(filter(lambda x: x is not None, possible_output_channel_axis_set))
|
167
|
-
if len(possible_output_channel_axis_list) > 0:
|
168
|
-
if len(possible_output_channel_axis_list) > 1:
|
169
|
-
Logger.warning(f'Ambiguous input channel data from next nodes for {n.name}.')
|
170
|
-
sc.mc.axis = possible_output_channel_axis_list[0]
|
171
|
-
sc.mpcc.axis = possible_output_channel_axis_list[0]
|
172
|
-
|
165
|
+
sc = create_stats_collector_for_node(n, fw_info=fw_info, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
|
173
166
|
# If we use bias correction, and the node has kernel weights to quantize, we need to make sure
|
174
167
|
# its previous nodes' tensors are consistent with this node.
|
175
|
-
|
176
|
-
|
168
|
+
kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
|
169
|
+
if qc.weights_bias_correction and kernel_attr is not None and n.is_weights_quantization_enabled(
|
170
|
+
kernel_attr):
|
177
171
|
for ie in graph.incoming_edges(n):
|
178
172
|
input_node = ie.source_node
|
179
173
|
create_tensor2node(graph,
|
180
174
|
input_node,
|
181
|
-
|
175
|
+
fw_info)
|
182
176
|
if sc is not None:
|
183
177
|
graph.set_out_stats_collector_to_node(n, sc)
|
184
178
|
|
@@ -211,11 +205,13 @@ class ModelCollector:
|
|
211
205
|
# TODO: Add integration test for this case
|
212
206
|
append2output = outputs_nodes + [n for n in self.model_outputs if n not in outputs_nodes]
|
213
207
|
|
208
|
+
|
214
209
|
# Build a float model and output all layers' outputs
|
215
210
|
# (that should be collected) as the model's outputs
|
216
211
|
self.model, _ = self.fw_impl.model_builder(graph,
|
217
212
|
mode=ModelBuilderMode.FLOAT,
|
218
|
-
append2output=append2output
|
213
|
+
append2output=append2output,
|
214
|
+
fw_info=self.fw_info)
|
219
215
|
|
220
216
|
def infer(self, inputs_list: List[np.ndarray]):
|
221
217
|
"""
|
@@ -0,0 +1,44 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from model_compression_toolkit.core import FrameworkInfo
|
5
|
+
|
6
|
+
|
7
|
+
class ModelValidation:
|
8
|
+
"""
|
9
|
+
Class to define validation methods in order to validate the received model to quantize.
|
10
|
+
"""
|
11
|
+
|
12
|
+
def __init__(self,
|
13
|
+
model: Any,
|
14
|
+
fw_info:FrameworkInfo):
|
15
|
+
"""
|
16
|
+
Initialize a ModelValidation object.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
model: Model to check its validity.
|
20
|
+
fw_info: Information about the specific framework of the model.
|
21
|
+
"""
|
22
|
+
self.model = model
|
23
|
+
self.fw_info = fw_info
|
24
|
+
|
25
|
+
@abstractmethod
|
26
|
+
def validate_output_channel_consistency(self):
|
27
|
+
"""
|
28
|
+
|
29
|
+
Validate that output channels index in all layers of the model are the same.
|
30
|
+
If the model has layers with different output channels index, it should throw an exception.
|
31
|
+
|
32
|
+
"""
|
33
|
+
raise NotImplemented(
|
34
|
+
f'Framework validation class did not implement validate_output_channel_consistency') # pragma: no cover
|
35
|
+
|
36
|
+
def validate(self):
|
37
|
+
"""
|
38
|
+
|
39
|
+
Run all validation methods before the quantization process starts.
|
40
|
+
|
41
|
+
"""
|
42
|
+
self.validate_output_channel_consistency()
|
43
|
+
|
44
|
+
|
@@ -13,14 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from model_compression_toolkit.core.common.network_editors.actions import
|
17
|
-
ChangeCandidatesWeightsQuantConfigAttr,
|
18
|
-
ChangeFinalWeightsQuantConfigAttr,
|
19
|
-
ChangeCandidatesActivationQuantConfigAttr,
|
20
|
-
ChangeCandidatesActivationQuantizationMethod,
|
21
|
-
ChangeFinalWeightsQuantizationMethod,
|
22
|
-
ChangeCandidatesWeightsQuantizationMethod,
|
23
|
-
ChangeFinalActivationQuantConfigAttr)
|
16
|
+
from model_compression_toolkit.core.common.network_editors.actions import ChangeCandidatesWeightsQuantConfigAttr, ChangeFinalWeightsQuantConfigAttr, ChangeCandidatesActivationQuantConfigAttr, ChangeQuantizationParamFunction, ChangeCandidatesActivationQuantizationMethod, ChangeFinalWeightsQuantizationMethod, ChangeCandidatesWeightsQuantizationMethod, ChangeFinalActivationQuantConfigAttr
|
24
17
|
from model_compression_toolkit.core.common.network_editors.actions import EditRule
|
25
18
|
from model_compression_toolkit.core.common.network_editors.node_filters import NodeTypeFilter, NodeNameScopeFilter, \
|
26
19
|
NodeNameFilter
|