mct-nightly 2.2.0.20250106.546__py3-none-any.whl → 2.2.0.20250107.134735__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.134735.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.134735.dist-info}/RECORD +44 -79
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/cut.py +5 -2
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +25 -25
- model_compression_toolkit/core/common/quantization/quantization_config.py +19 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -33
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py +2 -2
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +11 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/matmul_decomposition.py +499 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +3 -0
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +11 -3
- model_compression_toolkit/gptq/keras/gptq_loss.py +4 -3
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +10 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +8 -2
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -2
- model_compression_toolkit/ptq/keras/quantization_facade.py +10 -1
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +9 -1
- model_compression_toolkit/qat/__init__.py +5 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +9 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -1
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +63 -55
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py +29 -18
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py +78 -57
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py +69 -54
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +0 -10
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +93 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +46 -28
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +6 -5
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +51 -19
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +8 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +19 -9
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +7 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +46 -32
- model_compression_toolkit/xquant/keras/keras_report_utils.py +11 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +0 -98
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py +0 -129
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_pytorch.py +0 -108
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +0 -217
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_keras.py +0 -130
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_pytorch.py +0 -109
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +0 -215
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_keras.py +0 -130
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_pytorch.py +0 -110
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +0 -222
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +0 -132
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py +0 -110
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +0 -219
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +0 -132
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py +0 -109
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +0 -246
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +0 -135
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +0 -113
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +0 -230
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +0 -132
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +0 -110
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +0 -332
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +0 -140
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +0 -122
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +0 -55
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_keras.py +0 -89
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +0 -78
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +0 -55
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_keras.py +0 -118
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_pytorch.py +0 -100
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.134735.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.134735.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.134735.dist-info}/top_level.txt +0 -0
@@ -23,12 +23,12 @@ if FOUND_SONY_CUSTOM_LAYERS:
|
|
23
23
|
|
24
24
|
if version.parse(tf.__version__) >= version.parse("2.13"):
|
25
25
|
from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
|
26
|
-
MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
|
27
|
-
Conv2DTranspose,
|
26
|
+
MaxPooling2D, AveragePooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
|
27
|
+
Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum, Softmax
|
28
28
|
else:
|
29
29
|
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
|
30
|
-
MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
|
31
|
-
Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum
|
30
|
+
MaxPooling2D, AveragePooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
|
31
|
+
Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum, Softmax
|
32
32
|
|
33
33
|
from model_compression_toolkit import DefaultDict
|
34
34
|
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS, \
|
@@ -36,72 +36,93 @@ from model_compression_toolkit.target_platform_capabilities.constants import KER
|
|
36
36
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
|
37
37
|
from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
|
38
38
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
|
39
|
-
|
39
|
+
AttachTpcToFramework
|
40
40
|
|
41
41
|
|
42
|
-
class
|
42
|
+
class AttachTpcToKeras(AttachTpcToFramework):
|
43
43
|
def __init__(self):
|
44
44
|
super().__init__()
|
45
45
|
|
46
46
|
self._opset2layer = {
|
47
|
-
OperatorSetNames.
|
48
|
-
OperatorSetNames.
|
49
|
-
OperatorSetNames.
|
50
|
-
OperatorSetNames.
|
51
|
-
OperatorSetNames.
|
52
|
-
OperatorSetNames.
|
53
|
-
OperatorSetNames.
|
54
|
-
OperatorSetNames.
|
55
|
-
OperatorSetNames.
|
56
|
-
OperatorSetNames.
|
57
|
-
OperatorSetNames.
|
58
|
-
OperatorSetNames.
|
59
|
-
OperatorSetNames.
|
60
|
-
OperatorSetNames.
|
61
|
-
OperatorSetNames.
|
62
|
-
OperatorSetNames.
|
63
|
-
OperatorSetNames.
|
64
|
-
OperatorSetNames.
|
65
|
-
OperatorSetNames.
|
66
|
-
OperatorSetNames.
|
67
|
-
OperatorSetNames.
|
68
|
-
OperatorSetNames.
|
69
|
-
OperatorSetNames.
|
70
|
-
OperatorSetNames.
|
71
|
-
OperatorSetNames.
|
72
|
-
OperatorSetNames.
|
73
|
-
|
74
|
-
|
75
|
-
OperatorSetNames.
|
76
|
-
OperatorSetNames.
|
77
|
-
OperatorSetNames.
|
78
|
-
OperatorSetNames.
|
79
|
-
OperatorSetNames.
|
80
|
-
OperatorSetNames.
|
81
|
-
OperatorSetNames.
|
82
|
-
OperatorSetNames.
|
83
|
-
OperatorSetNames.
|
84
|
-
OperatorSetNames.
|
85
|
-
OperatorSetNames.
|
86
|
-
OperatorSetNames.
|
87
|
-
OperatorSetNames.
|
88
|
-
OperatorSetNames.
|
89
|
-
OperatorSetNames.
|
90
|
-
OperatorSetNames.
|
91
|
-
OperatorSetNames.
|
47
|
+
OperatorSetNames.CONV: [Conv2D, tf.nn.conv2d],
|
48
|
+
OperatorSetNames.DEPTHWISE_CONV: [DepthwiseConv2D, tf.nn.depthwise_conv2d],
|
49
|
+
OperatorSetNames.CONV_TRANSPOSE: [Conv2DTranspose, tf.nn.conv2d_transpose],
|
50
|
+
OperatorSetNames.FULLY_CONNECTED: [Dense],
|
51
|
+
OperatorSetNames.CONCATENATE: [tf.concat, Concatenate],
|
52
|
+
OperatorSetNames.STACK: [tf.stack],
|
53
|
+
OperatorSetNames.UNSTACK: [tf.unstack],
|
54
|
+
OperatorSetNames.GATHER: [tf.gather, tf.compat.v1.gather],
|
55
|
+
OperatorSetNames.EXPAND: [],
|
56
|
+
OperatorSetNames.BATCH_NORM: [BatchNormalization, tf.nn.batch_normalization],
|
57
|
+
OperatorSetNames.RELU: [tf.nn.relu, ReLU, LayerFilterParams(Activation, activation="relu")],
|
58
|
+
OperatorSetNames.RELU6: [tf.nn.relu6],
|
59
|
+
OperatorSetNames.LEAKY_RELU: [tf.nn.leaky_relu, LeakyReLU, LayerFilterParams(Activation, activation="leaky_relu")],
|
60
|
+
OperatorSetNames.HARD_TANH: [LayerFilterParams(Activation, activation="hard_tanh")],
|
61
|
+
OperatorSetNames.ADD: [tf.add, Add],
|
62
|
+
OperatorSetNames.SUB: [tf.subtract, Subtract],
|
63
|
+
OperatorSetNames.MUL: [tf.math.multiply, Multiply],
|
64
|
+
OperatorSetNames.DIV: [tf.math.divide, tf.math.truediv],
|
65
|
+
OperatorSetNames.MIN: [tf.math.minimum, Minimum],
|
66
|
+
OperatorSetNames.MAX: [tf.math.maximum, Maximum],
|
67
|
+
OperatorSetNames.PRELU: [PReLU],
|
68
|
+
OperatorSetNames.SWISH: [tf.nn.swish, LayerFilterParams(Activation, activation="swish")],
|
69
|
+
OperatorSetNames.HARDSWISH: [LayerFilterParams(Activation, activation="hard_swish")],
|
70
|
+
OperatorSetNames.SIGMOID: [tf.nn.sigmoid, LayerFilterParams(Activation, activation="sigmoid")],
|
71
|
+
OperatorSetNames.TANH: [tf.nn.tanh, LayerFilterParams(Activation, activation="tanh")],
|
72
|
+
OperatorSetNames.GELU: [tf.nn.gelu, LayerFilterParams(Activation, activation="gelu")],
|
73
|
+
OperatorSetNames.HARDSIGMOID: [tf.keras.activations.hard_sigmoid,
|
74
|
+
LayerFilterParams(Activation, activation="hard_sigmoid")],
|
75
|
+
OperatorSetNames.FLATTEN: [Flatten],
|
76
|
+
OperatorSetNames.GET_ITEM: [tf.__operators__.getitem],
|
77
|
+
OperatorSetNames.RESHAPE: [Reshape, tf.reshape],
|
78
|
+
OperatorSetNames.PERMUTE: [Permute],
|
79
|
+
OperatorSetNames.TRANSPOSE: [tf.transpose],
|
80
|
+
OperatorSetNames.UNSQUEEZE: [tf.expand_dims],
|
81
|
+
OperatorSetNames.SQUEEZE: [tf.squeeze],
|
82
|
+
OperatorSetNames.DROPOUT: [Dropout],
|
83
|
+
OperatorSetNames.SPLIT_CHUNK: [tf.split],
|
84
|
+
OperatorSetNames.MAXPOOL: [MaxPooling2D, tf.nn.avg_pool2d],
|
85
|
+
OperatorSetNames.AVGPOOL: [AveragePooling2D],
|
86
|
+
OperatorSetNames.SIZE: [tf.size],
|
87
|
+
OperatorSetNames.RESIZE: [tf.image.resize],
|
88
|
+
OperatorSetNames.PAD: [tf.pad, Cropping2D],
|
89
|
+
OperatorSetNames.FOLD: [tf.space_to_batch_nd],
|
90
|
+
OperatorSetNames.SHAPE: [tf.shape, tf.compat.v1.shape],
|
91
|
+
OperatorSetNames.EQUAL: [tf.math.equal],
|
92
|
+
OperatorSetNames.ARGMAX: [tf.math.argmax],
|
93
|
+
OperatorSetNames.TOPK: [tf.nn.top_k],
|
94
|
+
OperatorSetNames.FAKE_QUANT: [tf.quantization.fake_quant_with_min_max_vars],
|
95
|
+
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [tf.image.combined_non_max_suppression],
|
96
|
+
OperatorSetNames.ZERO_PADDING2D: [ZeroPadding2D],
|
97
|
+
OperatorSetNames.CAST: [tf.cast],
|
98
|
+
OperatorSetNames.STRIDED_SLICE: [tf.strided_slice],
|
99
|
+
OperatorSetNames.ELU: [tf.nn.elu, LayerFilterParams(Activation, activation="elu")],
|
100
|
+
OperatorSetNames.SOFTMAX: [tf.nn.softmax, Softmax,
|
101
|
+
LayerFilterParams(Activation, activation="softmax")],
|
102
|
+
OperatorSetNames.LOG_SOFTMAX: [tf.nn.log_softmax],
|
103
|
+
OperatorSetNames.ADD_BIAS: [tf.nn.bias_add],
|
104
|
+
OperatorSetNames.L2NORM: [tf.math.l2_normalize],
|
92
105
|
}
|
93
106
|
|
94
107
|
if FOUND_SONY_CUSTOM_LAYERS:
|
95
|
-
self._opset2layer[OperatorSetNames.
|
108
|
+
self._opset2layer[OperatorSetNames.SSD_POST_PROCESS] = [SSDPostProcess]
|
109
|
+
else:
|
110
|
+
# If Custom layers is not installed then we don't want the user to fail, but just ignore custom layers
|
111
|
+
# in the initialized framework TPC
|
112
|
+
self._opset2layer[OperatorSetNames.SSD_POST_PROCESS] = []
|
96
113
|
|
97
|
-
self._opset2attr_mapping = {
|
98
|
-
|
99
|
-
|
100
|
-
|
114
|
+
self._opset2attr_mapping = {
|
115
|
+
OperatorSetNames.CONV: {
|
116
|
+
KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
|
117
|
+
BIAS_ATTR: DefaultDict(default_value=BIAS)},
|
118
|
+
OperatorSetNames.CONV_TRANSPOSE: {
|
119
|
+
KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
|
120
|
+
BIAS_ATTR: DefaultDict(default_value=BIAS)},
|
121
|
+
OperatorSetNames.DEPTHWISE_CONV: {
|
101
122
|
KERNEL_ATTR: DefaultDict({
|
102
123
|
DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL,
|
103
124
|
tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL),
|
104
125
|
BIAS_ATTR: DefaultDict(default_value=BIAS)},
|
105
|
-
OperatorSetNames.
|
126
|
+
OperatorSetNames.FULLY_CONNECTED: {
|
106
127
|
KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
|
107
128
|
BIAS_ATTR: DefaultDict(default_value=BIAS)}}
|
@@ -18,74 +18,89 @@ import operator
|
|
18
18
|
import torch
|
19
19
|
from torch import add, sub, mul, div, divide, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, \
|
20
20
|
chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract, minimum, \
|
21
|
-
maximum
|
22
|
-
from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d, BatchNorm2d
|
23
|
-
|
24
|
-
from torch.nn import
|
21
|
+
maximum, softmax, fake_quantize_per_channel_affine
|
22
|
+
from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d, BatchNorm2d, Dropout, Flatten, Hardtanh, ReLU, ReLU6, \
|
23
|
+
PReLU, SiLU, Sigmoid, Tanh, Hardswish, Hardsigmoid, LeakyReLU, GELU, LogSoftmax, Softmax, ELU, AvgPool2d, ZeroPad2d
|
24
|
+
from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, hardsigmoid, leaky_relu, gelu, fold
|
25
25
|
import torch.nn.functional as F
|
26
|
-
from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, hardsigmoid, leaky_relu, gelu
|
27
26
|
|
28
27
|
from model_compression_toolkit import DefaultDict
|
29
28
|
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS, \
|
30
29
|
BIAS_ATTR
|
31
30
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
|
32
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
|
31
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams, Eq
|
33
32
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
|
34
|
-
|
33
|
+
AttachTpcToFramework
|
35
34
|
|
36
35
|
|
37
|
-
class
|
36
|
+
class AttachTpcToPytorch(AttachTpcToFramework):
|
38
37
|
def __init__(self):
|
39
38
|
super().__init__()
|
40
39
|
|
41
40
|
self._opset2layer = {
|
42
|
-
OperatorSetNames.
|
43
|
-
OperatorSetNames.
|
44
|
-
OperatorSetNames.
|
45
|
-
OperatorSetNames.
|
46
|
-
OperatorSetNames.
|
47
|
-
OperatorSetNames.
|
48
|
-
OperatorSetNames.
|
49
|
-
OperatorSetNames.
|
50
|
-
OperatorSetNames.
|
51
|
-
OperatorSetNames.
|
52
|
-
OperatorSetNames.
|
53
|
-
OperatorSetNames.
|
54
|
-
OperatorSetNames.
|
55
|
-
|
56
|
-
|
57
|
-
OperatorSetNames.
|
58
|
-
OperatorSetNames.
|
59
|
-
OperatorSetNames.
|
60
|
-
OperatorSetNames.
|
61
|
-
OperatorSetNames.
|
62
|
-
OperatorSetNames.
|
63
|
-
OperatorSetNames.
|
64
|
-
OperatorSetNames.
|
65
|
-
OperatorSetNames.
|
66
|
-
OperatorSetNames.
|
67
|
-
OperatorSetNames.
|
68
|
-
OperatorSetNames.
|
69
|
-
OperatorSetNames.
|
70
|
-
OperatorSetNames.
|
71
|
-
OperatorSetNames.
|
72
|
-
OperatorSetNames.
|
73
|
-
OperatorSetNames.
|
74
|
-
OperatorSetNames.
|
75
|
-
OperatorSetNames.
|
76
|
-
OperatorSetNames.
|
77
|
-
OperatorSetNames.
|
78
|
-
OperatorSetNames.
|
79
|
-
OperatorSetNames.
|
80
|
-
OperatorSetNames.
|
81
|
-
OperatorSetNames.
|
82
|
-
OperatorSetNames.
|
83
|
-
OperatorSetNames.
|
84
|
-
OperatorSetNames.
|
41
|
+
OperatorSetNames.CONV: [Conv2d],
|
42
|
+
OperatorSetNames.DEPTHWISE_CONV: [], # no specific operator for depthwise conv in pytorch
|
43
|
+
OperatorSetNames.CONV_TRANSPOSE: [ConvTranspose2d],
|
44
|
+
OperatorSetNames.FULLY_CONNECTED: [Linear],
|
45
|
+
OperatorSetNames.CONCATENATE: [torch.cat, torch.concat, torch.concatenate],
|
46
|
+
OperatorSetNames.STACK: [torch.stack],
|
47
|
+
OperatorSetNames.UNSTACK: [unbind],
|
48
|
+
OperatorSetNames.GATHER: [gather],
|
49
|
+
OperatorSetNames.EXPAND: [torch.Tensor.expand],
|
50
|
+
OperatorSetNames.BATCH_NORM: [BatchNorm2d],
|
51
|
+
OperatorSetNames.RELU: [torch.relu, ReLU, relu],
|
52
|
+
OperatorSetNames.RELU6: [ReLU6, relu6],
|
53
|
+
OperatorSetNames.LEAKY_RELU: [LeakyReLU, leaky_relu],
|
54
|
+
OperatorSetNames.HARD_TANH: [LayerFilterParams(Hardtanh, min_val=0),
|
55
|
+
LayerFilterParams(hardtanh, min_val=0)],
|
56
|
+
OperatorSetNames.ADD: [operator.add, add],
|
57
|
+
OperatorSetNames.SUB: [operator.sub, sub, subtract],
|
58
|
+
OperatorSetNames.MUL: [operator.mul, mul, multiply],
|
59
|
+
OperatorSetNames.DIV: [operator.truediv, div, divide],
|
60
|
+
OperatorSetNames.ADD_BIAS: [], # no specific operator for bias_add in pytorch
|
61
|
+
OperatorSetNames.MIN: [minimum],
|
62
|
+
OperatorSetNames.MAX: [maximum],
|
63
|
+
OperatorSetNames.PRELU: [PReLU, prelu],
|
64
|
+
OperatorSetNames.SWISH: [SiLU, silu],
|
65
|
+
OperatorSetNames.SIGMOID: [Sigmoid, sigmoid, F.sigmoid],
|
66
|
+
OperatorSetNames.TANH: [Tanh, tanh, F.tanh],
|
67
|
+
OperatorSetNames.GELU: [GELU, gelu],
|
68
|
+
OperatorSetNames.HARDSIGMOID: [Hardsigmoid, hardsigmoid],
|
69
|
+
OperatorSetNames.HARDSWISH: [Hardswish, hardswish],
|
70
|
+
OperatorSetNames.FLATTEN: [Flatten, flatten],
|
71
|
+
OperatorSetNames.GET_ITEM: [operator.getitem],
|
72
|
+
OperatorSetNames.RESHAPE: [reshape],
|
73
|
+
OperatorSetNames.UNSQUEEZE: [unsqueeze],
|
74
|
+
OperatorSetNames.SQUEEZE: [squeeze],
|
75
|
+
OperatorSetNames.PERMUTE: [permute],
|
76
|
+
OperatorSetNames.TRANSPOSE: [transpose],
|
77
|
+
OperatorSetNames.DROPOUT: [Dropout, dropout],
|
78
|
+
OperatorSetNames.SPLIT_CHUNK: [split, chunk],
|
79
|
+
OperatorSetNames.MAXPOOL: [MaxPool2d, F.max_pool2d],
|
80
|
+
OperatorSetNames.AVGPOOL: [AvgPool2d, F.avg_pool2d],
|
81
|
+
OperatorSetNames.SIZE: [torch.Tensor.size],
|
82
|
+
OperatorSetNames.RESIZE: [torch.Tensor.resize],
|
83
|
+
OperatorSetNames.PAD: [F.pad],
|
84
|
+
OperatorSetNames.FOLD: [fold],
|
85
|
+
OperatorSetNames.SHAPE: [torch.Tensor.shape],
|
86
|
+
OperatorSetNames.EQUAL: [equal],
|
87
|
+
OperatorSetNames.ARGMAX: [argmax],
|
88
|
+
OperatorSetNames.TOPK: [topk],
|
89
|
+
OperatorSetNames.FAKE_QUANT: [fake_quantize_per_channel_affine],
|
90
|
+
OperatorSetNames.ZERO_PADDING2D: [ZeroPad2d],
|
91
|
+
OperatorSetNames.CAST: [torch.Tensor.type],
|
92
|
+
OperatorSetNames.STRIDED_SLICE: [], # no such operator in pytorch, the equivalent is get_item which has a separate operator set
|
93
|
+
OperatorSetNames.ELU: [ELU, F.elu],
|
94
|
+
OperatorSetNames.SOFTMAX: [Softmax, softmax, F.softmax],
|
95
|
+
OperatorSetNames.LOG_SOFTMAX: [LogSoftmax],
|
96
|
+
OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
|
97
|
+
Eq('p', 2) | Eq('p', None))],
|
98
|
+
OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
|
99
|
+
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [] # no such operator in pytorch
|
85
100
|
}
|
86
101
|
|
87
102
|
pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
|
88
103
|
BIAS_ATTR: DefaultDict(default_value=BIAS)}
|
89
|
-
self._opset2attr_mapping = {OperatorSetNames.
|
90
|
-
OperatorSetNames.
|
91
|
-
OperatorSetNames.
|
104
|
+
self._opset2attr_mapping = {OperatorSetNames.CONV: pytorch_linear_attr_mapping,
|
105
|
+
OperatorSetNames.CONV_TRANSPOSE: pytorch_linear_attr_mapping,
|
106
|
+
OperatorSetNames.FULLY_CONNECTED: pytorch_linear_attr_mapping}
|
@@ -138,10 +138,8 @@ class OperationsToLayers:
|
|
138
138
|
OperationsSetToLayers), f'Operators set should be of type OperationsSetToLayers but it ' \
|
139
139
|
f'is of type {type(ops2layers)}'
|
140
140
|
|
141
|
-
# Assert that opset
|
142
|
-
|
143
|
-
assert opset_in_model, f'{ops2layers.name} is not defined in the target platform model that is associated with the target platform capabilities.'
|
144
|
-
assert not (ops2layers.name in existing_opset_names), f'OperationsSetToLayers names should be unique, but {ops2layers.name} appears to violate it.'
|
141
|
+
# Assert that opset has a unique name.
|
142
|
+
assert ops2layers.name not in existing_opset_names, f'OperationsSetToLayers names should be unique, but {ops2layers.name} appears to violate it.'
|
145
143
|
existing_opset_names.append(ops2layers.name)
|
146
144
|
|
147
145
|
# Assert that a layer does not appear in more than a single OperatorsSet in the TargetPlatformModel.
|
@@ -156,7 +156,6 @@ class TargetPlatformCapabilities(ImmutableClass):
|
|
156
156
|
if exc_value is not None:
|
157
157
|
print(exc_value, exc_value.args)
|
158
158
|
raise exc_value
|
159
|
-
self.raise_warnings()
|
160
159
|
self.layer2qco, self.filterlayer2qco = self._get_config_options_mapping()
|
161
160
|
_current_tpc.reset()
|
162
161
|
self.initialized_done()
|
@@ -226,15 +225,6 @@ class TargetPlatformCapabilities(ImmutableClass):
|
|
226
225
|
if opset_to_remove in self.__tp_model_opsets_not_used:
|
227
226
|
self.__tp_model_opsets_not_used.remove(opset_to_remove)
|
228
227
|
|
229
|
-
def raise_warnings(self):
|
230
|
-
"""
|
231
|
-
|
232
|
-
Log warnings regards unused opsets.
|
233
|
-
|
234
|
-
"""
|
235
|
-
for op in self.__tp_model_opsets_not_used:
|
236
|
-
Logger.warning(f'{op} is defined in TargetPlatformModel, but is not used in TargetPlatformCapabilities.')
|
237
|
-
|
238
228
|
@property
|
239
229
|
def is_simd_padding(self) -> bool:
|
240
230
|
"""
|
@@ -0,0 +1,93 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from pathlib import Path
|
16
|
+
from typing import Union
|
17
|
+
|
18
|
+
from model_compression_toolkit.logger import Logger
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
20
|
+
import json
|
21
|
+
|
22
|
+
|
23
|
+
def load_target_platform_model(tp_model_or_path: Union[TargetPlatformModel, str]) -> TargetPlatformModel:
|
24
|
+
"""
|
25
|
+
Parses the tp_model input, which can be either a TargetPlatformModel object
|
26
|
+
or a string path to a JSON file.
|
27
|
+
|
28
|
+
Parameters:
|
29
|
+
tp_model_or_path (Union[TargetPlatformModel, str]): Input target platform model or path to .JSON file.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
TargetPlatformModel: The parsed TargetPlatformModel.
|
33
|
+
|
34
|
+
Raises:
|
35
|
+
FileNotFoundError: If the JSON file does not exist.
|
36
|
+
ValueError: If the JSON content is invalid or cannot initialize the TargetPlatformModel.
|
37
|
+
TypeError: If the input is neither a TargetPlatformModel nor a valid JSON file path.
|
38
|
+
"""
|
39
|
+
if isinstance(tp_model_or_path, TargetPlatformModel):
|
40
|
+
return tp_model_or_path
|
41
|
+
|
42
|
+
if isinstance(tp_model_or_path, str):
|
43
|
+
path = Path(tp_model_or_path)
|
44
|
+
|
45
|
+
if not path.exists() or not path.is_file():
|
46
|
+
raise FileNotFoundError(f"The path '{tp_model_or_path}' is not a valid file.")
|
47
|
+
# Verify that the file has a .json extension
|
48
|
+
if path.suffix.lower() != '.json':
|
49
|
+
raise ValueError(f"The file '{path}' does not have a '.json' extension.")
|
50
|
+
try:
|
51
|
+
with path.open('r', encoding='utf-8') as file:
|
52
|
+
data = file.read()
|
53
|
+
except OSError as e:
|
54
|
+
raise ValueError(f"Error reading the file '{tp_model_or_path}': {e.strerror}.") from e
|
55
|
+
|
56
|
+
try:
|
57
|
+
return TargetPlatformModel.parse_raw(data)
|
58
|
+
except ValueError as e:
|
59
|
+
raise ValueError(f"Invalid JSON for loading TargetPlatformModel in '{tp_model_or_path}': {e}.") from e
|
60
|
+
except Exception as e:
|
61
|
+
raise ValueError(f"Unexpected error while initializing TargetPlatformModel: {e}.") from e
|
62
|
+
|
63
|
+
raise TypeError(
|
64
|
+
f"tp_model_or_path must be either a TargetPlatformModel instance or a string path to a JSON file, "
|
65
|
+
f"but received type '{type(tp_model_or_path).__name__}'."
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
def export_target_platform_model(model: TargetPlatformModel, export_path: Union[str, Path]) -> None:
|
70
|
+
"""
|
71
|
+
Exports a TargetPlatformModel instance to a JSON file.
|
72
|
+
|
73
|
+
Parameters:
|
74
|
+
model (TargetPlatformModel): The TargetPlatformModel instance to export.
|
75
|
+
export_path (Union[str, Path]): The file path to export the model to.
|
76
|
+
|
77
|
+
Raises:
|
78
|
+
ValueError: If the model is not an instance of TargetPlatformModel.
|
79
|
+
OSError: If there is an issue writing to the file.
|
80
|
+
"""
|
81
|
+
if not isinstance(model, TargetPlatformModel):
|
82
|
+
raise ValueError("The provided model is not a valid TargetPlatformModel instance.")
|
83
|
+
|
84
|
+
path = Path(export_path)
|
85
|
+
try:
|
86
|
+
# Ensure the parent directory exists
|
87
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
88
|
+
|
89
|
+
# Export the model to JSON and write to the file
|
90
|
+
with path.open('w', encoding='utf-8') as file:
|
91
|
+
file.write(model.json(indent=4))
|
92
|
+
except OSError as e:
|
93
|
+
raise OSError(f"Failed to write to file '{export_path}': {e.strerror}") from e
|
@@ -12,45 +12,63 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from model_compression_toolkit.constants import TENSORFLOW, PYTORCH
|
16
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \
|
17
|
+
TFLITE_TP_MODEL, QNNPACK_TP_MODEL
|
18
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
15
19
|
|
16
|
-
from model_compression_toolkit.target_platform_capabilities.
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_tp_model_imx500_v1
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import get_tp_model as get_tp_model_tflite_v1
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model as get_tp_model_qnnpack_v1
|
17
23
|
|
18
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.target_platform_capabilities import \
|
19
|
-
get_tpc_dict_by_fw as get_imx500_tpc
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.target_platform_capabilities import \
|
21
|
-
get_tpc_dict_by_fw as get_tflite_tpc
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.target_platform_capabilities import \
|
23
|
-
get_tpc_dict_by_fw as get_qnnpack_tpc
|
24
|
-
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, TFLITE_TP_MODEL, QNNPACK_TP_MODEL, LATEST
|
25
|
-
|
26
|
-
tpc_dict = {DEFAULT_TP_MODEL: get_imx500_tpc,
|
27
|
-
IMX500_TP_MODEL: get_imx500_tpc,
|
28
|
-
TFLITE_TP_MODEL: get_tflite_tpc,
|
29
|
-
QNNPACK_TP_MODEL: get_qnnpack_tpc}
|
30
24
|
|
25
|
+
# TODO: These methods need to be replaced once modifying the TPC API.
|
31
26
|
|
32
27
|
def get_target_platform_capabilities(fw_name: str,
|
33
28
|
target_platform_name: str,
|
34
|
-
target_platform_version: str = None) ->
|
29
|
+
target_platform_version: str = None) -> TargetPlatformModel:
|
35
30
|
"""
|
36
|
-
|
37
|
-
|
38
|
-
the target platform model can be 'default', 'imx500', 'tflite', or 'qnnpack'.
|
31
|
+
This is a degenerated function that only returns the MCT default TargetPlatformModel object, to comply with the
|
32
|
+
existing TPC API.
|
39
33
|
|
40
34
|
Args:
|
41
35
|
fw_name: Framework name of the TargetPlatformCapabilities.
|
42
36
|
target_platform_name: Target platform model name the model will use for inference.
|
43
37
|
target_platform_version: Target platform capabilities version.
|
38
|
+
|
44
39
|
Returns:
|
45
|
-
A
|
46
|
-
a framework information to it.
|
40
|
+
A default TargetPlatformModel object.
|
47
41
|
"""
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
if
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
42
|
+
|
43
|
+
assert fw_name in [TENSORFLOW, PYTORCH], f"Unsupported framework {fw_name}."
|
44
|
+
|
45
|
+
if target_platform_name == DEFAULT_TP_MODEL:
|
46
|
+
return get_tp_model_imx500_v1()
|
47
|
+
|
48
|
+
assert target_platform_version == 'v1' or target_platform_version is None, \
|
49
|
+
"The usage of get_target_platform_capabilities API is supported only with the default TPC ('v1')."
|
50
|
+
|
51
|
+
if target_platform_name == IMX500_TP_MODEL:
|
52
|
+
return get_tp_model_imx500_v1()
|
53
|
+
elif target_platform_name == TFLITE_TP_MODEL:
|
54
|
+
return get_tp_model_tflite_v1()
|
55
|
+
elif target_platform_name == QNNPACK_TP_MODEL:
|
56
|
+
return get_tp_model_qnnpack_v1()
|
57
|
+
|
58
|
+
raise ValueError(f"Unsupported target platform name {target_platform_name}.")
|
59
|
+
|
60
|
+
|
61
|
+
def get_tpc_model(name: str, tp_model: TargetPlatformModel):
|
62
|
+
"""
|
63
|
+
This is a utility method that just returns the TargetPlatformModel that it receives, to support existing TPC API.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
name: the name of the TargetPlatformModel (not used in this function).
|
67
|
+
tp_model: a TargetPlatformModel to return.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
The given TargetPlatformModel object.
|
71
|
+
|
72
|
+
"""
|
73
|
+
|
74
|
+
return tp_model
|
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py
CHANGED
@@ -16,9 +16,10 @@ from model_compression_toolkit.verify_packages import FOUND_TORCH, FOUND_TF
|
|
16
16
|
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model, generate_tp_model, \
|
17
17
|
get_op_quantization_configs
|
18
18
|
if FOUND_TF:
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_keras_tpc_latest
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
|
21
|
+
get_tpc_model as generate_keras_tpc
|
21
22
|
if FOUND_TORCH:
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.
|
23
|
-
|
24
|
-
|
23
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_pytorch_tpc_latest
|
24
|
+
from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
|
25
|
+
get_tpc_model as generate_pytorch_tpc
|