mct-nightly 2.2.0.20250113.134913__py3-none-any.whl → 2.2.0.20250114.134534__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/RECORD +102 -104
- model_compression_toolkit/__init__.py +2 -2
- model_compression_toolkit/core/common/framework_info.py +1 -3
- model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
- model_compression_toolkit/core/common/graph/base_graph.py +20 -21
- model_compression_toolkit/core/common/graph/base_node.py +44 -17
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +187 -0
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +35 -162
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +668 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +74 -51
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/pruner.py +5 -3
- model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/graph_prep_runner.py +12 -11
- model_compression_toolkit/core/keras/default_framework_info.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
- model_compression_toolkit/core/runner.py +33 -60
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/metadata.py +11 -10
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
- model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
- model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
- model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
- model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
- model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
- model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
- model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +0 -528
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -25,7 +25,7 @@ from model_compression_toolkit.constants import MIN_THRESHOLD
|
|
25
25
|
class CustomOpsetLayers(NamedTuple):
|
26
26
|
"""
|
27
27
|
This struct defines a set of operators from a specific framework, which will be used to configure a custom operator
|
28
|
-
set in the
|
28
|
+
set in the FQC.
|
29
29
|
|
30
30
|
Args:
|
31
31
|
operators: a list of framework operators to map to a certain custom opset name.
|
@@ -16,8 +16,8 @@
|
|
16
16
|
from collections.abc import Callable
|
17
17
|
from functools import partial
|
18
18
|
|
19
|
+
from mct_quantizers import QuantizationMethod
|
19
20
|
from model_compression_toolkit.logger import Logger
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
21
21
|
from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
|
22
22
|
from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
|
23
23
|
symmetric_quantizer, uniform_quantizer
|
@@ -16,8 +16,8 @@
|
|
16
16
|
from collections.abc import Callable
|
17
17
|
from functools import partial
|
18
18
|
|
19
|
+
from mct_quantizers import QuantizationMethod
|
19
20
|
from model_compression_toolkit.logger import Logger
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
21
21
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import \
|
22
22
|
lut_kmeans_tensor, lut_kmeans_histogram
|
23
23
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import \
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py
CHANGED
@@ -16,11 +16,11 @@ from copy import deepcopy
|
|
16
16
|
from typing import Tuple, Callable, List, Iterable, Optional
|
17
17
|
import numpy as np
|
18
18
|
import model_compression_toolkit.core.common.quantization.quantization_config as qc
|
19
|
+
from mct_quantizers import QuantizationMethod
|
19
20
|
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, HessianScoresGranularity, \
|
20
21
|
HessianInfoService
|
21
22
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_mae, compute_lp_norm
|
22
23
|
from model_compression_toolkit.logger import Logger
|
23
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
24
24
|
from model_compression_toolkit.constants import FLOAT_32, NUM_QPARAM_HESSIAN_SAMPLES
|
25
25
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor, \
|
26
26
|
reshape_tensor_for_per_channel_search
|
@@ -16,6 +16,7 @@ import numpy as np
|
|
16
16
|
from typing import Union, Tuple, Dict
|
17
17
|
|
18
18
|
import model_compression_toolkit.core.common.quantization.quantization_config as qc
|
19
|
+
from mct_quantizers import QuantizationMethod
|
19
20
|
from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES, SIGNED
|
20
21
|
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
21
22
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \
|
@@ -23,7 +24,6 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
|
|
23
24
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import max_power_of_two, get_tensor_max
|
24
25
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
|
25
26
|
get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function
|
26
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
27
27
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
|
28
28
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor
|
29
29
|
|
@@ -15,7 +15,7 @@
|
|
15
15
|
import numpy as np
|
16
16
|
from typing import Dict, Union
|
17
17
|
|
18
|
-
from
|
18
|
+
from mct_quantizers import QuantizationMethod
|
19
19
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
|
20
20
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
21
21
|
from model_compression_toolkit.core.common.quantization import quantization_params_generation
|
@@ -25,7 +25,7 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
|
|
25
25
|
qparams_symmetric_selection_histogram_search, kl_qparams_symmetric_selection_histogram_search
|
26
26
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import \
|
27
27
|
get_tensor_max
|
28
|
-
from
|
28
|
+
from mct_quantizers import QuantizationMethod
|
29
29
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
|
30
30
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor
|
31
31
|
|
@@ -24,7 +24,7 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
|
|
24
24
|
get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function
|
25
25
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import get_tensor_max, \
|
26
26
|
get_tensor_min
|
27
|
-
from
|
27
|
+
from mct_quantizers import QuantizationMethod
|
28
28
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
|
29
29
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor
|
30
30
|
|
@@ -33,9 +33,10 @@ from model_compression_toolkit.core.common.quantization.quantization_params_fn_s
|
|
33
33
|
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
|
34
34
|
get_weights_quantization_fn
|
35
35
|
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
|
36
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
37
36
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
|
38
37
|
QuantizationConfigOptions
|
38
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
39
|
+
FrameworkQuantizationCapabilities
|
39
40
|
|
40
41
|
|
41
42
|
def set_quantization_configuration_to_graph(graph: Graph,
|
@@ -71,14 +72,14 @@ def set_quantization_configuration_to_graph(graph: Graph,
|
|
71
72
|
graph=graph,
|
72
73
|
quant_config=quant_config,
|
73
74
|
fw_info=graph.fw_info,
|
74
|
-
|
75
|
+
fqc=graph.fqc,
|
75
76
|
mixed_precision_enable=mixed_precision_enable,
|
76
77
|
manual_bit_width_override=nodes_to_manipulate_bit_widths.get(n))
|
77
78
|
return graph
|
78
79
|
|
79
80
|
|
80
81
|
def filter_node_qco_by_graph(node: BaseNode,
|
81
|
-
|
82
|
+
fqc: FrameworkQuantizationCapabilities,
|
82
83
|
graph: Graph,
|
83
84
|
node_qc_options: QuantizationConfigOptions
|
84
85
|
) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
@@ -90,7 +91,7 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
90
91
|
|
91
92
|
Args:
|
92
93
|
node: Node for filtering.
|
93
|
-
|
94
|
+
fqc: FQC to extract the QuantizationConfigOptions for the next nodes.
|
94
95
|
graph: Graph object.
|
95
96
|
node_qc_options: Node's QuantizationConfigOptions.
|
96
97
|
|
@@ -108,7 +109,7 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
108
109
|
next_nodes = []
|
109
110
|
while len(_next_nodes):
|
110
111
|
n = _next_nodes.pop(0)
|
111
|
-
qco = n.get_qco(
|
112
|
+
qco = n.get_qco(fqc)
|
112
113
|
qp = [qc.quantization_preserving for qc in qco.quantization_configurations]
|
113
114
|
if not all(qp) and any(qp):
|
114
115
|
Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.')
|
@@ -117,7 +118,7 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
117
118
|
next_nodes.append(n)
|
118
119
|
|
119
120
|
if len(next_nodes):
|
120
|
-
next_nodes_qc_options = [_node.get_qco(
|
121
|
+
next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
|
121
122
|
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
|
122
123
|
for qc_opts in next_nodes_qc_options
|
123
124
|
for op_cfg in qc_opts.quantization_configurations])
|
@@ -126,7 +127,7 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
126
127
|
_node_qc_options = [_option for _option in _node_qc_options
|
127
128
|
if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
|
128
129
|
if len(_node_qc_options) == 0:
|
129
|
-
Logger.critical(f"Graph doesn't match
|
130
|
+
Logger.critical(f"Graph doesn't match FQC bit configurations: {node} -> {next_nodes}.")
|
130
131
|
|
131
132
|
# Verify base config match
|
132
133
|
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
|
@@ -136,9 +137,9 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
136
137
|
if len(_node_qc_options) > 0:
|
137
138
|
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
|
138
139
|
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
|
139
|
-
Logger.warning(f"Node {node} base quantization config changed to match Graph and
|
140
|
+
Logger.warning(f"Node {node} base quantization config changed to match Graph and FQC configuration.\nCause: {node} -> {next_nodes}.")
|
140
141
|
else:
|
141
|
-
Logger.critical(f"Graph doesn't match
|
142
|
+
Logger.critical(f"Graph doesn't match FQC bit configurations: {node} -> {next_nodes}.") # pragma: no cover
|
142
143
|
|
143
144
|
return _base_config, _node_qc_options
|
144
145
|
|
@@ -147,7 +148,7 @@ def set_quantization_configs_to_node(node: BaseNode,
|
|
147
148
|
graph: Graph,
|
148
149
|
quant_config: QuantizationConfig,
|
149
150
|
fw_info: FrameworkInfo,
|
150
|
-
|
151
|
+
fqc: FrameworkQuantizationCapabilities,
|
151
152
|
mixed_precision_enable: bool = False,
|
152
153
|
manual_bit_width_override: Optional[int] = None):
|
153
154
|
"""
|
@@ -158,12 +159,12 @@ def set_quantization_configs_to_node(node: BaseNode,
|
|
158
159
|
graph (Graph): Model's internal representation graph.
|
159
160
|
quant_config (QuantizationConfig): Quantization configuration to generate the node's configurations from.
|
160
161
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework.
|
161
|
-
|
162
|
+
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to get default OpQuantizationConfig.
|
162
163
|
mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
|
163
164
|
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
|
164
165
|
"""
|
165
|
-
node_qc_options = node.get_qco(
|
166
|
-
base_config, node_qc_options_list = filter_node_qco_by_graph(node,
|
166
|
+
node_qc_options = node.get_qco(fqc)
|
167
|
+
base_config, node_qc_options_list = filter_node_qco_by_graph(node, fqc, graph, node_qc_options)
|
167
168
|
|
168
169
|
# If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override,
|
169
170
|
# and update base_config accordingly.
|
@@ -257,7 +258,7 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig,
|
|
257
258
|
attrs_with_enabled_quantization = [attr for attr, cfg in op_cfg.attr_weights_configs_mapping.items()
|
258
259
|
if cfg.enable_weights_quantization]
|
259
260
|
if len(attrs_with_enabled_quantization) > 1:
|
260
|
-
Logger.warning(f"Multiple weights attributes quantization is enabled via the provided
|
261
|
+
Logger.warning(f"Multiple weights attributes quantization is enabled via the provided FQC."
|
261
262
|
f"Quantizing any attribute other than the kernel is experimental "
|
262
263
|
f"and may be subject to unstable behavior."
|
263
264
|
f"Attributes with enabled weights quantization: {attrs_with_enabled_quantization}.")
|
@@ -26,7 +26,7 @@ from model_compression_toolkit.logger import Logger
|
|
26
26
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
27
27
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
28
28
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
29
|
-
from
|
29
|
+
from mct_quantizers import QuantizationMethod
|
30
30
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig
|
31
31
|
|
32
32
|
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.core import common
|
|
22
22
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
23
23
|
from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher, NodeOperationMatcher
|
24
24
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
25
|
-
from
|
25
|
+
from mct_quantizers import QuantizationMethod
|
26
26
|
from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX
|
27
27
|
from model_compression_toolkit.logger import Logger
|
28
28
|
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.logger import Logger
|
|
22
22
|
from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
|
23
23
|
from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
|
24
24
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
25
|
-
from
|
25
|
+
from mct_quantizers import QuantizationMethod
|
26
26
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig
|
27
27
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
|
28
28
|
set_quantization_configs_to_node
|
@@ -359,7 +359,7 @@ def shift_negative_function(graph: Graph,
|
|
359
359
|
node=pad_node,
|
360
360
|
graph=graph,
|
361
361
|
quant_config=core_config.quantization_config,
|
362
|
-
|
362
|
+
fqc=graph.fqc,
|
363
363
|
mixed_precision_enable=core_config.is_mixed_precision_enabled)
|
364
364
|
|
365
365
|
for candidate_qc in pad_node.candidates_quantization_cfg:
|
@@ -376,7 +376,7 @@ def shift_negative_function(graph: Graph,
|
|
376
376
|
node=add_node,
|
377
377
|
graph=graph,
|
378
378
|
quant_config=core_config.quantization_config,
|
379
|
-
|
379
|
+
fqc=graph.fqc,
|
380
380
|
mixed_precision_enable=core_config.is_mixed_precision_enabled)
|
381
381
|
|
382
382
|
original_non_linear_activation_nbits = non_linear_node_cfg_candidate.activation_n_bits
|
@@ -392,7 +392,7 @@ def shift_negative_function(graph: Graph,
|
|
392
392
|
bypass_candidate_qc.activation_quantization_cfg.activation_quantization_params[SIGNED] = False
|
393
393
|
graph.shift_stats_collector(bypass_node, np.array(shift_value))
|
394
394
|
|
395
|
-
add_node_qco = add_node.get_qco(graph.
|
395
|
+
add_node_qco = add_node.get_qco(graph.fqc).quantization_configurations
|
396
396
|
for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg):
|
397
397
|
for attr in add_node.get_node_weights_attributes():
|
398
398
|
candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
|
@@ -533,7 +533,7 @@ def apply_shift_negative_correction(graph: Graph,
|
|
533
533
|
nodes = list(graph.nodes())
|
534
534
|
for n in nodes:
|
535
535
|
# Skip substitution if QuantizationMethod is uniform.
|
536
|
-
node_qco = n.get_qco(graph.
|
536
|
+
node_qco = n.get_qco(graph.fqc)
|
537
537
|
if any([op_qc.activation_quantization_method is QuantizationMethod.UNIFORM
|
538
538
|
for op_qc in node_qco.quantization_configurations]):
|
539
539
|
continue
|
@@ -29,8 +29,9 @@ from model_compression_toolkit.core.common.quantization.set_node_quantization_co
|
|
29
29
|
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
|
30
30
|
from model_compression_toolkit.core.common.substitutions.linear_collapsing_substitution import \
|
31
31
|
linear_collapsing_substitute
|
32
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
33
32
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
|
33
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
34
|
+
FrameworkQuantizationCapabilities
|
34
35
|
|
35
36
|
|
36
37
|
def graph_preparation_runner(in_model: Any,
|
@@ -38,7 +39,7 @@ def graph_preparation_runner(in_model: Any,
|
|
38
39
|
quantization_config: QuantizationConfig,
|
39
40
|
fw_info: FrameworkInfo,
|
40
41
|
fw_impl: FrameworkImplementation,
|
41
|
-
|
42
|
+
fqc: FrameworkQuantizationCapabilities,
|
42
43
|
bit_width_config: BitWidthConfig = None,
|
43
44
|
tb_w: TensorboardWriter = None,
|
44
45
|
mixed_precision_enable: bool = False,
|
@@ -58,7 +59,7 @@ def graph_preparation_runner(in_model: Any,
|
|
58
59
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
59
60
|
groups of layers by how they should be quantized, etc.).
|
60
61
|
fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
|
61
|
-
|
62
|
+
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities object that models the inference target platform and
|
62
63
|
the attached framework operator's information.
|
63
64
|
bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
|
64
65
|
tb_w (TensorboardWriter): TensorboardWriter object for logging.
|
@@ -71,7 +72,7 @@ def graph_preparation_runner(in_model: Any,
|
|
71
72
|
|
72
73
|
graph = read_model_to_graph(in_model,
|
73
74
|
representative_data_gen,
|
74
|
-
|
75
|
+
fqc,
|
75
76
|
fw_info,
|
76
77
|
fw_impl)
|
77
78
|
|
@@ -79,7 +80,7 @@ def graph_preparation_runner(in_model: Any,
|
|
79
80
|
tb_w.add_graph(graph, 'initial_graph')
|
80
81
|
|
81
82
|
transformed_graph = get_finalized_graph(graph,
|
82
|
-
|
83
|
+
fqc,
|
83
84
|
quantization_config,
|
84
85
|
bit_width_config,
|
85
86
|
fw_info,
|
@@ -92,7 +93,7 @@ def graph_preparation_runner(in_model: Any,
|
|
92
93
|
|
93
94
|
|
94
95
|
def get_finalized_graph(initial_graph: Graph,
|
95
|
-
|
96
|
+
fqc: FrameworkQuantizationCapabilities,
|
96
97
|
quant_config: QuantizationConfig = DEFAULTCONFIG,
|
97
98
|
bit_width_config: BitWidthConfig = None,
|
98
99
|
fw_info: FrameworkInfo = None,
|
@@ -106,7 +107,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
106
107
|
|
107
108
|
Args:
|
108
109
|
initial_graph (Graph): Graph to apply the changes to.
|
109
|
-
|
110
|
+
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities object that describes the desired inference target platform (includes fusing patterns MCT should handle).
|
110
111
|
quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
|
111
112
|
quantized.
|
112
113
|
bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
|
@@ -160,7 +161,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
160
161
|
######################################
|
161
162
|
# Layer fusing
|
162
163
|
######################################
|
163
|
-
transformed_graph = fusion(transformed_graph,
|
164
|
+
transformed_graph = fusion(transformed_graph, fqc)
|
164
165
|
|
165
166
|
######################################
|
166
167
|
# Channel equalization
|
@@ -185,7 +186,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
185
186
|
|
186
187
|
def read_model_to_graph(in_model: Any,
|
187
188
|
representative_data_gen: Callable,
|
188
|
-
|
189
|
+
fqc: FrameworkQuantizationCapabilities,
|
189
190
|
fw_info: FrameworkInfo = None,
|
190
191
|
fw_impl: FrameworkImplementation = None) -> Graph:
|
191
192
|
|
@@ -195,7 +196,7 @@ def read_model_to_graph(in_model: Any,
|
|
195
196
|
Args:
|
196
197
|
in_model: Model to optimize and prepare for quantization.
|
197
198
|
representative_data_gen: Dataset used for calibration.
|
198
|
-
|
199
|
+
fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
|
199
200
|
the attached framework operator's information.
|
200
201
|
fw_info: Information needed for quantization about the specific framework (e.g.,
|
201
202
|
kernel channels indices, groups of layers by how they should be quantized, etc.)
|
@@ -207,5 +208,5 @@ def read_model_to_graph(in_model: Any,
|
|
207
208
|
graph = fw_impl.model_reader(in_model,
|
208
209
|
representative_data_gen)
|
209
210
|
graph.set_fw_info(fw_info)
|
210
|
-
graph.
|
211
|
+
graph.set_fqc(fqc)
|
211
212
|
return graph
|
@@ -26,7 +26,7 @@ else:
|
|
26
26
|
|
27
27
|
from model_compression_toolkit.defaultdict import DefaultDict
|
28
28
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
29
|
-
from
|
29
|
+
from mct_quantizers import QuantizationMethod
|
30
30
|
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
|
31
31
|
from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
|
32
32
|
KERNEL, DEPTHWISE_KERNEL, GELU
|
@@ -20,8 +20,7 @@ from model_compression_toolkit.core.common.mixed_precision.configurable_quantize
|
|
20
20
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
21
21
|
CandidateNodeQuantizationConfig
|
22
22
|
from model_compression_toolkit.logger import Logger
|
23
|
-
from
|
24
|
-
from mct_quantizers import QuantizationTarget
|
23
|
+
from mct_quantizers import QuantizationTarget, QuantizationMethod
|
25
24
|
from mct_quantizers import mark_quantizer
|
26
25
|
|
27
26
|
import tensorflow as tf
|
@@ -18,18 +18,17 @@ from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, Cor
|
|
18
18
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
19
19
|
from model_compression_toolkit.logger import Logger
|
20
20
|
from model_compression_toolkit.constants import TENSORFLOW
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
23
22
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
|
24
23
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
25
24
|
|
26
25
|
if FOUND_TF:
|
26
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
27
|
+
AttachTpcToKeras
|
27
28
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
28
29
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
29
30
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
30
31
|
from tensorflow.keras.models import Model
|
31
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
32
|
-
AttachTpcToKeras
|
33
32
|
|
34
33
|
from model_compression_toolkit import get_target_platform_capabilities
|
35
34
|
|
@@ -39,7 +38,7 @@ if FOUND_TF:
|
|
39
38
|
representative_data_gen: Callable,
|
40
39
|
core_config: CoreConfig = CoreConfig(
|
41
40
|
mixed_precision_config=MixedPrecisionQuantizationConfig()),
|
42
|
-
target_platform_capabilities:
|
41
|
+
target_platform_capabilities: TargetPlatformCapabilities = KERAS_DEFAULT_TPC
|
43
42
|
) -> ResourceUtilization:
|
44
43
|
"""
|
45
44
|
Computes resource utilization data that can be used to calculate the desired target resource utilization
|
@@ -51,7 +50,7 @@ if FOUND_TF:
|
|
51
50
|
in_model (Model): Keras model to quantize.
|
52
51
|
representative_data_gen (Callable): Dataset used for calibration.
|
53
52
|
core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision of how the model should be quantized.
|
54
|
-
target_platform_capabilities (
|
53
|
+
target_platform_capabilities (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to optimize the Keras model according to.
|
55
54
|
|
56
55
|
Returns:
|
57
56
|
|
@@ -225,7 +225,7 @@ class PytorchModel(torch.nn.Module):
|
|
225
225
|
"""
|
226
226
|
super(PytorchModel, self).__init__()
|
227
227
|
self.graph = copy.deepcopy(graph)
|
228
|
-
delattr(self.graph, '
|
228
|
+
delattr(self.graph, 'fqc')
|
229
229
|
|
230
230
|
self.node_sort = list(topological_sort(self.graph))
|
231
231
|
self.node_to_activation_quantization_holder = {}
|
@@ -19,7 +19,7 @@ from torch import sigmoid
|
|
19
19
|
|
20
20
|
from model_compression_toolkit.defaultdict import DefaultDict
|
21
21
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
22
|
-
from
|
22
|
+
from mct_quantizers import QuantizationMethod
|
23
23
|
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
|
24
24
|
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
25
25
|
from model_compression_toolkit.core.pytorch.quantizer.fake_quant_builder import power_of_two_quantization, \
|
@@ -21,7 +21,7 @@ from model_compression_toolkit.core.common.mixed_precision.configurable_quantize
|
|
21
21
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
22
22
|
CandidateNodeQuantizationConfig
|
23
23
|
from model_compression_toolkit.logger import Logger
|
24
|
-
from
|
24
|
+
from mct_quantizers import QuantizationMethod
|
25
25
|
from mct_quantizers import QuantizationTarget
|
26
26
|
from mct_quantizers import mark_quantizer
|
27
27
|
|
@@ -20,7 +20,7 @@ from model_compression_toolkit.core.common.mixed_precision.configurable_quantize
|
|
20
20
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
21
21
|
CandidateNodeQuantizationConfig
|
22
22
|
from model_compression_toolkit.logger import Logger
|
23
|
-
from
|
23
|
+
from mct_quantizers import QuantizationMethod
|
24
24
|
from mct_quantizers import QuantizationTarget
|
25
25
|
|
26
26
|
from mct_quantizers import mark_quantizer
|
@@ -17,8 +17,7 @@ from typing import Callable
|
|
17
17
|
|
18
18
|
from model_compression_toolkit.logger import Logger
|
19
19
|
from model_compression_toolkit.constants import PYTORCH
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
22
21
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
23
22
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
|
24
23
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
@@ -30,7 +29,7 @@ if FOUND_TORCH:
|
|
30
29
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
31
30
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
32
31
|
from torch.nn import Module
|
33
|
-
from model_compression_toolkit.target_platform_capabilities.
|
32
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
34
33
|
AttachTpcToPytorch
|
35
34
|
|
36
35
|
from model_compression_toolkit import get_target_platform_capabilities
|
@@ -41,7 +40,7 @@ if FOUND_TORCH:
|
|
41
40
|
def pytorch_resource_utilization_data(in_model: Module,
|
42
41
|
representative_data_gen: Callable,
|
43
42
|
core_config: CoreConfig = CoreConfig(),
|
44
|
-
target_platform_capabilities:
|
43
|
+
target_platform_capabilities: TargetPlatformCapabilities= PYTORCH_DEFAULT_TPC
|
45
44
|
) -> ResourceUtilization:
|
46
45
|
"""
|
47
46
|
Computes resource utilization data that can be used to calculate the desired target resource utilization for mixed-precision quantization.
|
@@ -51,7 +50,7 @@ if FOUND_TORCH:
|
|
51
50
|
in_model (Model): PyTorch model to quantize.
|
52
51
|
representative_data_gen (Callable): Dataset used for calibration.
|
53
52
|
core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision
|
54
|
-
target_platform_capabilities (
|
53
|
+
target_platform_capabilities (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to optimize the PyTorch model according to.
|
55
54
|
|
56
55
|
Returns:
|
57
56
|
|