mct-nightly 2.2.0.20241223.525__py3-none-any.whl → 2.2.0.20241224.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20241223.525
3
+ Version: 2.2.0.20241224.532
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=QUv2SuXkm47u1PtvuiNSgpOZaPe-Fr6tU172aQNgLJs,1573
1
+ model_compression_toolkit/__init__.py,sha256=9suCm_ya-q7binwaiEyGExSDb8bJgOWwJ3wBnV_el2Y,1573
2
2
  model_compression_toolkit/constants.py,sha256=i_R6uXBfO1ph_X6DNJych2x59SUojfJbn7dNjs_mZnc,3846
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -433,9 +433,12 @@ model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROB
433
433
  model_compression_toolkit/target_platform_capabilities/schema/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
434
434
  model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=E6Zz8boibgfq8EVpZWyl0TOdFrv9qrwiVHUzYPIKVrQ,528
435
435
  model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=ZDFN2N4dRRP6qs0HxsHXEJbZCwYByo3JL9sBCJolDBs,4656
436
- model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=G00JebGnCBukLXLoYZEVqq0ArtC0F8GMrwvJ1BpUwbU,25207
436
+ model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=GsToWecNswbQCnLPBNsFBK5uX8B69jN9N-tXEmmU6WI,25223
437
437
  model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py,sha256=1FXmDVSqm-dr3xzH4vRo4NmAgyzBZjqHo5l63MUq4r0,1403
438
438
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/__init__.py,sha256=WCP1wfFZgM4eFm-pPeUinr5R_aSx5qwfSQqLZCXUNBA,1513
439
+ model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py,sha256=GK_eI9Oq-kgBdfXm0AwgXkYgGKL0FEthqrTd0X_XWg0,2872
440
+ model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py,sha256=mkTjg9JZEKMT_QHtlQPGUZvcuF0lofZUs1DU1h43JjM,6671
441
+ model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py,sha256=cn2sP9fnHJ8qfR1AMTwW58mXLOcdT6M7xyE3TLRjnY0,5443
439
442
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
440
443
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/current_tpc.py,sha256=fIheShGOnxWYKqT8saHpBJqOU5RG_1Hp9qHry7IviIw,2115
441
444
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/layer_filter_params.py,sha256=Cl6-mACpje2jM8RJkibbqE3hvTkFR3r26-lW021mIiA,4019
@@ -557,8 +560,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
557
560
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
558
561
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
559
562
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
560
- mct_nightly-2.2.0.20241223.525.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
561
- mct_nightly-2.2.0.20241223.525.dist-info/METADATA,sha256=PYunic67kp02_7_Cjbo7t9m6MLKtu8tKorpuxnz1feI,26453
562
- mct_nightly-2.2.0.20241223.525.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
563
- mct_nightly-2.2.0.20241223.525.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
564
- mct_nightly-2.2.0.20241223.525.dist-info/RECORD,,
563
+ mct_nightly-2.2.0.20241224.532.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
564
+ mct_nightly-2.2.0.20241224.532.dist-info/METADATA,sha256=TevkRWHqm2UgHf34bwK7NWHCKt4tIUfdOvpVaA4-CIU,26453
565
+ mct_nightly-2.2.0.20241224.532.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
566
+ mct_nightly-2.2.0.20241224.532.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
567
+ mct_nightly-2.2.0.20241224.532.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20241223.000525"
30
+ __version__ = "2.2.0.20241224.000532"
@@ -25,7 +25,7 @@ from model_compression_toolkit.target_platform_capabilities.constants import OPS
25
25
  class OperatorSetNames(Enum):
26
26
  OPSET_CONV = "Conv"
27
27
  OPSET_DEPTHWISE_CONV = "DepthwiseConv2D"
28
- OPSET_CONV_TRANSPOSE = "ConvTraspose"
28
+ OPSET_CONV_TRANSPOSE = "ConvTranspose"
29
29
  OPSET_FULLY_CONNECTED = "FullyConnected"
30
30
  OPSET_CONCATENATE = "Concatenate"
31
31
  OPSET_STACK = "Stack"
@@ -41,7 +41,8 @@ class OperatorSetNames(Enum):
41
41
  OPSET_SUB = "Sub"
42
42
  OPSET_MUL = "Mul"
43
43
  OPSET_DIV = "Div"
44
- OPSET_MIN_MAX = "MinMax"
44
+ OPSET_MIN = "Min"
45
+ OPSET_MAX = "Max"
45
46
  OPSET_PRELU = "PReLU"
46
47
  OPSET_SWISH = "Swish"
47
48
  OPSET_SIGMOID = "Sigmoid"
@@ -0,0 +1,56 @@
1
+ from typing import Dict, Tuple, List, Any, Optional
2
+
3
+ from model_compression_toolkit import DefaultDict
4
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
5
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, \
6
+ OperationsSetToLayers
7
+
8
+
9
+ class AttachTpModelToFw:
10
+
11
+ def __init__(self):
12
+ self._opset2layer = None
13
+
14
+ # A mapping that associates each layer type in the operation set (with weight attributes and a quantization
15
+ # configuration in the target platform model) to its framework-specific attribute name. If not all layer types
16
+ # in the operation set are provided in the mapping, a DefaultDict should be supplied to handle missing entries.
17
+ self._opset2attr_mapping = None # Mapping of operation sets to their corresponding framework-specific layers
18
+
19
+ def attach(self, tpc_model: TargetPlatformModel,
20
+ custom_opset2layer: Dict[str, Tuple[List[Any], Optional[Dict[str, DefaultDict]]]] = None
21
+ ) -> TargetPlatformCapabilities:
22
+ """
23
+ Attaching a TargetPlatformModel which includes a platform capabilities description to specific
24
+ framework's operators.
25
+
26
+ Args:
27
+ tpc_model: a TargetPlatformModel object.
28
+ custom_opset2layer: optional set of custom operator sets which allows to add/override the built-in set
29
+ of framework operator, to define a specific behavior for those operators. This dictionary should map
30
+ an operator set unique name to a pair of: a list of framework operators and an optional
31
+ operator's attributes names mapping.
32
+
33
+ Returns: a TargetPlatformCapabilities object.
34
+
35
+ """
36
+
37
+ tpc = TargetPlatformCapabilities(tpc_model)
38
+
39
+ with tpc:
40
+ for opset_name, operators in self._opset2layer.items():
41
+ attr_mapping = self._opset2attr_mapping.get(opset_name)
42
+ OperationsSetToLayers(opset_name, operators, attr_mapping=attr_mapping)
43
+
44
+ if custom_opset2layer is not None:
45
+ for opset_name, operators in custom_opset2layer.items():
46
+ if len(operators) == 1:
47
+ OperationsSetToLayers(opset_name, operators[0])
48
+ elif len(operators) == 2:
49
+ OperationsSetToLayers(opset_name, operators[0], attr_mapping=operators[1])
50
+ else:
51
+ raise ValueError(f"Custom operator set to layer mapping should include up to 2 elements - "
52
+ f"a list of layers to attach to the operator and an optional mapping of "
53
+ f"attributes names, but given a mapping contains {len(operators)} elements.")
54
+
55
+ return tpc
56
+
@@ -0,0 +1,107 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import tensorflow as tf
17
+ from packaging import version
18
+
19
+ from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS
20
+
21
+ if FOUND_SONY_CUSTOM_LAYERS:
22
+ from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess
23
+
24
+ if version.parse(tf.__version__) >= version.parse("2.13"):
25
+ from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
26
+ MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
27
+ Conv2DTranspose, Identity, Concatenate, BatchNormalization, Minimum, Maximum
28
+ else:
29
+ from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
30
+ MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
31
+ Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum
32
+
33
+ from model_compression_toolkit import DefaultDict
34
+ from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS, \
35
+ BIAS_ATTR, KERAS_KERNEL, KERAS_DEPTHWISE_KERNEL
36
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
37
+ from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
38
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
39
+ AttachTpModelToFw
40
+
41
+
42
+ class AttachTpModelToKeras(AttachTpModelToFw):
43
+ def __init__(self):
44
+ super().__init__()
45
+
46
+ self._opset2layer = {
47
+ OperatorSetNames.OPSET_CONV.value: [Conv2D, tf.nn.conv2d],
48
+ OperatorSetNames.OPSET_DEPTHWISE_CONV.value: [DepthwiseConv2D, tf.nn.depthwise_conv2d],
49
+ OperatorSetNames.OPSET_CONV_TRANSPOSE.value: [Conv2DTranspose, tf.nn.conv2d_transpose],
50
+ OperatorSetNames.OPSET_FULLY_CONNECTED.value: [Dense],
51
+ OperatorSetNames.OPSET_CONCATENATE.value: [tf.concat, Concatenate],
52
+ OperatorSetNames.OPSET_STACK.value: [tf.stack],
53
+ OperatorSetNames.OPSET_UNSTACK.value: [tf.unstack],
54
+ OperatorSetNames.OPSET_GATHER.value: [tf.gather, tf.compat.v1.gather],
55
+ OperatorSetNames.OPSET_EXPAND.value: [],
56
+ OperatorSetNames.OPSET_BATCH_NORM.value: [BatchNormalization],
57
+ OperatorSetNames.OPSET_RELU.value: [tf.nn.relu, ReLU],
58
+ OperatorSetNames.OPSET_RELU6.value: [tf.nn.relu6],
59
+ OperatorSetNames.OPSET_LEAKY_RELU.value: [tf.nn.leaky_relu, LeakyReLU],
60
+ OperatorSetNames.OPSET_HARD_TANH.value: [LayerFilterParams(Activation, activation="hard_tanh")],
61
+ OperatorSetNames.OPSET_ADD.value: [tf.add, Add],
62
+ OperatorSetNames.OPSET_SUB.value: [tf.subtract, Subtract],
63
+ OperatorSetNames.OPSET_MUL.value: [tf.math.multiply, Multiply],
64
+ OperatorSetNames.OPSET_DIV.value: [tf.math.divide, tf.math.truediv],
65
+ OperatorSetNames.OPSET_MIN.value: [tf.math.minimum, Minimum],
66
+ OperatorSetNames.OPSET_MAX.value: [tf.math.maximum, Maximum],
67
+ OperatorSetNames.OPSET_PRELU.value: [PReLU],
68
+ OperatorSetNames.OPSET_SWISH.value: [tf.nn.swish, LayerFilterParams(Activation, activation="swish")],
69
+ OperatorSetNames.OPSET_SIGMOID.value: [tf.nn.sigmoid, LayerFilterParams(Activation, activation="sigmoid")],
70
+ OperatorSetNames.OPSET_TANH.value: [tf.nn.tanh, LayerFilterParams(Activation, activation="tanh")],
71
+ OperatorSetNames.OPSET_GELU.value: [tf.nn.gelu, LayerFilterParams(Activation, activation="gelu")],
72
+ OperatorSetNames.OPSET_HARDSIGMOID.value: [tf.keras.activations.hard_sigmoid,
73
+ LayerFilterParams(Activation, activation="hard_sigmoid")],
74
+ OperatorSetNames.OPSET_FLATTEN.value: [Flatten],
75
+ OperatorSetNames.OPSET_GET_ITEM.value: [tf.__operators__.getitem],
76
+ OperatorSetNames.OPSET_RESHAPE.value: [Reshape, tf.reshape],
77
+ OperatorSetNames.OPSET_PERMUTE.value: [Permute],
78
+ OperatorSetNames.OPSET_TRANSPOSE.value: [tf.transpose],
79
+ OperatorSetNames.OPSET_DROPOUT.value: [Dropout],
80
+ OperatorSetNames.OPSET_SPLIT.value: [tf.split],
81
+ OperatorSetNames.OPSET_MAXPOOL.value: [MaxPooling2D],
82
+ OperatorSetNames.OPSET_SHAPE.value: [tf.shape, tf.compat.v1.shape],
83
+ OperatorSetNames.OPSET_EQUAL.value: [tf.math.equal],
84
+ OperatorSetNames.OPSET_ARGMAX.value: [tf.math.argmax],
85
+ OperatorSetNames.OPSET_TOPK.value: [tf.nn.top_k],
86
+ OperatorSetNames.OPSET_FAKE_QUANT_WITH_MIN_MAX_VARS.value: [tf.quantization.fake_quant_with_min_max_vars],
87
+ OperatorSetNames.OPSET_COMBINED_NON_MAX_SUPPRESSION.value: [tf.image.combined_non_max_suppression],
88
+ OperatorSetNames.OPSET_CROPPING2D.value: [Cropping2D],
89
+ OperatorSetNames.OPSET_ZERO_PADDING2d.value: [ZeroPadding2D],
90
+ OperatorSetNames.OPSET_CAST.value: [tf.cast],
91
+ OperatorSetNames.OPSET_STRIDED_SLICE.value: [tf.strided_slice]
92
+ }
93
+
94
+ if FOUND_SONY_CUSTOM_LAYERS:
95
+ self._opset2layer[OperatorSetNames.OPSET_POST_PROCESS] = [SSDPostProcess]
96
+
97
+ self._opset2attr_mapping = {OperatorSetNames.OPSET_CONV.value: {
98
+ KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
99
+ BIAS_ATTR: DefaultDict(default_value=BIAS)},
100
+ OperatorSetNames.OPSET_DEPTHWISE_CONV.value: {
101
+ KERNEL_ATTR: DefaultDict({
102
+ DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL,
103
+ tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL),
104
+ BIAS_ATTR: DefaultDict(default_value=BIAS)},
105
+ OperatorSetNames.OPSET_FULLY_CONNECTED.value: {
106
+ KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
107
+ BIAS_ATTR: DefaultDict(default_value=BIAS)}}
@@ -0,0 +1,91 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import operator
17
+
18
+ import torch
19
+ from torch import add, sub, mul, div, divide, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, \
20
+ chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract, minimum, \
21
+ maximum
22
+ from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d, BatchNorm2d
23
+ from torch.nn import Dropout, Flatten, Hardtanh
24
+ from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, Hardsigmoid, LeakyReLU, GELU
25
+ import torch.nn.functional as F
26
+ from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, hardsigmoid, leaky_relu, gelu
27
+
28
+ from model_compression_toolkit import DefaultDict
29
+ from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS, \
30
+ BIAS_ATTR
31
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
32
+ from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
33
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
34
+ AttachTpModelToFw
35
+
36
+
37
+ class AttachTpModelToPytorch(AttachTpModelToFw):
38
+ def __init__(self):
39
+ super().__init__()
40
+
41
+ self._opset2layer = {
42
+ OperatorSetNames.OPSET_CONV.value: [Conv2d],
43
+ OperatorSetNames.OPSET_CONV_TRANSPOSE.value: [ConvTranspose2d],
44
+ OperatorSetNames.OPSET_FULLY_CONNECTED.value: [Linear],
45
+ OperatorSetNames.OPSET_CONCATENATE.value: [torch.cat, torch.concat, torch.concatenate],
46
+ OperatorSetNames.OPSET_STACK.value: [torch.stack],
47
+ OperatorSetNames.OPSET_UNSTACK.value: [unbind],
48
+ OperatorSetNames.OPSET_GATHER.value: [gather],
49
+ OperatorSetNames.OPSET_EXPAND.value: [torch.Tensor.expand],
50
+ OperatorSetNames.OPSET_BATCH_NORM.value: [BatchNorm2d],
51
+ OperatorSetNames.OPSET_RELU.value: [torch.relu, ReLU, relu],
52
+ OperatorSetNames.OPSET_RELU6.value: [ReLU6, relu6],
53
+ OperatorSetNames.OPSET_LEAKY_RELU.value: [LeakyReLU, leaky_relu],
54
+ OperatorSetNames.OPSET_HARD_TANH.value: [LayerFilterParams(Hardtanh, min_val=0),
55
+ LayerFilterParams(hardtanh, min_val=0)],
56
+ OperatorSetNames.OPSET_ADD.value: [operator.add, add],
57
+ OperatorSetNames.OPSET_SUB.value: [operator.sub, sub, subtract],
58
+ OperatorSetNames.OPSET_MUL.value: [operator.mul, mul, multiply],
59
+ OperatorSetNames.OPSET_DIV.value: [operator.truediv, div, divide],
60
+ OperatorSetNames.OPSET_MIN.value: [minimum],
61
+ OperatorSetNames.OPSET_MAX.value: [maximum],
62
+ OperatorSetNames.OPSET_PRELU.value: [PReLU, prelu],
63
+ OperatorSetNames.OPSET_SWISH.value: [SiLU, silu],
64
+ OperatorSetNames.OPSET_SIGMOID.value: [Sigmoid, sigmoid, F.sigmoid],
65
+ OperatorSetNames.OPSET_TANH.value: [Tanh, tanh, F.tanh],
66
+ OperatorSetNames.OPSET_GELU.value: [GELU, gelu],
67
+ OperatorSetNames.OPSET_HARDSIGMOID.value: [Hardsigmoid, hardsigmoid],
68
+ OperatorSetNames.OPSET_HARDSWISH.value: [Hardswish, hardswish],
69
+ OperatorSetNames.OPSET_FLATTEN.value: [Flatten, flatten],
70
+ OperatorSetNames.OPSET_GET_ITEM.value: [operator.getitem],
71
+ OperatorSetNames.OPSET_RESHAPE.value: [reshape],
72
+ OperatorSetNames.OPSET_UNSQUEEZE.value: [unsqueeze],
73
+ OperatorSetNames.OPSET_SQUEEZE.value: [squeeze],
74
+ OperatorSetNames.OPSET_PERMUTE.value: [permute],
75
+ OperatorSetNames.OPSET_TRANSPOSE.value: [transpose],
76
+ OperatorSetNames.OPSET_DROPOUT.value: [Dropout, dropout],
77
+ OperatorSetNames.OPSET_SPLIT.value: [split],
78
+ OperatorSetNames.OPSET_CHUNK.value: [chunk],
79
+ OperatorSetNames.OPSET_MAXPOOL.value: [MaxPool2d],
80
+ OperatorSetNames.OPSET_SIZE.value: [torch.Tensor.size],
81
+ OperatorSetNames.OPSET_SHAPE.value: [torch.Tensor.shape],
82
+ OperatorSetNames.OPSET_EQUAL.value: [equal],
83
+ OperatorSetNames.OPSET_ARGMAX.value: [argmax],
84
+ OperatorSetNames.OPSET_TOPK.value: [topk],
85
+ }
86
+
87
+ pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
88
+ BIAS_ATTR: DefaultDict(default_value=BIAS)}
89
+ self._opset2attr_mapping = {OperatorSetNames.OPSET_CONV.value: pytorch_linear_attr_mapping,
90
+ OperatorSetNames.OPSET_CONV_TRANSPOSE.value: pytorch_linear_attr_mapping,
91
+ OperatorSetNames.OPSET_FULLY_CONNECTED.value: pytorch_linear_attr_mapping}