mct-nightly 2.2.0.20250106.546__py3-none-any.whl → 2.2.0.20250107.15510__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/RECORD +43 -78
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/cut.py +5 -2
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +25 -25
- model_compression_toolkit/core/common/quantization/quantization_config.py +19 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -33
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py +2 -2
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +11 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/matmul_decomposition.py +499 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +3 -0
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +11 -3
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +10 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +8 -2
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -2
- model_compression_toolkit/ptq/keras/quantization_facade.py +10 -1
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +9 -1
- model_compression_toolkit/qat/__init__.py +5 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +9 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -1
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +63 -55
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py +29 -18
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py +78 -57
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py +69 -54
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +0 -10
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +93 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +46 -28
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +6 -5
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +51 -19
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +8 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +19 -9
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +7 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +46 -32
- model_compression_toolkit/xquant/keras/keras_report_utils.py +11 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +0 -98
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py +0 -129
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_pytorch.py +0 -108
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +0 -217
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_keras.py +0 -130
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_pytorch.py +0 -109
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +0 -215
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_keras.py +0 -130
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_pytorch.py +0 -110
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +0 -222
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +0 -132
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py +0 -110
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +0 -219
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +0 -132
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py +0 -109
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +0 -246
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +0 -135
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +0 -113
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +0 -230
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +0 -132
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +0 -110
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +0 -332
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +0 -140
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +0 -122
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +0 -55
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_keras.py +0 -89
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +0 -78
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +0 -55
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_keras.py +0 -118
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_pytorch.py +0 -100
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/top_level.txt +0 -0
@@ -18,74 +18,89 @@ import operator
|
|
18
18
|
import torch
|
19
19
|
from torch import add, sub, mul, div, divide, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, \
|
20
20
|
chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract, minimum, \
|
21
|
-
maximum
|
22
|
-
from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d, BatchNorm2d
|
23
|
-
|
24
|
-
from torch.nn import
|
21
|
+
maximum, softmax, fake_quantize_per_channel_affine
|
22
|
+
from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d, BatchNorm2d, Dropout, Flatten, Hardtanh, ReLU, ReLU6, \
|
23
|
+
PReLU, SiLU, Sigmoid, Tanh, Hardswish, Hardsigmoid, LeakyReLU, GELU, LogSoftmax, Softmax, ELU, AvgPool2d, ZeroPad2d
|
24
|
+
from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, hardsigmoid, leaky_relu, gelu, fold
|
25
25
|
import torch.nn.functional as F
|
26
|
-
from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, hardsigmoid, leaky_relu, gelu
|
27
26
|
|
28
27
|
from model_compression_toolkit import DefaultDict
|
29
28
|
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS, \
|
30
29
|
BIAS_ATTR
|
31
30
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
|
32
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
|
31
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams, Eq
|
33
32
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
|
34
|
-
|
33
|
+
AttachTpcToFramework
|
35
34
|
|
36
35
|
|
37
|
-
class
|
36
|
+
class AttachTpcToPytorch(AttachTpcToFramework):
|
38
37
|
def __init__(self):
|
39
38
|
super().__init__()
|
40
39
|
|
41
40
|
self._opset2layer = {
|
42
|
-
OperatorSetNames.
|
43
|
-
OperatorSetNames.
|
44
|
-
OperatorSetNames.
|
45
|
-
OperatorSetNames.
|
46
|
-
OperatorSetNames.
|
47
|
-
OperatorSetNames.
|
48
|
-
OperatorSetNames.
|
49
|
-
OperatorSetNames.
|
50
|
-
OperatorSetNames.
|
51
|
-
OperatorSetNames.
|
52
|
-
OperatorSetNames.
|
53
|
-
OperatorSetNames.
|
54
|
-
OperatorSetNames.
|
55
|
-
|
56
|
-
|
57
|
-
OperatorSetNames.
|
58
|
-
OperatorSetNames.
|
59
|
-
OperatorSetNames.
|
60
|
-
OperatorSetNames.
|
61
|
-
OperatorSetNames.
|
62
|
-
OperatorSetNames.
|
63
|
-
OperatorSetNames.
|
64
|
-
OperatorSetNames.
|
65
|
-
OperatorSetNames.
|
66
|
-
OperatorSetNames.
|
67
|
-
OperatorSetNames.
|
68
|
-
OperatorSetNames.
|
69
|
-
OperatorSetNames.
|
70
|
-
OperatorSetNames.
|
71
|
-
OperatorSetNames.
|
72
|
-
OperatorSetNames.
|
73
|
-
OperatorSetNames.
|
74
|
-
OperatorSetNames.
|
75
|
-
OperatorSetNames.
|
76
|
-
OperatorSetNames.
|
77
|
-
OperatorSetNames.
|
78
|
-
OperatorSetNames.
|
79
|
-
OperatorSetNames.
|
80
|
-
OperatorSetNames.
|
81
|
-
OperatorSetNames.
|
82
|
-
OperatorSetNames.
|
83
|
-
OperatorSetNames.
|
84
|
-
OperatorSetNames.
|
41
|
+
OperatorSetNames.CONV: [Conv2d],
|
42
|
+
OperatorSetNames.DEPTHWISE_CONV: [], # no specific operator for depthwise conv in pytorch
|
43
|
+
OperatorSetNames.CONV_TRANSPOSE: [ConvTranspose2d],
|
44
|
+
OperatorSetNames.FULLY_CONNECTED: [Linear],
|
45
|
+
OperatorSetNames.CONCATENATE: [torch.cat, torch.concat, torch.concatenate],
|
46
|
+
OperatorSetNames.STACK: [torch.stack],
|
47
|
+
OperatorSetNames.UNSTACK: [unbind],
|
48
|
+
OperatorSetNames.GATHER: [gather],
|
49
|
+
OperatorSetNames.EXPAND: [torch.Tensor.expand],
|
50
|
+
OperatorSetNames.BATCH_NORM: [BatchNorm2d],
|
51
|
+
OperatorSetNames.RELU: [torch.relu, ReLU, relu],
|
52
|
+
OperatorSetNames.RELU6: [ReLU6, relu6],
|
53
|
+
OperatorSetNames.LEAKY_RELU: [LeakyReLU, leaky_relu],
|
54
|
+
OperatorSetNames.HARD_TANH: [LayerFilterParams(Hardtanh, min_val=0),
|
55
|
+
LayerFilterParams(hardtanh, min_val=0)],
|
56
|
+
OperatorSetNames.ADD: [operator.add, add],
|
57
|
+
OperatorSetNames.SUB: [operator.sub, sub, subtract],
|
58
|
+
OperatorSetNames.MUL: [operator.mul, mul, multiply],
|
59
|
+
OperatorSetNames.DIV: [operator.truediv, div, divide],
|
60
|
+
OperatorSetNames.ADD_BIAS: [], # no specific operator for bias_add in pytorch
|
61
|
+
OperatorSetNames.MIN: [minimum],
|
62
|
+
OperatorSetNames.MAX: [maximum],
|
63
|
+
OperatorSetNames.PRELU: [PReLU, prelu],
|
64
|
+
OperatorSetNames.SWISH: [SiLU, silu],
|
65
|
+
OperatorSetNames.SIGMOID: [Sigmoid, sigmoid, F.sigmoid],
|
66
|
+
OperatorSetNames.TANH: [Tanh, tanh, F.tanh],
|
67
|
+
OperatorSetNames.GELU: [GELU, gelu],
|
68
|
+
OperatorSetNames.HARDSIGMOID: [Hardsigmoid, hardsigmoid],
|
69
|
+
OperatorSetNames.HARDSWISH: [Hardswish, hardswish],
|
70
|
+
OperatorSetNames.FLATTEN: [Flatten, flatten],
|
71
|
+
OperatorSetNames.GET_ITEM: [operator.getitem],
|
72
|
+
OperatorSetNames.RESHAPE: [reshape],
|
73
|
+
OperatorSetNames.UNSQUEEZE: [unsqueeze],
|
74
|
+
OperatorSetNames.SQUEEZE: [squeeze],
|
75
|
+
OperatorSetNames.PERMUTE: [permute],
|
76
|
+
OperatorSetNames.TRANSPOSE: [transpose],
|
77
|
+
OperatorSetNames.DROPOUT: [Dropout, dropout],
|
78
|
+
OperatorSetNames.SPLIT_CHUNK: [split, chunk],
|
79
|
+
OperatorSetNames.MAXPOOL: [MaxPool2d, F.max_pool2d],
|
80
|
+
OperatorSetNames.AVGPOOL: [AvgPool2d, F.avg_pool2d],
|
81
|
+
OperatorSetNames.SIZE: [torch.Tensor.size],
|
82
|
+
OperatorSetNames.RESIZE: [torch.Tensor.resize],
|
83
|
+
OperatorSetNames.PAD: [F.pad],
|
84
|
+
OperatorSetNames.FOLD: [fold],
|
85
|
+
OperatorSetNames.SHAPE: [torch.Tensor.shape],
|
86
|
+
OperatorSetNames.EQUAL: [equal],
|
87
|
+
OperatorSetNames.ARGMAX: [argmax],
|
88
|
+
OperatorSetNames.TOPK: [topk],
|
89
|
+
OperatorSetNames.FAKE_QUANT: [fake_quantize_per_channel_affine],
|
90
|
+
OperatorSetNames.ZERO_PADDING2D: [ZeroPad2d],
|
91
|
+
OperatorSetNames.CAST: [torch.Tensor.type],
|
92
|
+
OperatorSetNames.STRIDED_SLICE: [], # no such operator in pytorch, the equivalent is get_item which has a separate operator set
|
93
|
+
OperatorSetNames.ELU: [ELU, F.elu],
|
94
|
+
OperatorSetNames.SOFTMAX: [Softmax, softmax, F.softmax],
|
95
|
+
OperatorSetNames.LOG_SOFTMAX: [LogSoftmax],
|
96
|
+
OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
|
97
|
+
Eq('p', 2) | Eq('p', None))],
|
98
|
+
OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
|
99
|
+
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [] # no such operator in pytorch
|
85
100
|
}
|
86
101
|
|
87
102
|
pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
|
88
103
|
BIAS_ATTR: DefaultDict(default_value=BIAS)}
|
89
|
-
self._opset2attr_mapping = {OperatorSetNames.
|
90
|
-
OperatorSetNames.
|
91
|
-
OperatorSetNames.
|
104
|
+
self._opset2attr_mapping = {OperatorSetNames.CONV: pytorch_linear_attr_mapping,
|
105
|
+
OperatorSetNames.CONV_TRANSPOSE: pytorch_linear_attr_mapping,
|
106
|
+
OperatorSetNames.FULLY_CONNECTED: pytorch_linear_attr_mapping}
|
@@ -138,10 +138,8 @@ class OperationsToLayers:
|
|
138
138
|
OperationsSetToLayers), f'Operators set should be of type OperationsSetToLayers but it ' \
|
139
139
|
f'is of type {type(ops2layers)}'
|
140
140
|
|
141
|
-
# Assert that opset
|
142
|
-
|
143
|
-
assert opset_in_model, f'{ops2layers.name} is not defined in the target platform model that is associated with the target platform capabilities.'
|
144
|
-
assert not (ops2layers.name in existing_opset_names), f'OperationsSetToLayers names should be unique, but {ops2layers.name} appears to violate it.'
|
141
|
+
# Assert that opset has a unique name.
|
142
|
+
assert ops2layers.name not in existing_opset_names, f'OperationsSetToLayers names should be unique, but {ops2layers.name} appears to violate it.'
|
145
143
|
existing_opset_names.append(ops2layers.name)
|
146
144
|
|
147
145
|
# Assert that a layer does not appear in more than a single OperatorsSet in the TargetPlatformModel.
|
@@ -156,7 +156,6 @@ class TargetPlatformCapabilities(ImmutableClass):
|
|
156
156
|
if exc_value is not None:
|
157
157
|
print(exc_value, exc_value.args)
|
158
158
|
raise exc_value
|
159
|
-
self.raise_warnings()
|
160
159
|
self.layer2qco, self.filterlayer2qco = self._get_config_options_mapping()
|
161
160
|
_current_tpc.reset()
|
162
161
|
self.initialized_done()
|
@@ -226,15 +225,6 @@ class TargetPlatformCapabilities(ImmutableClass):
|
|
226
225
|
if opset_to_remove in self.__tp_model_opsets_not_used:
|
227
226
|
self.__tp_model_opsets_not_used.remove(opset_to_remove)
|
228
227
|
|
229
|
-
def raise_warnings(self):
|
230
|
-
"""
|
231
|
-
|
232
|
-
Log warnings regards unused opsets.
|
233
|
-
|
234
|
-
"""
|
235
|
-
for op in self.__tp_model_opsets_not_used:
|
236
|
-
Logger.warning(f'{op} is defined in TargetPlatformModel, but is not used in TargetPlatformCapabilities.')
|
237
|
-
|
238
228
|
@property
|
239
229
|
def is_simd_padding(self) -> bool:
|
240
230
|
"""
|
@@ -0,0 +1,93 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from pathlib import Path
|
16
|
+
from typing import Union
|
17
|
+
|
18
|
+
from model_compression_toolkit.logger import Logger
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
20
|
+
import json
|
21
|
+
|
22
|
+
|
23
|
+
def load_target_platform_model(tp_model_or_path: Union[TargetPlatformModel, str]) -> TargetPlatformModel:
|
24
|
+
"""
|
25
|
+
Parses the tp_model input, which can be either a TargetPlatformModel object
|
26
|
+
or a string path to a JSON file.
|
27
|
+
|
28
|
+
Parameters:
|
29
|
+
tp_model_or_path (Union[TargetPlatformModel, str]): Input target platform model or path to .JSON file.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
TargetPlatformModel: The parsed TargetPlatformModel.
|
33
|
+
|
34
|
+
Raises:
|
35
|
+
FileNotFoundError: If the JSON file does not exist.
|
36
|
+
ValueError: If the JSON content is invalid or cannot initialize the TargetPlatformModel.
|
37
|
+
TypeError: If the input is neither a TargetPlatformModel nor a valid JSON file path.
|
38
|
+
"""
|
39
|
+
if isinstance(tp_model_or_path, TargetPlatformModel):
|
40
|
+
return tp_model_or_path
|
41
|
+
|
42
|
+
if isinstance(tp_model_or_path, str):
|
43
|
+
path = Path(tp_model_or_path)
|
44
|
+
|
45
|
+
if not path.exists() or not path.is_file():
|
46
|
+
raise FileNotFoundError(f"The path '{tp_model_or_path}' is not a valid file.")
|
47
|
+
# Verify that the file has a .json extension
|
48
|
+
if path.suffix.lower() != '.json':
|
49
|
+
raise ValueError(f"The file '{path}' does not have a '.json' extension.")
|
50
|
+
try:
|
51
|
+
with path.open('r', encoding='utf-8') as file:
|
52
|
+
data = file.read()
|
53
|
+
except OSError as e:
|
54
|
+
raise ValueError(f"Error reading the file '{tp_model_or_path}': {e.strerror}.") from e
|
55
|
+
|
56
|
+
try:
|
57
|
+
return TargetPlatformModel.parse_raw(data)
|
58
|
+
except ValueError as e:
|
59
|
+
raise ValueError(f"Invalid JSON for loading TargetPlatformModel in '{tp_model_or_path}': {e}.") from e
|
60
|
+
except Exception as e:
|
61
|
+
raise ValueError(f"Unexpected error while initializing TargetPlatformModel: {e}.") from e
|
62
|
+
|
63
|
+
raise TypeError(
|
64
|
+
f"tp_model_or_path must be either a TargetPlatformModel instance or a string path to a JSON file, "
|
65
|
+
f"but received type '{type(tp_model_or_path).__name__}'."
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
def export_target_platform_model(model: TargetPlatformModel, export_path: Union[str, Path]) -> None:
|
70
|
+
"""
|
71
|
+
Exports a TargetPlatformModel instance to a JSON file.
|
72
|
+
|
73
|
+
Parameters:
|
74
|
+
model (TargetPlatformModel): The TargetPlatformModel instance to export.
|
75
|
+
export_path (Union[str, Path]): The file path to export the model to.
|
76
|
+
|
77
|
+
Raises:
|
78
|
+
ValueError: If the model is not an instance of TargetPlatformModel.
|
79
|
+
OSError: If there is an issue writing to the file.
|
80
|
+
"""
|
81
|
+
if not isinstance(model, TargetPlatformModel):
|
82
|
+
raise ValueError("The provided model is not a valid TargetPlatformModel instance.")
|
83
|
+
|
84
|
+
path = Path(export_path)
|
85
|
+
try:
|
86
|
+
# Ensure the parent directory exists
|
87
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
88
|
+
|
89
|
+
# Export the model to JSON and write to the file
|
90
|
+
with path.open('w', encoding='utf-8') as file:
|
91
|
+
file.write(model.json(indent=4))
|
92
|
+
except OSError as e:
|
93
|
+
raise OSError(f"Failed to write to file '{export_path}': {e.strerror}") from e
|
@@ -12,45 +12,63 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from model_compression_toolkit.constants import TENSORFLOW, PYTORCH
|
16
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \
|
17
|
+
TFLITE_TP_MODEL, QNNPACK_TP_MODEL
|
18
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
15
19
|
|
16
|
-
from model_compression_toolkit.target_platform_capabilities.
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_tp_model_imx500_v1
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import get_tp_model as get_tp_model_tflite_v1
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model as get_tp_model_qnnpack_v1
|
17
23
|
|
18
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.target_platform_capabilities import \
|
19
|
-
get_tpc_dict_by_fw as get_imx500_tpc
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.target_platform_capabilities import \
|
21
|
-
get_tpc_dict_by_fw as get_tflite_tpc
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.target_platform_capabilities import \
|
23
|
-
get_tpc_dict_by_fw as get_qnnpack_tpc
|
24
|
-
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, TFLITE_TP_MODEL, QNNPACK_TP_MODEL, LATEST
|
25
|
-
|
26
|
-
tpc_dict = {DEFAULT_TP_MODEL: get_imx500_tpc,
|
27
|
-
IMX500_TP_MODEL: get_imx500_tpc,
|
28
|
-
TFLITE_TP_MODEL: get_tflite_tpc,
|
29
|
-
QNNPACK_TP_MODEL: get_qnnpack_tpc}
|
30
24
|
|
25
|
+
# TODO: These methods need to be replaced once modifying the TPC API.
|
31
26
|
|
32
27
|
def get_target_platform_capabilities(fw_name: str,
|
33
28
|
target_platform_name: str,
|
34
|
-
target_platform_version: str = None) ->
|
29
|
+
target_platform_version: str = None) -> TargetPlatformModel:
|
35
30
|
"""
|
36
|
-
|
37
|
-
|
38
|
-
the target platform model can be 'default', 'imx500', 'tflite', or 'qnnpack'.
|
31
|
+
This is a degenerated function that only returns the MCT default TargetPlatformModel object, to comply with the
|
32
|
+
existing TPC API.
|
39
33
|
|
40
34
|
Args:
|
41
35
|
fw_name: Framework name of the TargetPlatformCapabilities.
|
42
36
|
target_platform_name: Target platform model name the model will use for inference.
|
43
37
|
target_platform_version: Target platform capabilities version.
|
38
|
+
|
44
39
|
Returns:
|
45
|
-
A
|
46
|
-
a framework information to it.
|
40
|
+
A default TargetPlatformModel object.
|
47
41
|
"""
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
if
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
42
|
+
|
43
|
+
assert fw_name in [TENSORFLOW, PYTORCH], f"Unsupported framework {fw_name}."
|
44
|
+
|
45
|
+
if target_platform_name == DEFAULT_TP_MODEL:
|
46
|
+
return get_tp_model_imx500_v1()
|
47
|
+
|
48
|
+
assert target_platform_version == 'v1' or target_platform_version is None, \
|
49
|
+
"The usage of get_target_platform_capabilities API is supported only with the default TPC ('v1')."
|
50
|
+
|
51
|
+
if target_platform_name == IMX500_TP_MODEL:
|
52
|
+
return get_tp_model_imx500_v1()
|
53
|
+
elif target_platform_name == TFLITE_TP_MODEL:
|
54
|
+
return get_tp_model_tflite_v1()
|
55
|
+
elif target_platform_name == QNNPACK_TP_MODEL:
|
56
|
+
return get_tp_model_qnnpack_v1()
|
57
|
+
|
58
|
+
raise ValueError(f"Unsupported target platform name {target_platform_name}.")
|
59
|
+
|
60
|
+
|
61
|
+
def get_tpc_model(name: str, tp_model: TargetPlatformModel):
|
62
|
+
"""
|
63
|
+
This is a utility method that just returns the TargetPlatformModel that it receives, to support existing TPC API.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
name: the name of the TargetPlatformModel (not used in this function).
|
67
|
+
tp_model: a TargetPlatformModel to return.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
The given TargetPlatformModel object.
|
71
|
+
|
72
|
+
"""
|
73
|
+
|
74
|
+
return tp_model
|
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py
CHANGED
@@ -16,9 +16,10 @@ from model_compression_toolkit.verify_packages import FOUND_TORCH, FOUND_TF
|
|
16
16
|
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model, generate_tp_model, \
|
17
17
|
get_op_quantization_configs
|
18
18
|
if FOUND_TF:
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_keras_tpc_latest
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
|
21
|
+
get_tpc_model as generate_keras_tpc
|
21
22
|
if FOUND_TORCH:
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.
|
23
|
-
|
24
|
-
|
23
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_pytorch_tpc_latest
|
24
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
|
25
|
+
get_tpc_model as generate_pytorch_tpc
|
@@ -167,38 +167,70 @@ def generate_tp_model(default_config: OpQuantizationConfig,
|
|
167
167
|
operator_set = []
|
168
168
|
fusing_patterns = []
|
169
169
|
|
170
|
-
|
171
|
-
|
172
|
-
|
170
|
+
no_quantization_config = (default_configuration_options.clone_and_edit(enable_activation_quantization=False)
|
171
|
+
.clone_and_edit_weight_attribute(enable_weights_quantization=False))
|
172
|
+
|
173
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.STACK, qc_options=no_quantization_config))
|
174
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.UNSTACK, qc_options=no_quantization_config))
|
175
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.DROPOUT, qc_options=no_quantization_config))
|
176
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.FLATTEN, qc_options=no_quantization_config))
|
177
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.SPLIT_CHUNK, qc_options=no_quantization_config))
|
178
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.GET_ITEM, qc_options=no_quantization_config))
|
179
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.RESHAPE, qc_options=no_quantization_config))
|
180
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.UNSQUEEZE, qc_options=no_quantization_config))
|
181
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.SIZE, qc_options=no_quantization_config))
|
182
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.PERMUTE, qc_options=no_quantization_config))
|
183
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.TRANSPOSE, qc_options=no_quantization_config))
|
184
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.EQUAL, qc_options=no_quantization_config))
|
185
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.ARGMAX, qc_options=no_quantization_config))
|
186
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.GATHER, qc_options=no_quantization_config))
|
187
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.TOPK, qc_options=no_quantization_config))
|
188
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.SQUEEZE, qc_options=no_quantization_config))
|
189
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.MAXPOOL, qc_options=no_quantization_config))
|
190
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.PAD, qc_options=no_quantization_config))
|
191
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.ZERO_PADDING2D, qc_options=no_quantization_config))
|
192
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.CAST, qc_options=no_quantization_config))
|
193
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION, qc_options=no_quantization_config))
|
194
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.FAKE_QUANT, qc_options=no_quantization_config))
|
195
|
+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.SSD_POST_PROCESS, qc_options=no_quantization_config))
|
173
196
|
|
174
197
|
# Define operator sets that use mixed_precision_configuration_options:
|
175
|
-
conv = schema.OperatorsSet(name=
|
176
|
-
|
198
|
+
conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV, qc_options=mixed_precision_configuration_options)
|
199
|
+
conv_transpose = schema.OperatorsSet(name=schema.OperatorSetNames.CONV_TRANSPOSE, qc_options=mixed_precision_configuration_options)
|
200
|
+
depthwise_conv = schema.OperatorsSet(name=schema.OperatorSetNames.DEPTHWISE_CONV, qc_options=mixed_precision_configuration_options)
|
201
|
+
fc = schema.OperatorsSet(name=schema.OperatorSetNames.FULLY_CONNECTED, qc_options=mixed_precision_configuration_options)
|
177
202
|
|
178
203
|
# Define operations sets without quantization configuration
|
179
204
|
# options (useful for creating fusing patterns, for example):
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
205
|
+
relu = schema.OperatorsSet(name=schema.OperatorSetNames.RELU)
|
206
|
+
relu6 = schema.OperatorsSet(name=schema.OperatorSetNames.RELU6)
|
207
|
+
leaky_relu = schema.OperatorsSet(name=schema.OperatorSetNames.LEAKY_RELU)
|
208
|
+
prelu = schema.OperatorsSet(name=schema.OperatorSetNames.PRELU)
|
209
|
+
add = schema.OperatorsSet(name=schema.OperatorSetNames.ADD)
|
210
|
+
sub = schema.OperatorsSet(name=schema.OperatorSetNames.SUB)
|
211
|
+
mul = schema.OperatorsSet(name=schema.OperatorSetNames.MUL)
|
212
|
+
div = schema.OperatorsSet(name=schema.OperatorSetNames.DIV)
|
213
|
+
swish = schema.OperatorsSet(name=schema.OperatorSetNames.SWISH)
|
214
|
+
hard_swish = schema.OperatorsSet(name=schema.OperatorSetNames.HARDSWISH)
|
215
|
+
sigmoid = schema.OperatorsSet(name=schema.OperatorSetNames.SIGMOID)
|
216
|
+
tanh = schema.OperatorsSet(name=schema.OperatorSetNames.TANH)
|
217
|
+
hard_tanh = schema.OperatorsSet(name=schema.OperatorSetNames.HARD_TANH)
|
218
|
+
|
219
|
+
operator_set.extend([conv, conv_transpose, depthwise_conv, fc, relu, relu6, leaky_relu, add, sub, mul, div, prelu, swish,
|
220
|
+
hard_swish, sigmoid, tanh, hard_tanh])
|
221
|
+
any_relu = schema.OperatorSetConcat(operators_set=[relu, relu6, leaky_relu, hard_tanh])
|
191
222
|
# Combine multiple operators into a single operator to avoid quantization between
|
192
223
|
# them. To do this we define fusing patterns using the OperatorsSets that were created.
|
193
224
|
# To group multiple sets with regard to fusing, an OperatorSetConcat can be created
|
194
|
-
activations_after_conv_to_fuse = schema.OperatorSetConcat(operators_set=[
|
195
|
-
|
225
|
+
activations_after_conv_to_fuse = schema.OperatorSetConcat(operators_set=[relu, relu6, leaky_relu, hard_tanh, swish, hard_swish, prelu, sigmoid, tanh])
|
226
|
+
conv_types = schema.OperatorSetConcat(operators_set=[conv, conv_transpose, depthwise_conv])
|
227
|
+
activations_after_fc_to_fuse = schema.OperatorSetConcat(operators_set=[relu, relu6, leaky_relu, hard_tanh, swish, hard_swish, sigmoid])
|
196
228
|
any_binary = schema.OperatorSetConcat(operators_set=[add, sub, mul, div])
|
197
229
|
|
198
230
|
# ------------------- #
|
199
231
|
# Fusions
|
200
232
|
# ------------------- #
|
201
|
-
fusing_patterns.append(schema.Fusing(operator_groups=(
|
233
|
+
fusing_patterns.append(schema.Fusing(operator_groups=(conv_types, activations_after_conv_to_fuse)))
|
202
234
|
fusing_patterns.append(schema.Fusing(operator_groups=(fc, activations_after_fc_to_fuse)))
|
203
235
|
fusing_patterns.append(schema.Fusing(operator_groups=(any_binary, any_relu)))
|
204
236
|
|
model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py
CHANGED
@@ -15,8 +15,12 @@
|
|
15
15
|
from model_compression_toolkit.verify_packages import FOUND_TORCH, FOUND_TF
|
16
16
|
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model, generate_tp_model, get_op_quantization_configs
|
17
17
|
if FOUND_TF:
|
18
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.
|
19
|
-
|
18
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model as \
|
19
|
+
get_keras_tpc_latest
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
|
21
|
+
get_tpc_model as generate_keras_tpc, get_tpc_model as generate_keras_tpc
|
20
22
|
if FOUND_TORCH:
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.
|
22
|
-
|
23
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model as \
|
24
|
+
get_pytorch_tpc_latest
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
|
26
|
+
get_tpc_model as generate_pytorch_tpc, get_tpc_model as generate_pytorch_tpc
|
model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py
CHANGED
@@ -148,19 +148,29 @@ def generate_tp_model(default_config: OpQuantizationConfig,
|
|
148
148
|
operator_set = []
|
149
149
|
fusing_patterns = []
|
150
150
|
|
151
|
-
conv = schema.OperatorsSet(name=
|
152
|
-
|
153
|
-
|
154
|
-
|
151
|
+
conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV)
|
152
|
+
conv_depthwise = schema.OperatorsSet(name=schema.OperatorSetNames.DEPTHWISE_CONV)
|
153
|
+
conv_transpose = schema.OperatorsSet(name=schema.OperatorSetNames.CONV_TRANSPOSE)
|
154
|
+
batchnorm = schema.OperatorsSet(name=schema.OperatorSetNames.BATCH_NORM)
|
155
|
+
relu = schema.OperatorsSet(name=schema.OperatorSetNames.RELU)
|
156
|
+
relu6 = schema.OperatorsSet(name=schema.OperatorSetNames.RELU6)
|
157
|
+
|
158
|
+
hard_tanh = schema.OperatorsSet(name=schema.OperatorSetNames.HARD_TANH)
|
159
|
+
linear = schema.OperatorsSet(name=schema.OperatorSetNames.FULLY_CONNECTED)
|
160
|
+
|
161
|
+
operator_set.extend([conv, conv_depthwise, conv_transpose, batchnorm, relu, relu6, hard_tanh, linear])
|
162
|
+
|
163
|
+
conv_opset_concat = schema.OperatorSetConcat(operators_set=[conv, conv_transpose])
|
164
|
+
relu_opset_concat = schema.OperatorSetConcat(operators_set=[relu, relu6, hard_tanh])
|
155
165
|
|
156
|
-
operator_set.extend([conv, batchnorm, relu, linear])
|
157
166
|
# ------------------- #
|
158
167
|
# Fusions
|
159
168
|
# ------------------- #
|
160
|
-
fusing_patterns.append(schema.Fusing(operator_groups=(
|
161
|
-
fusing_patterns.append(schema.Fusing(operator_groups=(
|
162
|
-
fusing_patterns.append(schema.Fusing(operator_groups=(
|
163
|
-
fusing_patterns.append(schema.Fusing(operator_groups=(linear,
|
169
|
+
fusing_patterns.append(schema.Fusing(operator_groups=(conv_opset_concat, batchnorm, relu_opset_concat)))
|
170
|
+
fusing_patterns.append(schema.Fusing(operator_groups=(conv_opset_concat, batchnorm)))
|
171
|
+
fusing_patterns.append(schema.Fusing(operator_groups=(conv_opset_concat, relu_opset_concat)))
|
172
|
+
fusing_patterns.append(schema.Fusing(operator_groups=(linear, relu_opset_concat)))
|
173
|
+
|
164
174
|
# Create a TargetPlatformModel and set its default quantization config.
|
165
175
|
# This default configuration will be used for all operations
|
166
176
|
# unless specified otherwise (see OperatorsSet, for example):
|
model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py
CHANGED
@@ -15,8 +15,11 @@
|
|
15
15
|
from model_compression_toolkit.verify_packages import FOUND_TORCH, FOUND_TF
|
16
16
|
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import get_tp_model, generate_tp_model, get_op_quantization_configs
|
17
17
|
if FOUND_TF:
|
18
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.
|
18
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import get_keras_tpc as get_keras_tpc_latest
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
|
20
|
+
get_tpc_model as generate_keras_tpc
|
20
21
|
if FOUND_TORCH:
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.
|
22
|
-
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import \
|
23
|
+
get_tp_model as get_pytorch_tpc_latest
|
24
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
|
25
|
+
get_tpc_model as generate_pytorch_tpc
|