mct-nightly 2.2.0.20241223.525__py3-none-any.whl → 2.2.0.20241224.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241223.525.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241223.525.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/RECORD +10 -7
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +3 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py +56 -0
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py +107 -0
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py +91 -0
- {mct_nightly-2.2.0.20241223.525.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241223.525.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241223.525.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/top_level.txt +0 -0
{mct_nightly-2.2.0.20241223.525.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=9suCm_ya-q7binwaiEyGExSDb8bJgOWwJ3wBnV_el2Y,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=i_R6uXBfO1ph_X6DNJych2x59SUojfJbn7dNjs_mZnc,3846
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -433,9 +433,12 @@ model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROB
|
|
433
433
|
model_compression_toolkit/target_platform_capabilities/schema/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
|
434
434
|
model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=E6Zz8boibgfq8EVpZWyl0TOdFrv9qrwiVHUzYPIKVrQ,528
|
435
435
|
model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=ZDFN2N4dRRP6qs0HxsHXEJbZCwYByo3JL9sBCJolDBs,4656
|
436
|
-
model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=
|
436
|
+
model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=GsToWecNswbQCnLPBNsFBK5uX8B69jN9N-tXEmmU6WI,25223
|
437
437
|
model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py,sha256=1FXmDVSqm-dr3xzH4vRo4NmAgyzBZjqHo5l63MUq4r0,1403
|
438
438
|
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/__init__.py,sha256=WCP1wfFZgM4eFm-pPeUinr5R_aSx5qwfSQqLZCXUNBA,1513
|
439
|
+
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py,sha256=GK_eI9Oq-kgBdfXm0AwgXkYgGKL0FEthqrTd0X_XWg0,2872
|
440
|
+
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py,sha256=mkTjg9JZEKMT_QHtlQPGUZvcuF0lofZUs1DU1h43JjM,6671
|
441
|
+
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py,sha256=cn2sP9fnHJ8qfR1AMTwW58mXLOcdT6M7xyE3TLRjnY0,5443
|
439
442
|
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
|
440
443
|
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/current_tpc.py,sha256=fIheShGOnxWYKqT8saHpBJqOU5RG_1Hp9qHry7IviIw,2115
|
441
444
|
model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/layer_filter_params.py,sha256=Cl6-mACpje2jM8RJkibbqE3hvTkFR3r26-lW021mIiA,4019
|
@@ -557,8 +560,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
557
560
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
|
558
561
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
559
562
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
560
|
-
mct_nightly-2.2.0.
|
561
|
-
mct_nightly-2.2.0.
|
562
|
-
mct_nightly-2.2.0.
|
563
|
-
mct_nightly-2.2.0.
|
564
|
-
mct_nightly-2.2.0.
|
563
|
+
mct_nightly-2.2.0.20241224.532.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
564
|
+
mct_nightly-2.2.0.20241224.532.dist-info/METADATA,sha256=TevkRWHqm2UgHf34bwK7NWHCKt4tIUfdOvpVaA4-CIU,26453
|
565
|
+
mct_nightly-2.2.0.20241224.532.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
566
|
+
mct_nightly-2.2.0.20241224.532.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
567
|
+
mct_nightly-2.2.0.20241224.532.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.2.0.
|
30
|
+
__version__ = "2.2.0.20241224.000532"
|
@@ -25,7 +25,7 @@ from model_compression_toolkit.target_platform_capabilities.constants import OPS
|
|
25
25
|
class OperatorSetNames(Enum):
|
26
26
|
OPSET_CONV = "Conv"
|
27
27
|
OPSET_DEPTHWISE_CONV = "DepthwiseConv2D"
|
28
|
-
OPSET_CONV_TRANSPOSE = "
|
28
|
+
OPSET_CONV_TRANSPOSE = "ConvTranspose"
|
29
29
|
OPSET_FULLY_CONNECTED = "FullyConnected"
|
30
30
|
OPSET_CONCATENATE = "Concatenate"
|
31
31
|
OPSET_STACK = "Stack"
|
@@ -41,7 +41,8 @@ class OperatorSetNames(Enum):
|
|
41
41
|
OPSET_SUB = "Sub"
|
42
42
|
OPSET_MUL = "Mul"
|
43
43
|
OPSET_DIV = "Div"
|
44
|
-
|
44
|
+
OPSET_MIN = "Min"
|
45
|
+
OPSET_MAX = "Max"
|
45
46
|
OPSET_PRELU = "PReLU"
|
46
47
|
OPSET_SWISH = "Swish"
|
47
48
|
OPSET_SIGMOID = "Sigmoid"
|
@@ -0,0 +1,56 @@
|
|
1
|
+
from typing import Dict, Tuple, List, Any, Optional
|
2
|
+
|
3
|
+
from model_compression_toolkit import DefaultDict
|
4
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
5
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, \
|
6
|
+
OperationsSetToLayers
|
7
|
+
|
8
|
+
|
9
|
+
class AttachTpModelToFw:
|
10
|
+
|
11
|
+
def __init__(self):
|
12
|
+
self._opset2layer = None
|
13
|
+
|
14
|
+
# A mapping that associates each layer type in the operation set (with weight attributes and a quantization
|
15
|
+
# configuration in the target platform model) to its framework-specific attribute name. If not all layer types
|
16
|
+
# in the operation set are provided in the mapping, a DefaultDict should be supplied to handle missing entries.
|
17
|
+
self._opset2attr_mapping = None # Mapping of operation sets to their corresponding framework-specific layers
|
18
|
+
|
19
|
+
def attach(self, tpc_model: TargetPlatformModel,
|
20
|
+
custom_opset2layer: Dict[str, Tuple[List[Any], Optional[Dict[str, DefaultDict]]]] = None
|
21
|
+
) -> TargetPlatformCapabilities:
|
22
|
+
"""
|
23
|
+
Attaching a TargetPlatformModel which includes a platform capabilities description to specific
|
24
|
+
framework's operators.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
tpc_model: a TargetPlatformModel object.
|
28
|
+
custom_opset2layer: optional set of custom operator sets which allows to add/override the built-in set
|
29
|
+
of framework operator, to define a specific behavior for those operators. This dictionary should map
|
30
|
+
an operator set unique name to a pair of: a list of framework operators and an optional
|
31
|
+
operator's attributes names mapping.
|
32
|
+
|
33
|
+
Returns: a TargetPlatformCapabilities object.
|
34
|
+
|
35
|
+
"""
|
36
|
+
|
37
|
+
tpc = TargetPlatformCapabilities(tpc_model)
|
38
|
+
|
39
|
+
with tpc:
|
40
|
+
for opset_name, operators in self._opset2layer.items():
|
41
|
+
attr_mapping = self._opset2attr_mapping.get(opset_name)
|
42
|
+
OperationsSetToLayers(opset_name, operators, attr_mapping=attr_mapping)
|
43
|
+
|
44
|
+
if custom_opset2layer is not None:
|
45
|
+
for opset_name, operators in custom_opset2layer.items():
|
46
|
+
if len(operators) == 1:
|
47
|
+
OperationsSetToLayers(opset_name, operators[0])
|
48
|
+
elif len(operators) == 2:
|
49
|
+
OperationsSetToLayers(opset_name, operators[0], attr_mapping=operators[1])
|
50
|
+
else:
|
51
|
+
raise ValueError(f"Custom operator set to layer mapping should include up to 2 elements - "
|
52
|
+
f"a list of layers to attach to the operator and an optional mapping of "
|
53
|
+
f"attributes names, but given a mapping contains {len(operators)} elements.")
|
54
|
+
|
55
|
+
return tpc
|
56
|
+
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import tensorflow as tf
|
17
|
+
from packaging import version
|
18
|
+
|
19
|
+
from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS
|
20
|
+
|
21
|
+
if FOUND_SONY_CUSTOM_LAYERS:
|
22
|
+
from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess
|
23
|
+
|
24
|
+
if version.parse(tf.__version__) >= version.parse("2.13"):
|
25
|
+
from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
|
26
|
+
MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
|
27
|
+
Conv2DTranspose, Identity, Concatenate, BatchNormalization, Minimum, Maximum
|
28
|
+
else:
|
29
|
+
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
|
30
|
+
MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
|
31
|
+
Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum
|
32
|
+
|
33
|
+
from model_compression_toolkit import DefaultDict
|
34
|
+
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS, \
|
35
|
+
BIAS_ATTR, KERAS_KERNEL, KERAS_DEPTHWISE_KERNEL
|
36
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
|
37
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
|
38
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
|
39
|
+
AttachTpModelToFw
|
40
|
+
|
41
|
+
|
42
|
+
class AttachTpModelToKeras(AttachTpModelToFw):
|
43
|
+
def __init__(self):
|
44
|
+
super().__init__()
|
45
|
+
|
46
|
+
self._opset2layer = {
|
47
|
+
OperatorSetNames.OPSET_CONV.value: [Conv2D, tf.nn.conv2d],
|
48
|
+
OperatorSetNames.OPSET_DEPTHWISE_CONV.value: [DepthwiseConv2D, tf.nn.depthwise_conv2d],
|
49
|
+
OperatorSetNames.OPSET_CONV_TRANSPOSE.value: [Conv2DTranspose, tf.nn.conv2d_transpose],
|
50
|
+
OperatorSetNames.OPSET_FULLY_CONNECTED.value: [Dense],
|
51
|
+
OperatorSetNames.OPSET_CONCATENATE.value: [tf.concat, Concatenate],
|
52
|
+
OperatorSetNames.OPSET_STACK.value: [tf.stack],
|
53
|
+
OperatorSetNames.OPSET_UNSTACK.value: [tf.unstack],
|
54
|
+
OperatorSetNames.OPSET_GATHER.value: [tf.gather, tf.compat.v1.gather],
|
55
|
+
OperatorSetNames.OPSET_EXPAND.value: [],
|
56
|
+
OperatorSetNames.OPSET_BATCH_NORM.value: [BatchNormalization],
|
57
|
+
OperatorSetNames.OPSET_RELU.value: [tf.nn.relu, ReLU],
|
58
|
+
OperatorSetNames.OPSET_RELU6.value: [tf.nn.relu6],
|
59
|
+
OperatorSetNames.OPSET_LEAKY_RELU.value: [tf.nn.leaky_relu, LeakyReLU],
|
60
|
+
OperatorSetNames.OPSET_HARD_TANH.value: [LayerFilterParams(Activation, activation="hard_tanh")],
|
61
|
+
OperatorSetNames.OPSET_ADD.value: [tf.add, Add],
|
62
|
+
OperatorSetNames.OPSET_SUB.value: [tf.subtract, Subtract],
|
63
|
+
OperatorSetNames.OPSET_MUL.value: [tf.math.multiply, Multiply],
|
64
|
+
OperatorSetNames.OPSET_DIV.value: [tf.math.divide, tf.math.truediv],
|
65
|
+
OperatorSetNames.OPSET_MIN.value: [tf.math.minimum, Minimum],
|
66
|
+
OperatorSetNames.OPSET_MAX.value: [tf.math.maximum, Maximum],
|
67
|
+
OperatorSetNames.OPSET_PRELU.value: [PReLU],
|
68
|
+
OperatorSetNames.OPSET_SWISH.value: [tf.nn.swish, LayerFilterParams(Activation, activation="swish")],
|
69
|
+
OperatorSetNames.OPSET_SIGMOID.value: [tf.nn.sigmoid, LayerFilterParams(Activation, activation="sigmoid")],
|
70
|
+
OperatorSetNames.OPSET_TANH.value: [tf.nn.tanh, LayerFilterParams(Activation, activation="tanh")],
|
71
|
+
OperatorSetNames.OPSET_GELU.value: [tf.nn.gelu, LayerFilterParams(Activation, activation="gelu")],
|
72
|
+
OperatorSetNames.OPSET_HARDSIGMOID.value: [tf.keras.activations.hard_sigmoid,
|
73
|
+
LayerFilterParams(Activation, activation="hard_sigmoid")],
|
74
|
+
OperatorSetNames.OPSET_FLATTEN.value: [Flatten],
|
75
|
+
OperatorSetNames.OPSET_GET_ITEM.value: [tf.__operators__.getitem],
|
76
|
+
OperatorSetNames.OPSET_RESHAPE.value: [Reshape, tf.reshape],
|
77
|
+
OperatorSetNames.OPSET_PERMUTE.value: [Permute],
|
78
|
+
OperatorSetNames.OPSET_TRANSPOSE.value: [tf.transpose],
|
79
|
+
OperatorSetNames.OPSET_DROPOUT.value: [Dropout],
|
80
|
+
OperatorSetNames.OPSET_SPLIT.value: [tf.split],
|
81
|
+
OperatorSetNames.OPSET_MAXPOOL.value: [MaxPooling2D],
|
82
|
+
OperatorSetNames.OPSET_SHAPE.value: [tf.shape, tf.compat.v1.shape],
|
83
|
+
OperatorSetNames.OPSET_EQUAL.value: [tf.math.equal],
|
84
|
+
OperatorSetNames.OPSET_ARGMAX.value: [tf.math.argmax],
|
85
|
+
OperatorSetNames.OPSET_TOPK.value: [tf.nn.top_k],
|
86
|
+
OperatorSetNames.OPSET_FAKE_QUANT_WITH_MIN_MAX_VARS.value: [tf.quantization.fake_quant_with_min_max_vars],
|
87
|
+
OperatorSetNames.OPSET_COMBINED_NON_MAX_SUPPRESSION.value: [tf.image.combined_non_max_suppression],
|
88
|
+
OperatorSetNames.OPSET_CROPPING2D.value: [Cropping2D],
|
89
|
+
OperatorSetNames.OPSET_ZERO_PADDING2d.value: [ZeroPadding2D],
|
90
|
+
OperatorSetNames.OPSET_CAST.value: [tf.cast],
|
91
|
+
OperatorSetNames.OPSET_STRIDED_SLICE.value: [tf.strided_slice]
|
92
|
+
}
|
93
|
+
|
94
|
+
if FOUND_SONY_CUSTOM_LAYERS:
|
95
|
+
self._opset2layer[OperatorSetNames.OPSET_POST_PROCESS] = [SSDPostProcess]
|
96
|
+
|
97
|
+
self._opset2attr_mapping = {OperatorSetNames.OPSET_CONV.value: {
|
98
|
+
KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
|
99
|
+
BIAS_ATTR: DefaultDict(default_value=BIAS)},
|
100
|
+
OperatorSetNames.OPSET_DEPTHWISE_CONV.value: {
|
101
|
+
KERNEL_ATTR: DefaultDict({
|
102
|
+
DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL,
|
103
|
+
tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL),
|
104
|
+
BIAS_ATTR: DefaultDict(default_value=BIAS)},
|
105
|
+
OperatorSetNames.OPSET_FULLY_CONNECTED.value: {
|
106
|
+
KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
|
107
|
+
BIAS_ATTR: DefaultDict(default_value=BIAS)}}
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import operator
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from torch import add, sub, mul, div, divide, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, \
|
20
|
+
chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract, minimum, \
|
21
|
+
maximum
|
22
|
+
from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d, BatchNorm2d
|
23
|
+
from torch.nn import Dropout, Flatten, Hardtanh
|
24
|
+
from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, Hardsigmoid, LeakyReLU, GELU
|
25
|
+
import torch.nn.functional as F
|
26
|
+
from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, hardsigmoid, leaky_relu, gelu
|
27
|
+
|
28
|
+
from model_compression_toolkit import DefaultDict
|
29
|
+
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS, \
|
30
|
+
BIAS_ATTR
|
31
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
|
32
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
|
33
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
|
34
|
+
AttachTpModelToFw
|
35
|
+
|
36
|
+
|
37
|
+
class AttachTpModelToPytorch(AttachTpModelToFw):
|
38
|
+
def __init__(self):
|
39
|
+
super().__init__()
|
40
|
+
|
41
|
+
self._opset2layer = {
|
42
|
+
OperatorSetNames.OPSET_CONV.value: [Conv2d],
|
43
|
+
OperatorSetNames.OPSET_CONV_TRANSPOSE.value: [ConvTranspose2d],
|
44
|
+
OperatorSetNames.OPSET_FULLY_CONNECTED.value: [Linear],
|
45
|
+
OperatorSetNames.OPSET_CONCATENATE.value: [torch.cat, torch.concat, torch.concatenate],
|
46
|
+
OperatorSetNames.OPSET_STACK.value: [torch.stack],
|
47
|
+
OperatorSetNames.OPSET_UNSTACK.value: [unbind],
|
48
|
+
OperatorSetNames.OPSET_GATHER.value: [gather],
|
49
|
+
OperatorSetNames.OPSET_EXPAND.value: [torch.Tensor.expand],
|
50
|
+
OperatorSetNames.OPSET_BATCH_NORM.value: [BatchNorm2d],
|
51
|
+
OperatorSetNames.OPSET_RELU.value: [torch.relu, ReLU, relu],
|
52
|
+
OperatorSetNames.OPSET_RELU6.value: [ReLU6, relu6],
|
53
|
+
OperatorSetNames.OPSET_LEAKY_RELU.value: [LeakyReLU, leaky_relu],
|
54
|
+
OperatorSetNames.OPSET_HARD_TANH.value: [LayerFilterParams(Hardtanh, min_val=0),
|
55
|
+
LayerFilterParams(hardtanh, min_val=0)],
|
56
|
+
OperatorSetNames.OPSET_ADD.value: [operator.add, add],
|
57
|
+
OperatorSetNames.OPSET_SUB.value: [operator.sub, sub, subtract],
|
58
|
+
OperatorSetNames.OPSET_MUL.value: [operator.mul, mul, multiply],
|
59
|
+
OperatorSetNames.OPSET_DIV.value: [operator.truediv, div, divide],
|
60
|
+
OperatorSetNames.OPSET_MIN.value: [minimum],
|
61
|
+
OperatorSetNames.OPSET_MAX.value: [maximum],
|
62
|
+
OperatorSetNames.OPSET_PRELU.value: [PReLU, prelu],
|
63
|
+
OperatorSetNames.OPSET_SWISH.value: [SiLU, silu],
|
64
|
+
OperatorSetNames.OPSET_SIGMOID.value: [Sigmoid, sigmoid, F.sigmoid],
|
65
|
+
OperatorSetNames.OPSET_TANH.value: [Tanh, tanh, F.tanh],
|
66
|
+
OperatorSetNames.OPSET_GELU.value: [GELU, gelu],
|
67
|
+
OperatorSetNames.OPSET_HARDSIGMOID.value: [Hardsigmoid, hardsigmoid],
|
68
|
+
OperatorSetNames.OPSET_HARDSWISH.value: [Hardswish, hardswish],
|
69
|
+
OperatorSetNames.OPSET_FLATTEN.value: [Flatten, flatten],
|
70
|
+
OperatorSetNames.OPSET_GET_ITEM.value: [operator.getitem],
|
71
|
+
OperatorSetNames.OPSET_RESHAPE.value: [reshape],
|
72
|
+
OperatorSetNames.OPSET_UNSQUEEZE.value: [unsqueeze],
|
73
|
+
OperatorSetNames.OPSET_SQUEEZE.value: [squeeze],
|
74
|
+
OperatorSetNames.OPSET_PERMUTE.value: [permute],
|
75
|
+
OperatorSetNames.OPSET_TRANSPOSE.value: [transpose],
|
76
|
+
OperatorSetNames.OPSET_DROPOUT.value: [Dropout, dropout],
|
77
|
+
OperatorSetNames.OPSET_SPLIT.value: [split],
|
78
|
+
OperatorSetNames.OPSET_CHUNK.value: [chunk],
|
79
|
+
OperatorSetNames.OPSET_MAXPOOL.value: [MaxPool2d],
|
80
|
+
OperatorSetNames.OPSET_SIZE.value: [torch.Tensor.size],
|
81
|
+
OperatorSetNames.OPSET_SHAPE.value: [torch.Tensor.shape],
|
82
|
+
OperatorSetNames.OPSET_EQUAL.value: [equal],
|
83
|
+
OperatorSetNames.OPSET_ARGMAX.value: [argmax],
|
84
|
+
OperatorSetNames.OPSET_TOPK.value: [topk],
|
85
|
+
}
|
86
|
+
|
87
|
+
pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
|
88
|
+
BIAS_ATTR: DefaultDict(default_value=BIAS)}
|
89
|
+
self._opset2attr_mapping = {OperatorSetNames.OPSET_CONV.value: pytorch_linear_attr_mapping,
|
90
|
+
OperatorSetNames.OPSET_CONV_TRANSPOSE.value: pytorch_linear_attr_mapping,
|
91
|
+
OperatorSetNames.OPSET_FULLY_CONNECTED.value: pytorch_linear_attr_mapping}
|
{mct_nightly-2.2.0.20241223.525.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.2.0.20241223.525.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/top_level.txt
RENAMED
File without changes
|