mct-nightly 2.2.0.20250106.546__py3-none-any.whl → 2.2.0.20250107.15510__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/RECORD +43 -78
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/cut.py +5 -2
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +25 -25
- model_compression_toolkit/core/common/quantization/quantization_config.py +19 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -33
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py +2 -2
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +11 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/matmul_decomposition.py +499 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +3 -0
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +11 -3
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +10 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +8 -2
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -2
- model_compression_toolkit/ptq/keras/quantization_facade.py +10 -1
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +9 -1
- model_compression_toolkit/qat/__init__.py +5 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +9 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -1
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +63 -55
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py +29 -18
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py +78 -57
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py +69 -54
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +0 -10
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +93 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +46 -28
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +6 -5
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +51 -19
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +8 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +19 -9
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +7 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +46 -32
- model_compression_toolkit/xquant/keras/keras_report_utils.py +11 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +0 -98
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py +0 -129
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_pytorch.py +0 -108
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +0 -217
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_keras.py +0 -130
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_pytorch.py +0 -109
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +0 -215
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_keras.py +0 -130
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_pytorch.py +0 -110
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +0 -222
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +0 -132
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py +0 -110
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +0 -219
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +0 -132
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py +0 -109
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +0 -246
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +0 -135
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +0 -113
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +0 -230
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +0 -132
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +0 -110
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +0 -332
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +0 -140
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +0 -122
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +0 -55
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_keras.py +0 -89
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +0 -78
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +0 -55
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_keras.py +0 -118
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_pytorch.py +0 -100
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/top_level.txt +0 -0
@@ -17,12 +17,12 @@ from typing import Callable, Tuple
|
|
17
17
|
|
18
18
|
from model_compression_toolkit import get_target_platform_capabilities
|
19
19
|
from model_compression_toolkit.constants import TENSORFLOW
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
20
21
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
21
22
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
22
23
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
23
24
|
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
|
24
25
|
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
|
25
|
-
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
26
26
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
27
27
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
28
28
|
from model_compression_toolkit.logger import Logger
|
@@ -35,6 +35,8 @@ if FOUND_TF:
|
|
35
35
|
from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
|
36
36
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
37
37
|
from tensorflow.keras.models import Model
|
38
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
39
|
+
AttachTpcToKeras
|
38
40
|
|
39
41
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
40
42
|
|
@@ -42,7 +44,7 @@ if FOUND_TF:
|
|
42
44
|
target_resource_utilization: ResourceUtilization,
|
43
45
|
representative_data_gen: Callable,
|
44
46
|
pruning_config: PruningConfig = PruningConfig(),
|
45
|
-
target_platform_capabilities:
|
47
|
+
target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
|
46
48
|
"""
|
47
49
|
Perform structured pruning on a Keras model to meet a specified target resource utilization.
|
48
50
|
This function prunes the provided model according to the target resource utilization by grouping and pruning
|
@@ -111,6 +113,10 @@ if FOUND_TF:
|
|
111
113
|
# Instantiate the Keras framework implementation.
|
112
114
|
fw_impl = PruningKerasImplementation()
|
113
115
|
|
116
|
+
# Attach tpc model to framework
|
117
|
+
attach2keras = AttachTpcToKeras()
|
118
|
+
target_platform_capabilities = attach2keras.attach(target_platform_capabilities)
|
119
|
+
|
114
120
|
# Convert the original Keras model to an internal graph representation.
|
115
121
|
float_graph = read_model_to_graph(model,
|
116
122
|
representative_data_gen,
|
@@ -16,12 +16,12 @@
|
|
16
16
|
from typing import Callable, Tuple
|
17
17
|
from model_compression_toolkit import get_target_platform_capabilities
|
18
18
|
from model_compression_toolkit.constants import PYTORCH
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
19
20
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
20
21
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
21
22
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
22
23
|
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
|
23
24
|
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
|
24
|
-
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
25
25
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
26
26
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
27
27
|
from model_compression_toolkit.logger import Logger
|
@@ -38,6 +38,8 @@ if FOUND_TORCH:
|
|
38
38
|
PruningPytorchImplementation
|
39
39
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
40
40
|
from torch.nn import Module
|
41
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
|
42
|
+
AttachTpcToPytorch
|
41
43
|
|
42
44
|
# Set the default Target Platform Capabilities (TPC) for PyTorch.
|
43
45
|
DEFAULT_PYOTRCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
@@ -46,7 +48,7 @@ if FOUND_TORCH:
|
|
46
48
|
target_resource_utilization: ResourceUtilization,
|
47
49
|
representative_data_gen: Callable,
|
48
50
|
pruning_config: PruningConfig = PruningConfig(),
|
49
|
-
target_platform_capabilities:
|
51
|
+
target_platform_capabilities: TargetPlatformModel = DEFAULT_PYOTRCH_TPC) -> \
|
50
52
|
Tuple[Module, PruningInfo]:
|
51
53
|
"""
|
52
54
|
Perform structured pruning on a Pytorch model to meet a specified target resource utilization.
|
@@ -117,6 +119,10 @@ if FOUND_TORCH:
|
|
117
119
|
# Instantiate the Pytorch framework implementation.
|
118
120
|
fw_impl = PruningPytorchImplementation()
|
119
121
|
|
122
|
+
# Attach TPC to framework
|
123
|
+
attach2pytorch = AttachTpcToPytorch()
|
124
|
+
target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities)
|
125
|
+
|
120
126
|
# Convert the original Pytorch model to an internal graph representation.
|
121
127
|
float_graph = read_model_to_graph(model,
|
122
128
|
representative_data_gen,
|
@@ -22,6 +22,7 @@ from model_compression_toolkit.core.common.quantization.quantize_graph_weights i
|
|
22
22
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
23
23
|
from model_compression_toolkit.logger import Logger
|
24
24
|
from model_compression_toolkit.constants import TENSORFLOW
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
25
26
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
26
27
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
27
28
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
@@ -41,6 +42,9 @@ if FOUND_TF:
|
|
41
42
|
|
42
43
|
from model_compression_toolkit import get_target_platform_capabilities
|
43
44
|
from mct_quantizers.keras.metadata import add_metadata
|
45
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
46
|
+
AttachTpcToKeras
|
47
|
+
|
44
48
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
45
49
|
|
46
50
|
|
@@ -48,7 +52,7 @@ if FOUND_TF:
|
|
48
52
|
representative_data_gen: Callable,
|
49
53
|
target_resource_utilization: ResourceUtilization = None,
|
50
54
|
core_config: CoreConfig = CoreConfig(),
|
51
|
-
target_platform_capabilities:
|
55
|
+
target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC):
|
52
56
|
"""
|
53
57
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
54
58
|
symmetric constraint quantization thresholds (power of two).
|
@@ -134,6 +138,11 @@ if FOUND_TF:
|
|
134
138
|
|
135
139
|
fw_impl = KerasImplementation()
|
136
140
|
|
141
|
+
attach2keras = AttachTpcToKeras()
|
142
|
+
target_platform_capabilities = attach2keras.attach(
|
143
|
+
target_platform_capabilities,
|
144
|
+
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
|
145
|
+
|
137
146
|
# Ignore returned hessian service as PTQ does not use it
|
138
147
|
tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_model,
|
139
148
|
representative_data_gen=representative_data_gen,
|
@@ -19,6 +19,7 @@ from typing import Callable
|
|
19
19
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
20
20
|
from model_compression_toolkit.logger import Logger
|
21
21
|
from model_compression_toolkit.constants import PYTORCH
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
22
23
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
23
24
|
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
24
25
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
@@ -39,6 +40,8 @@ if FOUND_TORCH:
|
|
39
40
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
40
41
|
from model_compression_toolkit import get_target_platform_capabilities
|
41
42
|
from mct_quantizers.pytorch.metadata import add_metadata
|
43
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
|
44
|
+
AttachTpcToPytorch
|
42
45
|
|
43
46
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
44
47
|
|
@@ -46,7 +49,7 @@ if FOUND_TORCH:
|
|
46
49
|
representative_data_gen: Callable,
|
47
50
|
target_resource_utilization: ResourceUtilization = None,
|
48
51
|
core_config: CoreConfig = CoreConfig(),
|
49
|
-
target_platform_capabilities:
|
52
|
+
target_platform_capabilities: TargetPlatformModel = DEFAULT_PYTORCH_TPC):
|
50
53
|
"""
|
51
54
|
Quantize a trained Pytorch module using post-training quantization.
|
52
55
|
By default, the module is quantized using a symmetric constraint quantization thresholds
|
@@ -107,6 +110,11 @@ if FOUND_TORCH:
|
|
107
110
|
|
108
111
|
fw_impl = PytorchImplementation()
|
109
112
|
|
113
|
+
# Attach tpc model to framework
|
114
|
+
attach2pytorch = AttachTpcToPytorch()
|
115
|
+
target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
|
116
|
+
core_config.quantization_config.custom_tpc_opset_to_layer)
|
117
|
+
|
110
118
|
# Ignore hessian info service as it is not used here yet.
|
111
119
|
tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_module,
|
112
120
|
representative_data_gen=representative_data_gen,
|
@@ -13,6 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
16
|
+
from model_compression_toolkit.verify_packages import FOUND_TF, FOUND_TORCH
|
16
17
|
|
17
|
-
|
18
|
-
from model_compression_toolkit.qat.
|
18
|
+
if FOUND_TF:
|
19
|
+
from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init_experimental, keras_quantization_aware_training_finalize_experimental
|
20
|
+
if FOUND_TORCH:
|
21
|
+
from model_compression_toolkit.qat.pytorch.quantization_facade import pytorch_quantization_aware_training_init_experimental, pytorch_quantization_aware_training_finalize_experimental
|
@@ -19,6 +19,7 @@ from functools import partial
|
|
19
19
|
from model_compression_toolkit.core import CoreConfig
|
20
20
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
21
21
|
from model_compression_toolkit.logger import Logger
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
22
23
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
23
24
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
24
25
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
@@ -54,6 +55,8 @@ if FOUND_TF:
|
|
54
55
|
from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
|
55
56
|
get_activation_quantizer_holder
|
56
57
|
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
58
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
59
|
+
AttachTpcToKeras
|
57
60
|
|
58
61
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
59
62
|
|
@@ -90,7 +93,7 @@ if FOUND_TF:
|
|
90
93
|
target_resource_utilization: ResourceUtilization = None,
|
91
94
|
core_config: CoreConfig = CoreConfig(),
|
92
95
|
qat_config: QATConfig = QATConfig(),
|
93
|
-
target_platform_capabilities:
|
96
|
+
target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC):
|
94
97
|
"""
|
95
98
|
Prepare a trained Keras model for quantization aware training. First the model quantization is optimized
|
96
99
|
with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
|
@@ -186,6 +189,11 @@ if FOUND_TF:
|
|
186
189
|
|
187
190
|
fw_impl = KerasImplementation()
|
188
191
|
|
192
|
+
attach2keras = AttachTpcToKeras()
|
193
|
+
target_platform_capabilities = attach2keras.attach(
|
194
|
+
target_platform_capabilities,
|
195
|
+
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
|
196
|
+
|
189
197
|
# Ignore hessian service since is not used in QAT at the moment
|
190
198
|
tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
|
191
199
|
representative_data_gen=representative_data_gen,
|
@@ -17,6 +17,9 @@ from typing import Callable
|
|
17
17
|
from functools import partial
|
18
18
|
|
19
19
|
from model_compression_toolkit.constants import PYTORCH
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
|
22
|
+
AttachTpcToPytorch
|
20
23
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
21
24
|
|
22
25
|
from model_compression_toolkit.core import CoreConfig
|
@@ -79,7 +82,7 @@ if FOUND_TORCH:
|
|
79
82
|
target_resource_utilization: ResourceUtilization = None,
|
80
83
|
core_config: CoreConfig = CoreConfig(),
|
81
84
|
qat_config: QATConfig = QATConfig(),
|
82
|
-
target_platform_capabilities:
|
85
|
+
target_platform_capabilities: TargetPlatformModel = DEFAULT_PYTORCH_TPC):
|
83
86
|
"""
|
84
87
|
Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
|
85
88
|
with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
|
@@ -154,6 +157,11 @@ if FOUND_TORCH:
|
|
154
157
|
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
155
158
|
fw_impl = PytorchImplementation()
|
156
159
|
|
160
|
+
# Attach tpc model to framework
|
161
|
+
attach2pytorch = AttachTpcToPytorch()
|
162
|
+
target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
|
163
|
+
core_config.quantization_config.custom_tpc_opset_to_layer)
|
164
|
+
|
157
165
|
# Ignore hessian scores service as we do not use it here
|
158
166
|
tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
|
159
167
|
representative_data_gen=representative_data_gen,
|
@@ -7,6 +7,6 @@ OpQuantizationConfig = schema.OpQuantizationConfig
|
|
7
7
|
QuantizationConfigOptions = schema.QuantizationConfigOptions
|
8
8
|
OperatorsSetBase = schema.OperatorsSetBase
|
9
9
|
OperatorsSet = schema.OperatorsSet
|
10
|
-
OperatorSetConcat= schema.OperatorSetConcat
|
10
|
+
OperatorSetConcat = schema.OperatorSetConcat
|
11
11
|
Fusing = schema.Fusing
|
12
12
|
TargetPlatformModel = schema.TargetPlatformModel
|
@@ -13,66 +13,74 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import pprint
|
16
|
-
|
17
16
|
from enum import Enum
|
18
17
|
from typing import Dict, Any, Union, Tuple, List, Optional, Literal, Annotated
|
18
|
+
|
19
|
+
from pydantic import BaseModel, Field, root_validator, validator, PositiveInt
|
20
|
+
|
19
21
|
from mct_quantizers import QuantizationMethod
|
20
22
|
from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
21
23
|
from model_compression_toolkit.logger import Logger
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
24
|
+
|
25
|
+
|
26
|
+
class OperatorSetNames(str, Enum):
|
27
|
+
CONV = "Conv"
|
28
|
+
DEPTHWISE_CONV = "DepthwiseConv2D"
|
29
|
+
CONV_TRANSPOSE = "ConvTranspose"
|
30
|
+
FULLY_CONNECTED = "FullyConnected"
|
31
|
+
CONCATENATE = "Concatenate"
|
32
|
+
STACK = "Stack"
|
33
|
+
UNSTACK = "Unstack"
|
34
|
+
GATHER = "Gather"
|
35
|
+
EXPAND = "Expend"
|
36
|
+
BATCH_NORM = "BatchNorm"
|
37
|
+
L2NORM = "L2Norm"
|
38
|
+
RELU = "ReLU"
|
39
|
+
RELU6 = "ReLU6"
|
40
|
+
LEAKY_RELU = "LeakyReLU"
|
41
|
+
ELU = "Elu"
|
42
|
+
HARD_TANH = "HardTanh"
|
43
|
+
ADD = "Add"
|
44
|
+
SUB = "Sub"
|
45
|
+
MUL = "Mul"
|
46
|
+
DIV = "Div"
|
47
|
+
MIN = "Min"
|
48
|
+
MAX = "Max"
|
49
|
+
PRELU = "PReLU"
|
50
|
+
ADD_BIAS = "AddBias"
|
51
|
+
SWISH = "Swish"
|
52
|
+
SIGMOID = "Sigmoid"
|
53
|
+
SOFTMAX = "Softmax"
|
54
|
+
LOG_SOFTMAX = "LogSoftmax"
|
55
|
+
TANH = "Tanh"
|
56
|
+
GELU = "Gelu"
|
57
|
+
HARDSIGMOID = "HardSigmoid"
|
58
|
+
HARDSWISH = "HardSwish"
|
59
|
+
FLATTEN = "Flatten"
|
60
|
+
GET_ITEM = "GetItem"
|
61
|
+
RESHAPE = "Reshape"
|
62
|
+
UNSQUEEZE = "Unsqueeze"
|
63
|
+
SQUEEZE = "Squeeze"
|
64
|
+
PERMUTE = "Permute"
|
65
|
+
TRANSPOSE = "Transpose"
|
66
|
+
DROPOUT = "Dropout"
|
67
|
+
SPLIT_CHUNK = "SplitChunk"
|
68
|
+
MAXPOOL = "MaxPool"
|
69
|
+
AVGPOOL = "AvgPool"
|
70
|
+
SIZE = "Size"
|
71
|
+
SHAPE = "Shape"
|
72
|
+
EQUAL = "Equal"
|
73
|
+
ARGMAX = "ArgMax"
|
74
|
+
TOPK = "TopK"
|
75
|
+
FAKE_QUANT = "FakeQuant"
|
76
|
+
COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
|
77
|
+
ZERO_PADDING2D = "ZeroPadding2D"
|
78
|
+
CAST = "Cast"
|
79
|
+
RESIZE = "Resize"
|
80
|
+
PAD = "Pad"
|
81
|
+
FOLD = "Fold"
|
82
|
+
STRIDED_SLICE = "StridedSlice"
|
83
|
+
SSD_POST_PROCESS = "SSDPostProcess"
|
76
84
|
|
77
85
|
@classmethod
|
78
86
|
def get_values(cls):
|
@@ -1,12 +1,15 @@
|
|
1
|
-
from typing import Dict,
|
1
|
+
from typing import Dict, Optional
|
2
2
|
|
3
|
-
from model_compression_toolkit import
|
4
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
3
|
+
from model_compression_toolkit.logger import Logger
|
4
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, \
|
5
|
+
OperatorsSet
|
5
6
|
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, \
|
6
7
|
OperationsSetToLayers
|
7
8
|
|
9
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import CustomOpsetLayers
|
8
10
|
|
9
|
-
|
11
|
+
|
12
|
+
class AttachTpcToFramework:
|
10
13
|
|
11
14
|
def __init__(self):
|
12
15
|
self._opset2layer = None
|
@@ -17,7 +20,7 @@ class AttachTpModelToFw:
|
|
17
20
|
self._opset2attr_mapping = None # Mapping of operation sets to their corresponding framework-specific layers
|
18
21
|
|
19
22
|
def attach(self, tpc_model: TargetPlatformModel,
|
20
|
-
custom_opset2layer:
|
23
|
+
custom_opset2layer: Optional[Dict[str, 'CustomOpsetLayers']] = None
|
21
24
|
) -> TargetPlatformCapabilities:
|
22
25
|
"""
|
23
26
|
Attaching a TargetPlatformModel which includes a platform capabilities description to specific
|
@@ -35,22 +38,30 @@ class AttachTpModelToFw:
|
|
35
38
|
"""
|
36
39
|
|
37
40
|
tpc = TargetPlatformCapabilities(tpc_model)
|
41
|
+
custom_opset2layer = custom_opset2layer if custom_opset2layer is not None else {}
|
38
42
|
|
39
43
|
with tpc:
|
40
|
-
for
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
elif
|
49
|
-
|
44
|
+
for opset in tpc_model.operator_set:
|
45
|
+
if isinstance(opset, OperatorsSet): # filter out OperatorsSetConcat
|
46
|
+
if opset.name in custom_opset2layer:
|
47
|
+
custom_opset_layers = custom_opset2layer[opset.name]
|
48
|
+
OperationsSetToLayers(opset.name,
|
49
|
+
layers=custom_opset_layers.operators,
|
50
|
+
attr_mapping=custom_opset_layers.attr_mapping)
|
51
|
+
|
52
|
+
elif opset.name in self._opset2layer:
|
53
|
+
# Note that if the user provided a custom operator set with a name that exists in our
|
54
|
+
# pre-defined set of operator sets, we prioritize the user's custom opset definition
|
55
|
+
layers = self._opset2layer[opset.name]
|
56
|
+
if len(layers) > 0:
|
57
|
+
# If the framework does not define any matching operators to a given operator set name that
|
58
|
+
# appears in the TPC, then we just skip it
|
59
|
+
attr_mapping = self._opset2attr_mapping.get(opset.name)
|
60
|
+
OperationsSetToLayers(opset.name, layers, attr_mapping=attr_mapping)
|
50
61
|
else:
|
51
|
-
|
52
|
-
|
53
|
-
|
62
|
+
Logger.critical(f'{opset.name} is defined in TargetPlatformModel, '
|
63
|
+
f'but is not defined in the framework set of operators or in the provided '
|
64
|
+
f'custom operator sets mapping.')
|
54
65
|
|
55
66
|
return tpc
|
56
67
|
|
@@ -23,12 +23,12 @@ if FOUND_SONY_CUSTOM_LAYERS:
|
|
23
23
|
|
24
24
|
if version.parse(tf.__version__) >= version.parse("2.13"):
|
25
25
|
from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
|
26
|
-
MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
|
27
|
-
Conv2DTranspose,
|
26
|
+
MaxPooling2D, AveragePooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
|
27
|
+
Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum, Softmax
|
28
28
|
else:
|
29
29
|
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
|
30
|
-
MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
|
31
|
-
Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum
|
30
|
+
MaxPooling2D, AveragePooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
|
31
|
+
Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum, Softmax
|
32
32
|
|
33
33
|
from model_compression_toolkit import DefaultDict
|
34
34
|
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS, \
|
@@ -36,72 +36,93 @@ from model_compression_toolkit.target_platform_capabilities.constants import KER
|
|
36
36
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
|
37
37
|
from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
|
38
38
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
|
39
|
-
|
39
|
+
AttachTpcToFramework
|
40
40
|
|
41
41
|
|
42
|
-
class
|
42
|
+
class AttachTpcToKeras(AttachTpcToFramework):
|
43
43
|
def __init__(self):
|
44
44
|
super().__init__()
|
45
45
|
|
46
46
|
self._opset2layer = {
|
47
|
-
OperatorSetNames.
|
48
|
-
OperatorSetNames.
|
49
|
-
OperatorSetNames.
|
50
|
-
OperatorSetNames.
|
51
|
-
OperatorSetNames.
|
52
|
-
OperatorSetNames.
|
53
|
-
OperatorSetNames.
|
54
|
-
OperatorSetNames.
|
55
|
-
OperatorSetNames.
|
56
|
-
OperatorSetNames.
|
57
|
-
OperatorSetNames.
|
58
|
-
OperatorSetNames.
|
59
|
-
OperatorSetNames.
|
60
|
-
OperatorSetNames.
|
61
|
-
OperatorSetNames.
|
62
|
-
OperatorSetNames.
|
63
|
-
OperatorSetNames.
|
64
|
-
OperatorSetNames.
|
65
|
-
OperatorSetNames.
|
66
|
-
OperatorSetNames.
|
67
|
-
OperatorSetNames.
|
68
|
-
OperatorSetNames.
|
69
|
-
OperatorSetNames.
|
70
|
-
OperatorSetNames.
|
71
|
-
OperatorSetNames.
|
72
|
-
OperatorSetNames.
|
73
|
-
|
74
|
-
|
75
|
-
OperatorSetNames.
|
76
|
-
OperatorSetNames.
|
77
|
-
OperatorSetNames.
|
78
|
-
OperatorSetNames.
|
79
|
-
OperatorSetNames.
|
80
|
-
OperatorSetNames.
|
81
|
-
OperatorSetNames.
|
82
|
-
OperatorSetNames.
|
83
|
-
OperatorSetNames.
|
84
|
-
OperatorSetNames.
|
85
|
-
OperatorSetNames.
|
86
|
-
OperatorSetNames.
|
87
|
-
OperatorSetNames.
|
88
|
-
OperatorSetNames.
|
89
|
-
OperatorSetNames.
|
90
|
-
OperatorSetNames.
|
91
|
-
OperatorSetNames.
|
47
|
+
OperatorSetNames.CONV: [Conv2D, tf.nn.conv2d],
|
48
|
+
OperatorSetNames.DEPTHWISE_CONV: [DepthwiseConv2D, tf.nn.depthwise_conv2d],
|
49
|
+
OperatorSetNames.CONV_TRANSPOSE: [Conv2DTranspose, tf.nn.conv2d_transpose],
|
50
|
+
OperatorSetNames.FULLY_CONNECTED: [Dense],
|
51
|
+
OperatorSetNames.CONCATENATE: [tf.concat, Concatenate],
|
52
|
+
OperatorSetNames.STACK: [tf.stack],
|
53
|
+
OperatorSetNames.UNSTACK: [tf.unstack],
|
54
|
+
OperatorSetNames.GATHER: [tf.gather, tf.compat.v1.gather],
|
55
|
+
OperatorSetNames.EXPAND: [],
|
56
|
+
OperatorSetNames.BATCH_NORM: [BatchNormalization, tf.nn.batch_normalization],
|
57
|
+
OperatorSetNames.RELU: [tf.nn.relu, ReLU, LayerFilterParams(Activation, activation="relu")],
|
58
|
+
OperatorSetNames.RELU6: [tf.nn.relu6],
|
59
|
+
OperatorSetNames.LEAKY_RELU: [tf.nn.leaky_relu, LeakyReLU, LayerFilterParams(Activation, activation="leaky_relu")],
|
60
|
+
OperatorSetNames.HARD_TANH: [LayerFilterParams(Activation, activation="hard_tanh")],
|
61
|
+
OperatorSetNames.ADD: [tf.add, Add],
|
62
|
+
OperatorSetNames.SUB: [tf.subtract, Subtract],
|
63
|
+
OperatorSetNames.MUL: [tf.math.multiply, Multiply],
|
64
|
+
OperatorSetNames.DIV: [tf.math.divide, tf.math.truediv],
|
65
|
+
OperatorSetNames.MIN: [tf.math.minimum, Minimum],
|
66
|
+
OperatorSetNames.MAX: [tf.math.maximum, Maximum],
|
67
|
+
OperatorSetNames.PRELU: [PReLU],
|
68
|
+
OperatorSetNames.SWISH: [tf.nn.swish, LayerFilterParams(Activation, activation="swish")],
|
69
|
+
OperatorSetNames.HARDSWISH: [LayerFilterParams(Activation, activation="hard_swish")],
|
70
|
+
OperatorSetNames.SIGMOID: [tf.nn.sigmoid, LayerFilterParams(Activation, activation="sigmoid")],
|
71
|
+
OperatorSetNames.TANH: [tf.nn.tanh, LayerFilterParams(Activation, activation="tanh")],
|
72
|
+
OperatorSetNames.GELU: [tf.nn.gelu, LayerFilterParams(Activation, activation="gelu")],
|
73
|
+
OperatorSetNames.HARDSIGMOID: [tf.keras.activations.hard_sigmoid,
|
74
|
+
LayerFilterParams(Activation, activation="hard_sigmoid")],
|
75
|
+
OperatorSetNames.FLATTEN: [Flatten],
|
76
|
+
OperatorSetNames.GET_ITEM: [tf.__operators__.getitem],
|
77
|
+
OperatorSetNames.RESHAPE: [Reshape, tf.reshape],
|
78
|
+
OperatorSetNames.PERMUTE: [Permute],
|
79
|
+
OperatorSetNames.TRANSPOSE: [tf.transpose],
|
80
|
+
OperatorSetNames.UNSQUEEZE: [tf.expand_dims],
|
81
|
+
OperatorSetNames.SQUEEZE: [tf.squeeze],
|
82
|
+
OperatorSetNames.DROPOUT: [Dropout],
|
83
|
+
OperatorSetNames.SPLIT_CHUNK: [tf.split],
|
84
|
+
OperatorSetNames.MAXPOOL: [MaxPooling2D, tf.nn.avg_pool2d],
|
85
|
+
OperatorSetNames.AVGPOOL: [AveragePooling2D],
|
86
|
+
OperatorSetNames.SIZE: [tf.size],
|
87
|
+
OperatorSetNames.RESIZE: [tf.image.resize],
|
88
|
+
OperatorSetNames.PAD: [tf.pad, Cropping2D],
|
89
|
+
OperatorSetNames.FOLD: [tf.space_to_batch_nd],
|
90
|
+
OperatorSetNames.SHAPE: [tf.shape, tf.compat.v1.shape],
|
91
|
+
OperatorSetNames.EQUAL: [tf.math.equal],
|
92
|
+
OperatorSetNames.ARGMAX: [tf.math.argmax],
|
93
|
+
OperatorSetNames.TOPK: [tf.nn.top_k],
|
94
|
+
OperatorSetNames.FAKE_QUANT: [tf.quantization.fake_quant_with_min_max_vars],
|
95
|
+
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [tf.image.combined_non_max_suppression],
|
96
|
+
OperatorSetNames.ZERO_PADDING2D: [ZeroPadding2D],
|
97
|
+
OperatorSetNames.CAST: [tf.cast],
|
98
|
+
OperatorSetNames.STRIDED_SLICE: [tf.strided_slice],
|
99
|
+
OperatorSetNames.ELU: [tf.nn.elu, LayerFilterParams(Activation, activation="elu")],
|
100
|
+
OperatorSetNames.SOFTMAX: [tf.nn.softmax, Softmax,
|
101
|
+
LayerFilterParams(Activation, activation="softmax")],
|
102
|
+
OperatorSetNames.LOG_SOFTMAX: [tf.nn.log_softmax],
|
103
|
+
OperatorSetNames.ADD_BIAS: [tf.nn.bias_add],
|
104
|
+
OperatorSetNames.L2NORM: [tf.math.l2_normalize],
|
92
105
|
}
|
93
106
|
|
94
107
|
if FOUND_SONY_CUSTOM_LAYERS:
|
95
|
-
self._opset2layer[OperatorSetNames.
|
108
|
+
self._opset2layer[OperatorSetNames.SSD_POST_PROCESS] = [SSDPostProcess]
|
109
|
+
else:
|
110
|
+
# If Custom layers is not installed then we don't want the user to fail, but just ignore custom layers
|
111
|
+
# in the initialized framework TPC
|
112
|
+
self._opset2layer[OperatorSetNames.SSD_POST_PROCESS] = []
|
96
113
|
|
97
|
-
self._opset2attr_mapping = {
|
98
|
-
|
99
|
-
|
100
|
-
|
114
|
+
self._opset2attr_mapping = {
|
115
|
+
OperatorSetNames.CONV: {
|
116
|
+
KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
|
117
|
+
BIAS_ATTR: DefaultDict(default_value=BIAS)},
|
118
|
+
OperatorSetNames.CONV_TRANSPOSE: {
|
119
|
+
KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
|
120
|
+
BIAS_ATTR: DefaultDict(default_value=BIAS)},
|
121
|
+
OperatorSetNames.DEPTHWISE_CONV: {
|
101
122
|
KERNEL_ATTR: DefaultDict({
|
102
123
|
DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL,
|
103
124
|
tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL),
|
104
125
|
BIAS_ATTR: DefaultDict(default_value=BIAS)},
|
105
|
-
OperatorSetNames.
|
126
|
+
OperatorSetNames.FULLY_CONNECTED: {
|
106
127
|
KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
|
107
128
|
BIAS_ATTR: DefaultDict(default_value=BIAS)}}
|