mct-nightly 2.2.0.20250113.134913__py3-none-any.whl → 2.2.0.20250114.134534__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/RECORD +102 -104
- model_compression_toolkit/__init__.py +2 -2
- model_compression_toolkit/core/common/framework_info.py +1 -3
- model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
- model_compression_toolkit/core/common/graph/base_graph.py +20 -21
- model_compression_toolkit/core/common/graph/base_node.py +44 -17
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +187 -0
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +35 -162
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +668 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +74 -51
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/pruner.py +5 -3
- model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/graph_prep_runner.py +12 -11
- model_compression_toolkit/core/keras/default_framework_info.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
- model_compression_toolkit/core/runner.py +33 -60
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/metadata.py +11 -10
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
- model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
- model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
- model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
- model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
- model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
- model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
- model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +0 -528
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -19,9 +19,8 @@ from typing import Callable
|
|
19
19
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
20
20
|
from model_compression_toolkit.logger import Logger
|
21
21
|
from model_compression_toolkit.constants import PYTORCH
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
23
23
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
24
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
25
24
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
26
25
|
from model_compression_toolkit.core import CoreConfig
|
27
26
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
@@ -40,7 +39,7 @@ if FOUND_TORCH:
|
|
40
39
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
41
40
|
from model_compression_toolkit import get_target_platform_capabilities
|
42
41
|
from mct_quantizers.pytorch.metadata import add_metadata
|
43
|
-
from model_compression_toolkit.target_platform_capabilities.
|
42
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
44
43
|
AttachTpcToPytorch
|
45
44
|
|
46
45
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
@@ -49,11 +48,11 @@ if FOUND_TORCH:
|
|
49
48
|
representative_data_gen: Callable,
|
50
49
|
target_resource_utilization: ResourceUtilization = None,
|
51
50
|
core_config: CoreConfig = CoreConfig(),
|
52
|
-
target_platform_capabilities:
|
51
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
53
52
|
"""
|
54
53
|
Quantize a trained Pytorch module using post-training quantization.
|
55
54
|
By default, the module is quantized using a symmetric constraint quantization thresholds
|
56
|
-
(power of two) as defined in the default
|
55
|
+
(power of two) as defined in the default FrameworkQuantizationCapabilities.
|
57
56
|
The module is first optimized using several transformations (e.g. BatchNormalization folding to
|
58
57
|
preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
|
59
58
|
being collected for each layer's output (and input, depends on the quantization configuration).
|
@@ -112,7 +111,7 @@ if FOUND_TORCH:
|
|
112
111
|
|
113
112
|
# Attach tpc model to framework
|
114
113
|
attach2pytorch = AttachTpcToPytorch()
|
115
|
-
|
114
|
+
framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
|
116
115
|
core_config.quantization_config.custom_tpc_opset_to_layer)
|
117
116
|
|
118
117
|
# Ignore hessian info service as it is not used here yet.
|
@@ -121,7 +120,7 @@ if FOUND_TORCH:
|
|
121
120
|
core_config=core_config,
|
122
121
|
fw_info=fw_info,
|
123
122
|
fw_impl=fw_impl,
|
124
|
-
|
123
|
+
fqc=framework_platform_capabilities,
|
125
124
|
target_resource_utilization=target_resource_utilization,
|
126
125
|
tb_w=tb_w)
|
127
126
|
|
@@ -149,9 +148,9 @@ if FOUND_TORCH:
|
|
149
148
|
fw_info)
|
150
149
|
|
151
150
|
exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction)
|
152
|
-
if
|
151
|
+
if framework_platform_capabilities.tpc.add_metadata:
|
153
152
|
exportable_model = add_metadata(exportable_model,
|
154
|
-
create_model_metadata(
|
153
|
+
create_model_metadata(fqc=framework_platform_capabilities,
|
155
154
|
scheduling_info=scheduling_info))
|
156
155
|
return exportable_model, user_info
|
157
156
|
|
@@ -19,13 +19,14 @@ from functools import partial
|
|
19
19
|
from model_compression_toolkit.core import CoreConfig
|
20
20
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
21
21
|
from model_compression_toolkit.logger import Logger
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
23
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
24
|
+
AttachTpcToKeras
|
23
25
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
24
26
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
25
27
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
26
28
|
MixedPrecisionQuantizationConfig
|
27
29
|
from mct_quantizers import KerasActivationQuantizationHolder
|
28
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
29
30
|
from model_compression_toolkit.core.runner import core_runner
|
30
31
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
31
32
|
|
@@ -55,8 +56,6 @@ if FOUND_TF:
|
|
55
56
|
from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
|
56
57
|
get_activation_quantizer_holder
|
57
58
|
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
58
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
59
|
-
AttachTpcToKeras
|
60
59
|
|
61
60
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
62
61
|
|
@@ -93,7 +92,7 @@ if FOUND_TF:
|
|
93
92
|
target_resource_utilization: ResourceUtilization = None,
|
94
93
|
core_config: CoreConfig = CoreConfig(),
|
95
94
|
qat_config: QATConfig = QATConfig(),
|
96
|
-
target_platform_capabilities:
|
95
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
|
97
96
|
"""
|
98
97
|
Prepare a trained Keras model for quantization aware training. First the model quantization is optimized
|
99
98
|
with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
|
@@ -200,7 +199,7 @@ if FOUND_TF:
|
|
200
199
|
core_config=core_config,
|
201
200
|
fw_info=DEFAULT_KERAS_INFO,
|
202
201
|
fw_impl=fw_impl,
|
203
|
-
|
202
|
+
fqc=target_platform_capabilities,
|
204
203
|
target_resource_utilization=target_resource_utilization,
|
205
204
|
tb_w=tb_w)
|
206
205
|
|
@@ -21,7 +21,7 @@ from tensorflow.python.framework.tensor_shape import TensorShape
|
|
21
21
|
|
22
22
|
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
|
23
23
|
|
24
|
-
from
|
24
|
+
from mct_quantizers import QuantizationMethod
|
25
25
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
26
26
|
from mct_quantizers import QuantizationTarget, mark_quantizer
|
27
27
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.trainable_infrastructure.common.constants import
|
|
22
22
|
|
23
23
|
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
|
24
24
|
|
25
|
-
from
|
25
|
+
from mct_quantizers import QuantizationMethod
|
26
26
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
27
27
|
from mct_quantizers import QuantizationTarget, mark_quantizer
|
28
28
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
@@ -12,13 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
import copy
|
16
15
|
from typing import Callable
|
17
16
|
from functools import partial
|
18
17
|
|
19
18
|
from model_compression_toolkit.constants import PYTORCH
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
22
21
|
AttachTpcToPytorch
|
23
22
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
24
23
|
|
@@ -26,12 +25,9 @@ from model_compression_toolkit.core import CoreConfig
|
|
26
25
|
from model_compression_toolkit.core import common
|
27
26
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
28
27
|
from model_compression_toolkit.logger import Logger
|
29
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
30
28
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
31
29
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
32
30
|
MixedPrecisionQuantizationConfig
|
33
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
|
34
|
-
TargetPlatformCapabilities
|
35
31
|
from model_compression_toolkit.core.runner import core_runner
|
36
32
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
37
33
|
|
@@ -82,7 +78,7 @@ if FOUND_TORCH:
|
|
82
78
|
target_resource_utilization: ResourceUtilization = None,
|
83
79
|
core_config: CoreConfig = CoreConfig(),
|
84
80
|
qat_config: QATConfig = QATConfig(),
|
85
|
-
target_platform_capabilities:
|
81
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
86
82
|
"""
|
87
83
|
Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
|
88
84
|
with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
|
@@ -159,7 +155,7 @@ if FOUND_TORCH:
|
|
159
155
|
|
160
156
|
# Attach tpc model to framework
|
161
157
|
attach2pytorch = AttachTpcToPytorch()
|
162
|
-
|
158
|
+
framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
|
163
159
|
core_config.quantization_config.custom_tpc_opset_to_layer)
|
164
160
|
|
165
161
|
# Ignore hessian scores service as we do not use it here
|
@@ -168,7 +164,7 @@ if FOUND_TORCH:
|
|
168
164
|
core_config=core_config,
|
169
165
|
fw_info=DEFAULT_PYTORCH_INFO,
|
170
166
|
fw_impl=fw_impl,
|
171
|
-
|
167
|
+
fqc=framework_platform_capabilities,
|
172
168
|
target_resource_utilization=target_resource_utilization,
|
173
169
|
tb_w=tb_w)
|
174
170
|
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
|
21
|
-
from
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
22
|
from mct_quantizers import PytorchQuantizationWrapper
|
23
23
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
24
24
|
from model_compression_toolkit import constants as C
|
@@ -28,7 +28,7 @@ from model_compression_toolkit.trainable_infrastructure.pytorch.quantizer_utils
|
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
29
29
|
from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
|
30
30
|
TrainableQuantizerWeightsConfig
|
31
|
-
from
|
31
|
+
from mct_quantizers import QuantizationMethod
|
32
32
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
33
33
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
|
34
34
|
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
|
21
|
-
from
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
22
|
from mct_quantizers import PytorchQuantizationWrapper
|
23
23
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
24
24
|
from model_compression_toolkit import constants as C
|
@@ -20,7 +20,7 @@ from torch import Tensor
|
|
20
20
|
from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
|
21
21
|
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
|
22
22
|
|
23
|
-
from
|
23
|
+
from mct_quantizers import QuantizationMethod
|
24
24
|
from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
|
25
25
|
|
26
26
|
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer
|
@@ -12,3 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
|
16
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import AttributeFilter
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import (
|
18
|
+
FrameworkQuantizationCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater,
|
19
|
+
LayerFilterParams, OperationsToLayers, get_current_tpc)
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities, OperatorsSet, \
|
21
|
+
OperatorSetGroup, Signedness, AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions, Fusing
|
22
|
+
|
23
|
+
from mct_quantizers import QuantizationMethod
|
@@ -7,6 +7,6 @@ OpQuantizationConfig = schema.OpQuantizationConfig
|
|
7
7
|
QuantizationConfigOptions = schema.QuantizationConfigOptions
|
8
8
|
OperatorsSetBase = schema.OperatorsSetBase
|
9
9
|
OperatorsSet = schema.OperatorsSet
|
10
|
-
|
10
|
+
OperatorSetGroup = schema.OperatorSetGroup
|
11
11
|
Fusing = schema.Fusing
|
12
|
-
|
12
|
+
TargetPlatformCapabilities = schema.TargetPlatformCapabilities
|
@@ -16,7 +16,7 @@ from logging import Logger
|
|
16
16
|
from typing import Optional
|
17
17
|
|
18
18
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
|
19
|
-
|
19
|
+
TargetPlatformCapabilities, QuantizationConfigOptions, OperatorsSetBase
|
20
20
|
|
21
21
|
|
22
22
|
def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) -> int:
|
@@ -32,31 +32,31 @@ def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) ->
|
|
32
32
|
return max(op_quantization_config.supported_input_activation_n_bits)
|
33
33
|
|
34
34
|
|
35
|
-
def get_config_options_by_operators_set(
|
35
|
+
def get_config_options_by_operators_set(tpc: TargetPlatformCapabilities,
|
36
36
|
operators_set_name: str) -> QuantizationConfigOptions:
|
37
37
|
"""
|
38
38
|
Get the QuantizationConfigOptions of an OperatorsSet by its name.
|
39
39
|
|
40
40
|
Args:
|
41
|
-
|
41
|
+
tpc (TargetPlatformCapabilities): The target platform model containing the operator sets and their configurations.
|
42
42
|
operators_set_name (str): The name of the OperatorsSet whose quantization configuration options are to be retrieved.
|
43
43
|
|
44
44
|
Returns:
|
45
45
|
QuantizationConfigOptions: The quantization configuration options associated with the specified OperatorsSet,
|
46
46
|
or the default quantization configuration options if the OperatorsSet is not found.
|
47
47
|
"""
|
48
|
-
for op_set in
|
48
|
+
for op_set in tpc.operator_set:
|
49
49
|
if operators_set_name == op_set.name:
|
50
50
|
return op_set.qc_options
|
51
|
-
return
|
51
|
+
return tpc.default_qco
|
52
52
|
|
53
53
|
|
54
|
-
def get_default_op_quantization_config(
|
54
|
+
def get_default_op_quantization_config(tpc: TargetPlatformCapabilities) -> OpQuantizationConfig:
|
55
55
|
"""
|
56
|
-
Get the default OpQuantizationConfig of the
|
56
|
+
Get the default OpQuantizationConfig of the TargetPlatformCapabilities.
|
57
57
|
|
58
58
|
Args:
|
59
|
-
|
59
|
+
tpc (TargetPlatformCapabilities): The target platform model containing the default quantization configuration.
|
60
60
|
|
61
61
|
Returns:
|
62
62
|
OpQuantizationConfig: The default quantization configuration.
|
@@ -64,32 +64,32 @@ def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuant
|
|
64
64
|
Raises:
|
65
65
|
AssertionError: If the default quantization configuration list contains more than one configuration option.
|
66
66
|
"""
|
67
|
-
assert len(
|
67
|
+
assert len(tpc.default_qco.quantization_configurations) == 1, \
|
68
68
|
f"Default quantization configuration options must contain only one option, " \
|
69
|
-
f"but found {len(
|
70
|
-
return
|
69
|
+
f"but found {len(tpc.default_qco.quantization_configurations)} configurations." # pragma: no cover
|
70
|
+
return tpc.default_qco.quantization_configurations[0]
|
71
71
|
|
72
72
|
|
73
|
-
def is_opset_in_model(
|
73
|
+
def is_opset_in_model(tpc: TargetPlatformCapabilities, opset_name: str) -> bool:
|
74
74
|
"""
|
75
75
|
Check whether an OperatorsSet is defined in the model.
|
76
76
|
|
77
77
|
Args:
|
78
|
-
|
78
|
+
tpc (TargetPlatformCapabilities): The target platform model containing the list of operator sets.
|
79
79
|
opset_name (str): The name of the OperatorsSet to check for existence.
|
80
80
|
|
81
81
|
Returns:
|
82
82
|
bool: True if an OperatorsSet with the given name exists in the target platform model,
|
83
83
|
otherwise False.
|
84
84
|
"""
|
85
|
-
return
|
85
|
+
return tpc.operator_set is not None and opset_name in [x.name for x in tpc.operator_set]
|
86
86
|
|
87
|
-
def get_opset_by_name(
|
87
|
+
def get_opset_by_name(tpc: TargetPlatformCapabilities, opset_name: str) -> Optional[OperatorsSetBase]:
|
88
88
|
"""
|
89
89
|
Get an OperatorsSet object from the model by its name.
|
90
90
|
|
91
91
|
Args:
|
92
|
-
|
92
|
+
tpc (TargetPlatformCapabilities): The target platform model containing the list of operator sets.
|
93
93
|
opset_name (str): The name of the OperatorsSet to be retrieved.
|
94
94
|
|
95
95
|
Returns:
|
@@ -99,7 +99,7 @@ def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optiona
|
|
99
99
|
Raises:
|
100
100
|
A critical log message if multiple operator sets with the same name are found.
|
101
101
|
"""
|
102
|
-
opset_list = [x for x in
|
102
|
+
opset_list = [x for x in tpc.operator_set if x.name == opset_name]
|
103
103
|
if len(opset_list) > 1:
|
104
|
-
Logger.critical(f"Found more than one OperatorsSet in
|
104
|
+
Logger.critical(f"Found more than one OperatorsSet in TargetPlatformCapabilities with the name {opset_name}.") # pragma: no cover
|
105
105
|
return opset_list[0] if opset_list else None
|
@@ -414,7 +414,7 @@ class QuantizationConfigOptions(BaseModel):
|
|
414
414
|
|
415
415
|
class TargetPlatformModelComponent(BaseModel):
|
416
416
|
"""
|
417
|
-
Component of
|
417
|
+
Component of TargetPlatformCapabilities (Fusing, OperatorsSet, etc.).
|
418
418
|
"""
|
419
419
|
class Config:
|
420
420
|
frozen = True
|
@@ -433,7 +433,7 @@ class OperatorsSet(OperatorsSetBase):
|
|
433
433
|
Set of operators that are represented by a unique label.
|
434
434
|
|
435
435
|
Attributes:
|
436
|
-
name (Union[str, OperatorSetNames]): The set's label (must be unique within a
|
436
|
+
name (Union[str, OperatorSetNames]): The set's label (must be unique within a TargetPlatformCapabilities).
|
437
437
|
qc_options (Optional[QuantizationConfigOptions]): Configuration options to use for this set of operations.
|
438
438
|
If None, it represents a fusing set.
|
439
439
|
type (Literal["OperatorsSet"]): Fixed type identifier.
|
@@ -457,7 +457,7 @@ class OperatorsSet(OperatorsSetBase):
|
|
457
457
|
return {"name": self.name}
|
458
458
|
|
459
459
|
|
460
|
-
class
|
460
|
+
class OperatorSetGroup(OperatorsSetBase):
|
461
461
|
"""
|
462
462
|
Concatenate a tuple of operator sets to treat them similarly in different places (like fusing).
|
463
463
|
|
@@ -469,7 +469,7 @@ class OperatorSetConcat(OperatorsSetBase):
|
|
469
469
|
name: Optional[str] = None # Will be set in the validator if not given
|
470
470
|
|
471
471
|
# Define a private attribute _type
|
472
|
-
type: Literal["
|
472
|
+
type: Literal["OperatorSetGroup"] = "OperatorSetGroup"
|
473
473
|
|
474
474
|
class Config:
|
475
475
|
frozen = True
|
@@ -518,11 +518,11 @@ class Fusing(TargetPlatformModelComponent):
|
|
518
518
|
hence no quantization is applied between them.
|
519
519
|
|
520
520
|
Attributes:
|
521
|
-
operator_groups (Tuple[Union[OperatorsSet,
|
522
|
-
each being either an
|
521
|
+
operator_groups (Tuple[Union[OperatorsSet, OperatorSetGroup], ...]): A tuple of operator groups,
|
522
|
+
each being either an OperatorSetGroup or an OperatorsSet.
|
523
523
|
name (Optional[str]): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
|
524
524
|
"""
|
525
|
-
operator_groups: Tuple[Annotated[Union[OperatorsSet,
|
525
|
+
operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetGroup], Field(discriminator='type')], ...]
|
526
526
|
name: Optional[str] = None # Will be set in the validator if not given.
|
527
527
|
|
528
528
|
class Config:
|
@@ -591,7 +591,7 @@ class Fusing(TargetPlatformModelComponent):
|
|
591
591
|
for i in range(len(self.operator_groups) - len(other.operator_groups) + 1):
|
592
592
|
for j in range(len(other.operator_groups)):
|
593
593
|
if self.operator_groups[i + j] != other.operator_groups[j] and not (
|
594
|
-
isinstance(self.operator_groups[i + j],
|
594
|
+
isinstance(self.operator_groups[i + j], OperatorSetGroup) and (
|
595
595
|
other.operator_groups[j] in self.operator_groups[i + j].operators_set)):
|
596
596
|
break
|
597
597
|
else:
|
@@ -621,7 +621,7 @@ class Fusing(TargetPlatformModelComponent):
|
|
621
621
|
for x in self.operator_groups
|
622
622
|
])
|
623
623
|
|
624
|
-
class
|
624
|
+
class TargetPlatformCapabilities(BaseModel):
|
625
625
|
"""
|
626
626
|
Represents the hardware configuration used for quantized model inference.
|
627
627
|
|
@@ -644,7 +644,7 @@ class TargetPlatformModel(BaseModel):
|
|
644
644
|
tpc_patch_version: Optional[int]
|
645
645
|
tpc_platform_type: Optional[str]
|
646
646
|
add_metadata: bool = True
|
647
|
-
name: Optional[str] = "
|
647
|
+
name: Optional[str] = "default_tpc"
|
648
648
|
is_simd_padding: bool = False
|
649
649
|
|
650
650
|
SCHEMA_VERSION: int = 1
|
@@ -682,10 +682,10 @@ class TargetPlatformModel(BaseModel):
|
|
682
682
|
|
683
683
|
def get_info(self) -> Dict[str, Any]:
|
684
684
|
"""
|
685
|
-
Get a dictionary summarizing the
|
685
|
+
Get a dictionary summarizing the TargetPlatformCapabilities properties.
|
686
686
|
|
687
687
|
Returns:
|
688
|
-
Dict[str, Any]: Summary of the
|
688
|
+
Dict[str, Any]: Summary of the TargetPlatformCapabilities properties.
|
689
689
|
"""
|
690
690
|
return {
|
691
691
|
"Model name": self.name,
|
@@ -695,6 +695,6 @@ class TargetPlatformModel(BaseModel):
|
|
695
695
|
|
696
696
|
def show(self):
|
697
697
|
"""
|
698
|
-
Display the
|
698
|
+
Display the TargetPlatformCapabilities.
|
699
699
|
"""
|
700
700
|
pprint.pprint(self.get_info(), sort_dicts=False)
|
@@ -13,13 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from model_compression_toolkit.target_platform_capabilities.
|
17
|
-
from model_compression_toolkit.target_platform_capabilities.
|
18
|
-
from model_compression_toolkit.target_platform_capabilities.
|
19
|
-
Eq, GreaterEq, NotEq, SmallerEq, Greater, Smaller
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import \
|
16
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import get_current_tpc
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import FrameworkQuantizationCapabilities
|
18
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.layer_filter_params import \
|
21
19
|
LayerFilterParams
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import \
|
21
|
+
Eq, GreaterEq, NotEq, SmallerEq, Greater, Smaller
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.operations_to_layers import \
|
23
23
|
OperationsToLayers, OperationsSetToLayers
|
24
24
|
|
25
25
|
|
@@ -1,12 +1,12 @@
|
|
1
1
|
from typing import Dict, Optional
|
2
2
|
|
3
3
|
from model_compression_toolkit.logger import Logger
|
4
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
4
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities, \
|
5
5
|
OperatorsSet
|
6
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, \
|
7
|
-
OperationsSetToLayers
|
8
6
|
|
9
7
|
from model_compression_toolkit.core.common.quantization.quantization_config import CustomOpsetLayers
|
8
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import \
|
9
|
+
FrameworkQuantizationCapabilities, OperationsSetToLayers
|
10
10
|
|
11
11
|
|
12
12
|
class AttachTpcToFramework:
|
@@ -19,25 +19,25 @@ class AttachTpcToFramework:
|
|
19
19
|
# in the operation set are provided in the mapping, a DefaultDict should be supplied to handle missing entries.
|
20
20
|
self._opset2attr_mapping = None # Mapping of operation sets to their corresponding framework-specific layers
|
21
21
|
|
22
|
-
def attach(self, tpc_model:
|
22
|
+
def attach(self, tpc_model: TargetPlatformCapabilities,
|
23
23
|
custom_opset2layer: Optional[Dict[str, 'CustomOpsetLayers']] = None
|
24
|
-
) ->
|
24
|
+
) -> FrameworkQuantizationCapabilities:
|
25
25
|
"""
|
26
|
-
Attaching a
|
26
|
+
Attaching a TargetPlatformCapabilities which includes a platform capabilities description to specific
|
27
27
|
framework's operators.
|
28
28
|
|
29
29
|
Args:
|
30
|
-
tpc_model: a
|
30
|
+
tpc_model: a TargetPlatformCapabilities object.
|
31
31
|
custom_opset2layer: optional set of custom operator sets which allows to add/override the built-in set
|
32
32
|
of framework operator, to define a specific behavior for those operators. This dictionary should map
|
33
33
|
an operator set unique name to a pair of: a list of framework operators and an optional
|
34
34
|
operator's attributes names mapping.
|
35
35
|
|
36
|
-
Returns: a
|
36
|
+
Returns: a FrameworkQuantizationCapabilities object.
|
37
37
|
|
38
38
|
"""
|
39
39
|
|
40
|
-
tpc =
|
40
|
+
tpc = FrameworkQuantizationCapabilities(tpc_model)
|
41
41
|
custom_opset2layer = custom_opset2layer if custom_opset2layer is not None else {}
|
42
42
|
|
43
43
|
with tpc:
|
@@ -59,7 +59,7 @@ class AttachTpcToFramework:
|
|
59
59
|
attr_mapping = self._opset2attr_mapping.get(opset.name)
|
60
60
|
OperationsSetToLayers(opset.name, layers, attr_mapping=attr_mapping)
|
61
61
|
else:
|
62
|
-
Logger.critical(f'{opset.name} is defined in
|
62
|
+
Logger.critical(f'{opset.name} is defined in TargetPlatformCapabilities, '
|
63
63
|
f'but is not defined in the framework set of operators or in the provided '
|
64
64
|
f'custom operator sets mapping.')
|
65
65
|
|
@@ -16,6 +16,9 @@
|
|
16
16
|
import tensorflow as tf
|
17
17
|
from packaging import version
|
18
18
|
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
|
21
|
+
AttachTpcToFramework
|
19
22
|
from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS
|
20
23
|
|
21
24
|
if FOUND_SONY_CUSTOM_LAYERS:
|
@@ -34,9 +37,6 @@ from model_compression_toolkit import DefaultDict
|
|
34
37
|
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS, \
|
35
38
|
BIAS_ATTR, KERAS_KERNEL, KERAS_DEPTHWISE_KERNEL
|
36
39
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
|
37
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
|
38
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
|
39
|
-
AttachTpcToFramework
|
40
40
|
|
41
41
|
|
42
42
|
class AttachTpcToKeras(AttachTpcToFramework):
|
@@ -28,9 +28,10 @@ from model_compression_toolkit import DefaultDict
|
|
28
28
|
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS, \
|
29
29
|
BIAS_ATTR
|
30
30
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
|
31
|
-
from model_compression_toolkit.target_platform_capabilities.
|
32
|
-
from model_compression_toolkit.target_platform_capabilities.
|
31
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
|
32
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
|
33
33
|
AttachTpcToFramework
|
34
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import Eq
|
34
35
|
|
35
36
|
|
36
37
|
class AttachTpcToPytorch(AttachTpcToFramework):
|
@@ -18,7 +18,7 @@ from model_compression_toolkit.logger import Logger
|
|
18
18
|
def get_current_tpc():
|
19
19
|
"""
|
20
20
|
|
21
|
-
Returns: The current
|
21
|
+
Returns: The current FrameworkQuantizationCapabilities that is being used and accessed.
|
22
22
|
|
23
23
|
"""
|
24
24
|
return _current_tpc.get()
|
@@ -26,7 +26,7 @@ def get_current_tpc():
|
|
26
26
|
|
27
27
|
class _CurrentTPC(object):
|
28
28
|
"""
|
29
|
-
Wrapper of the current
|
29
|
+
Wrapper of the current FrameworkQuantizationCapabilities object that is being accessed and defined.
|
30
30
|
"""
|
31
31
|
def __init__(self):
|
32
32
|
super(_CurrentTPC, self).__init__()
|
@@ -35,28 +35,28 @@ class _CurrentTPC(object):
|
|
35
35
|
def get(self):
|
36
36
|
"""
|
37
37
|
|
38
|
-
Returns: The current
|
38
|
+
Returns: The current FrameworkQuantizationCapabilities that is being defined.
|
39
39
|
|
40
40
|
"""
|
41
41
|
if self.tpc is None:
|
42
|
-
Logger.critical("'
|
42
|
+
Logger.critical("'FrameworkQuantizationCapabilities' (TPC) instance is not initialized.")
|
43
43
|
return self.tpc
|
44
44
|
|
45
45
|
def reset(self):
|
46
46
|
"""
|
47
47
|
|
48
|
-
Reset the current
|
49
|
-
used as the current
|
48
|
+
Reset the current FrameworkQuantizationCapabilities so a new FrameworkQuantizationCapabilities can be wrapped and
|
49
|
+
used as the current FrameworkQuantizationCapabilities object.
|
50
50
|
|
51
51
|
"""
|
52
52
|
self.tpc = None
|
53
53
|
|
54
54
|
def set(self, target_platform_capabilities):
|
55
55
|
"""
|
56
|
-
Set and wrap a
|
56
|
+
Set and wrap a FrameworkQuantizationCapabilities as the current FrameworkQuantizationCapabilities.
|
57
57
|
|
58
58
|
Args:
|
59
|
-
target_platform_capabilities:
|
59
|
+
target_platform_capabilities: FrameworkQuantizationCapabilities to set as the current FrameworkQuantizationCapabilities
|
60
60
|
to access and use.
|
61
61
|
|
62
62
|
"""
|