mct-nightly 2.2.0.20241119.516__py3-none-any.whl → 2.2.0.20241121.524__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20241119.516
3
+ Version: 2.2.0.20241121.524
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=LtqmejEQcQLXceDuGvAmjezjSZB8czuS4kfbMbUzF7k,1573
1
+ model_compression_toolkit/__init__.py,sha256=X-Df1ZQGO1n9pF09XDXYplM7M5oQzCDKFyOASqQtRVA,1573
2
2
  model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -223,7 +223,7 @@ model_compression_toolkit/core/pytorch/constants.py,sha256=YwD_joIF0vK8UG2vW1NVv
223
223
  model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
224
224
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=-Vls1P_8Ckm_18nnOsmQkZ71SmzHwtQLbQ383Z4Rb-U,4365
225
225
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
226
- model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=4uzO-lXfuitlC3NHx5-k2Fjm8VHa1T7ox9c8DSxYs9M,29437
226
+ model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=SgxmSdzAQOPI9YHt4Q9-OeDi8fzAdgASHQ4nZ5maPsg,29599
227
227
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
228
228
  model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=xpKj99OZKT9NT0vKIl_cOe8d89d2gef1gKoNT6PFElE,4989
229
229
  model_compression_toolkit/core/pytorch/utils.py,sha256=7VbgcLwtQvdEEc_AJgSOQ3U3KRKCICFPaBirN1fIQxg,3940
@@ -246,6 +246,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_
246
246
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=sw3jIOUSvfWUeD8l3rGcUOtC6QuzpMIQm8V3RQAM53Q,4741
247
247
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py,sha256=7GZY7lU3LUUaO5iiccHkUP62PB0QeGAGOZdUSGMkFBY,4450
248
248
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py,sha256=XhiLVcnCc_gF-6mjxbf9C4bYg5YL_GCvDJmcdLkBNAg,4151
249
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_linear.py,sha256=3-OHYPun5Rt7GITqV3ZekJk59tsuY9ZYSpRpxKsNEVA,3450
249
250
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py,sha256=CXSMASpc_Zed3BJ2CsER69zKxE6ncFvvKQWDO1JxKYI,5849
250
251
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=VNg-VgzCxSyqy2J3neEPl6U0SPO8UIVU_T47bGhz4FE,38459
251
252
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py,sha256=q1a3HieQtaOmWG2WGXp6GHYAvxa3CZ9dJUx9dqMAsS8,5695
@@ -477,9 +478,9 @@ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_
477
478
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py,sha256=XM6qBLIvzsmdFf-AZq5WOlORK2GXC_X-gulReNxHb9E,6601
478
479
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py,sha256=nP05jqvh6uaj30a3W7zEkJfKtqfP0Nz5bobwRqbYrdM,5807
479
480
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py,sha256=tHTUvsaerSfbe22pU0kIDauPpFD7Pq5EmZytVIDkHz4,717
480
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py,sha256=Ee7M3YVymdv6HYsm7coB8N0dyTOhlAhLdxfSLJXCuoU,15665
481
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py,sha256=u8qD1XkHwU4LIoNbmC5mtZd8lZ8gZ4XFihZmoYwAulc,7641
482
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py,sha256=GCghKkkZOKNTAzwyoZZPid9alGiufNUBzDj2yE7YUSU,6709
481
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py,sha256=b1GAFZMtp0Hf1Ybq8gDLUk90m1HFD00LwtEsFpoN5mY,17240
482
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py,sha256=P6gOd5PNMlIp6bcPPfTIX-hTO0AgT9XswrBdvqm-oJ0,8271
483
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py,sha256=4b3WTV3IDqoqDYx37ba-lxF56K-P5FYyPfIM_TWttQ4,7247
483
484
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
484
485
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py,sha256=is00rNrDmmirYsyMtMkWz0DwOA92-x7hAJwpd6z1n2E,2806
485
486
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py,sha256=CXC-HQolSDu7j8V-Xm-SWGCd74gXB3XnAkEhI_TVbIQ,1516
@@ -558,8 +559,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
558
559
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
559
560
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
560
561
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
561
- mct_nightly-2.2.0.20241119.516.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
562
- mct_nightly-2.2.0.20241119.516.dist-info/METADATA,sha256=-P-wCECFkhsiwVycbvgMYHIWF8OEJLWA_watkr62kFs,26472
563
- mct_nightly-2.2.0.20241119.516.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
564
- mct_nightly-2.2.0.20241119.516.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
565
- mct_nightly-2.2.0.20241119.516.dist-info/RECORD,,
562
+ mct_nightly-2.2.0.20241121.524.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
563
+ mct_nightly-2.2.0.20241121.524.dist-info/METADATA,sha256=TpyybcMe4-UNnUKIqO_9NdHs7z0vYVj_7Y7wOs3VaYw,26472
564
+ mct_nightly-2.2.0.20241121.524.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
565
+ mct_nightly-2.2.0.20241121.524.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
566
+ mct_nightly-2.2.0.20241121.524.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20241119.000516"
30
+ __version__ = "2.2.0.20241121.000524"
@@ -0,0 +1,83 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+
18
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
19
+ from model_compression_toolkit.core.common import BaseNode, Graph, BaseSubstitution
20
+ from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
21
+ from model_compression_toolkit.core.pytorch.constants import *
22
+ from model_compression_toolkit.logger import Logger
23
+
24
+
25
+ class FunctionalLinear(BaseSubstitution):
26
+ """
27
+ Replace functional linear with Linear.
28
+ """
29
+
30
+ def __init__(self):
31
+ """
32
+ Matches: functional linear
33
+ """
34
+ func_node = NodeOperationMatcher(F.linear)
35
+ super().__init__(matcher_instance=func_node)
36
+
37
+ def substitute(self,
38
+ graph: Graph,
39
+ func_node: FunctionalNode) -> Graph:
40
+ """
41
+ Substitute functional.linear and its inputs with Linear.
42
+ Args:
43
+ graph: Graph we apply the substitution on.
44
+ node: node that match the pattern in the substitution init.
45
+
46
+ Returns:
47
+ Graph after applying the substitution.
48
+ """
49
+
50
+ # Create new node of layer Linear
51
+ if 1 not in func_node.weights:
52
+ Logger.critical(f'Weight input missing for node {func_node.name}.') # pragma: no cover
53
+ # Extract index of kernel and bias according to tensor_input_allocs if they were input as kwargs. If
54
+ # they were input as args, use their fixed positions.
55
+ weight_index = func_node.tensor_input_allocs.index(KERNEL) if KERNEL in func_node.tensor_input_allocs else 1
56
+ bias_index = func_node.tensor_input_allocs.index(BIAS) if BIAS in func_node.tensor_input_allocs else 2
57
+ if weight_index not in func_node.weights:
58
+ Logger.critical(f'Mismatch between tensor_input_allocs and weight index in node {func_node.name}.') # pragma: no cover
59
+ weight = func_node.weights[weight_index]
60
+ bias = func_node.weights.get(bias_index)
61
+
62
+ framework_attr = {
63
+ IN_FEATURES: func_node.input_shape[0][-1],
64
+ OUT_FEATURES: func_node.output_shape[0][-1],
65
+ BIAS: bias is not None,
66
+ }
67
+
68
+ weights = {KERNEL: weight} if bias is None else {KERNEL: weight, BIAS: bias}
69
+
70
+ new_node = BaseNode(
71
+ name=func_node.name,
72
+ framework_attr=framework_attr,
73
+ input_shape=func_node.input_shape[0],
74
+ output_shape=func_node.output_shape,
75
+ weights=weights,
76
+ layer_class=nn.Linear,
77
+ has_activation=func_node.has_activation,
78
+ reuse=func_node.reuse,
79
+ reuse_group=func_node.reuse_group
80
+ )
81
+
82
+ graph.replace_node(func_node, new_node)
83
+ return graph
@@ -50,6 +50,8 @@ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.fu
50
50
  FunctionalBatchNorm
51
51
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_layer_norm import \
52
52
  FunctionalLayerNorm
53
+ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_linear import \
54
+ FunctionalLinear
53
55
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.linear_collapsing import \
54
56
  pytorch_linear_collapsing
55
57
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \
@@ -266,6 +268,7 @@ class PytorchImplementation(FrameworkImplementation):
266
268
  FunctionalConvSubstitution(fw_info),
267
269
  FunctionalBatchNorm(),
268
270
  FunctionalLayerNorm(),
271
+ FunctionalLinear(),
269
272
  RemoveIdentity()]
270
273
 
271
274
  def get_substitutions_pre_statistics_collection(self,
@@ -31,15 +31,20 @@ OPSET_DIMENSION_MANIPULATION_OPS = "DimensionManipulationOps"
31
31
  OPSET_MERGE_OPS = "MergeOps"
32
32
  OPSET_CONV = "Conv"
33
33
  OPSET_FULLY_CONNECTED = "FullyConnected"
34
+ OPSET_BATCH_NORM = "BatchNorm"
34
35
  OPSET_ANY_RELU = "AnyReLU"
35
36
  OPSET_ADD = "Add"
36
37
  OPSET_SUB = "Sub"
37
38
  OPSET_MUL = "Mul"
38
39
  OPSET_DIV = "Div"
40
+ OPSET_MIN_MAX = "MinMax"
39
41
  OPSET_PRELU = "PReLU"
40
42
  OPSET_SWISH = "Swish"
41
43
  OPSET_SIGMOID = "Sigmoid"
42
44
  OPSET_TANH = "Tanh"
45
+ OPSET_GELU = "Gelu"
46
+ OPSET_HARDSIGMOID = "HardSigmoid"
47
+ OPSET_HARDSWISH = "HardSwish"
43
48
 
44
49
 
45
50
  def get_tp_model() -> TargetPlatformModel:
@@ -172,6 +177,11 @@ def generate_tp_model(default_config: OpQuantizationConfig,
172
177
  # If the QuantizationConfigOptions contains only one configuration,
173
178
  # this configuration will be used for the operation quantization:
174
179
  default_configuration_options = tp.QuantizationConfigOptions([default_config])
180
+ default_config_input16 = default_config.clone_and_edit(supported_input_activation_n_bits=(8, 16))
181
+ default_config_options_16bit = tp.QuantizationConfigOptions([default_config_input16,
182
+ default_config_input16.clone_and_edit(activation_n_bits=16,
183
+ signedness=Signedness.SIGNED)],
184
+ base_config=default_config_input16)
175
185
 
176
186
  # Create a QuantizationConfigOptions for quantizing constants in functional ops.
177
187
  # Constant configuration is similar to the default eight bit configuration except for PoT
@@ -212,6 +222,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
212
222
  weights_per_channel_threshold=False))
213
223
  qpreserving_const_config_options = tp.QuantizationConfigOptions([qpreserving_const_config])
214
224
 
225
+ mp_cfg_list_16bit = [mp_cfg.clone_and_edit(activation_n_bits=16, signedness=Signedness.SIGNED)
226
+ for mp_cfg in mixed_precision_cfg_list]
227
+
215
228
  # Create a TargetPlatformModel and set its default quantization config.
216
229
  # This default configuration will be used for all operations
217
230
  # unless specified otherwise (see OperatorsSet, for example):
@@ -246,30 +259,37 @@ def generate_tp_model(default_config: OpQuantizationConfig,
246
259
  tp.OperatorsSet(OPSET_MERGE_OPS, const_configuration_options_inout16_per_tensor)
247
260
 
248
261
  # Create Mixed-Precision quantization configuration options from the given list of OpQuantizationConfig objects
249
- mixed_precision_configuration_options = tp.QuantizationConfigOptions(mixed_precision_cfg_list,
262
+ mixed_precision_configuration_options = tp.QuantizationConfigOptions(mixed_precision_cfg_list + mp_cfg_list_16bit,
250
263
  base_config=base_config)
251
264
 
252
265
  # Define operator sets that use mixed_precision_configuration_options:
253
266
  conv = tp.OperatorsSet(OPSET_CONV, mixed_precision_configuration_options)
254
267
  fc = tp.OperatorsSet(OPSET_FULLY_CONNECTED, mixed_precision_configuration_options)
255
268
 
256
- # Define operations sets without quantization configuration
257
- # options (useful for creating fusing patterns, for example):
258
- any_relu = tp.OperatorsSet(OPSET_ANY_RELU)
269
+ tp.OperatorsSet(OPSET_BATCH_NORM, default_config_options_16bit)
270
+
271
+ # Note: Operations sets without quantization configuration are useful for creating fusing patterns
272
+ any_relu = tp.OperatorsSet(OPSET_ANY_RELU, default_config_options_16bit)
259
273
  add = tp.OperatorsSet(OPSET_ADD, const_configuration_options_inout16)
260
274
  sub = tp.OperatorsSet(OPSET_SUB, const_configuration_options_inout16)
261
275
  mul = tp.OperatorsSet(OPSET_MUL, const_configuration_options_inout16)
262
276
  div = tp.OperatorsSet(OPSET_DIV, const_configuration_options)
263
- prelu = tp.OperatorsSet(OPSET_PRELU)
264
- swish = tp.OperatorsSet(OPSET_SWISH)
265
- sigmoid = tp.OperatorsSet(OPSET_SIGMOID)
266
- tanh = tp.OperatorsSet(OPSET_TANH)
277
+ tp.OperatorsSet(OPSET_MIN_MAX, const_configuration_options_inout16)
278
+ prelu = tp.OperatorsSet(OPSET_PRELU, default_config_options_16bit)
279
+ swish = tp.OperatorsSet(OPSET_SWISH, default_config_options_16bit)
280
+ sigmoid = tp.OperatorsSet(OPSET_SIGMOID, default_config_options_16bit)
281
+ tanh = tp.OperatorsSet(OPSET_TANH, default_config_options_16bit)
282
+ gelu = tp.OperatorsSet(OPSET_GELU, default_config_options_16bit)
283
+ hardsigmoid = tp.OperatorsSet(OPSET_HARDSIGMOID, default_config_options_16bit)
284
+ hardswish = tp.OperatorsSet(OPSET_HARDSWISH, default_config_options_16bit)
267
285
 
268
286
  # Combine multiple operators into a single operator to avoid quantization between
269
287
  # them. To do this we define fusing patterns using the OperatorsSets that were created.
270
288
  # To group multiple sets with regard to fusing, an OperatorSetConcat can be created
271
- activations_after_conv_to_fuse = tp.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh)
272
- activations_after_fc_to_fuse = tp.OperatorSetConcat(any_relu, swish, sigmoid)
289
+ activations_after_conv_to_fuse = tp.OperatorSetConcat(any_relu, swish, prelu, sigmoid,
290
+ tanh, gelu, hardswish, hardsigmoid)
291
+ activations_after_fc_to_fuse = tp.OperatorSetConcat(any_relu, swish, sigmoid, tanh, gelu,
292
+ hardswish, hardsigmoid)
273
293
  any_binary = tp.OperatorSetConcat(add, sub, mul, div)
274
294
 
275
295
  # ------------------- #
@@ -26,11 +26,11 @@ if FOUND_SONY_CUSTOM_LAYERS:
26
26
  if version.parse(tf.__version__) >= version.parse("2.13"):
27
27
  from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
28
28
  MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
29
- Conv2DTranspose, Identity, Concatenate
29
+ Conv2DTranspose, Identity, Concatenate, BatchNormalization, Minimum, Maximum
30
30
  else:
31
31
  from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
32
32
  MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
33
- Conv2DTranspose, Identity, Concatenate
33
+ Conv2DTranspose, Identity, Concatenate, BatchNormalization, Minimum, Maximum
34
34
 
35
35
  from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import get_tp_model
36
36
  import model_compression_toolkit as mct
@@ -38,7 +38,7 @@ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tp
38
38
  from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import OPSET_NO_QUANTIZATION, \
39
39
  OPSET_QUANTIZATION_PRESERVING, OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, OPSET_DIMENSION_MANIPULATION_OPS, \
40
40
  OPSET_MERGE_OPS, OPSET_CONV, OPSET_FULLY_CONNECTED, OPSET_ANY_RELU, OPSET_ADD, OPSET_SUB, OPSET_MUL, OPSET_DIV, \
41
- OPSET_PRELU, OPSET_SWISH, OPSET_SIGMOID, OPSET_TANH
41
+ OPSET_PRELU, OPSET_SWISH, OPSET_SIGMOID, OPSET_TANH, OPSET_GELU, OPSET_BATCH_NORM, OPSET_MIN_MAX, OPSET_HARDSIGMOID
42
42
 
43
43
  tp = mct.target_platform
44
44
 
@@ -117,6 +117,7 @@ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel):
117
117
  tp.OperationsSetToLayers(OPSET_FULLY_CONNECTED, [Dense],
118
118
  attr_mapping={KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
119
119
  BIAS_ATTR: DefaultDict(default_value=BIAS)})
120
+ tp.OperationsSetToLayers(OPSET_BATCH_NORM, [BatchNormalization])
120
121
  tp.OperationsSetToLayers(OPSET_ANY_RELU, [tf.nn.relu,
121
122
  tf.nn.relu6,
122
123
  tf.nn.leaky_relu,
@@ -128,9 +129,13 @@ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel):
128
129
  tp.OperationsSetToLayers(OPSET_SUB, [tf.subtract, Subtract])
129
130
  tp.OperationsSetToLayers(OPSET_MUL, [tf.math.multiply, Multiply])
130
131
  tp.OperationsSetToLayers(OPSET_DIV, [tf.math.divide, tf.math.truediv])
132
+ tp.OperationsSetToLayers(OPSET_MIN_MAX, [tf.math.minimum, tf.math.maximum, Minimum, Maximum])
131
133
  tp.OperationsSetToLayers(OPSET_PRELU, [PReLU])
132
134
  tp.OperationsSetToLayers(OPSET_SWISH, [tf.nn.swish, tp.LayerFilterParams(Activation, activation="swish")])
133
135
  tp.OperationsSetToLayers(OPSET_SIGMOID, [tf.nn.sigmoid, tp.LayerFilterParams(Activation, activation="sigmoid")])
134
136
  tp.OperationsSetToLayers(OPSET_TANH, [tf.nn.tanh, tp.LayerFilterParams(Activation, activation="tanh")])
137
+ tp.OperationsSetToLayers(OPSET_GELU, [tf.nn.gelu, tp.LayerFilterParams(Activation, activation="gelu")])
138
+ tp.OperationsSetToLayers(OPSET_HARDSIGMOID, [tf.keras.activations.hard_sigmoid,
139
+ tp.LayerFilterParams(Activation, activation="hard_sigmoid")])
135
140
 
136
141
  return keras_tpc
@@ -17,11 +17,13 @@ import operator
17
17
 
18
18
  import torch
19
19
  from torch import add, sub, mul, div, divide, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, \
20
- chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract
21
- from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d
20
+ chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract, minimum, \
21
+ maximum
22
+ from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d, BatchNorm2d
22
23
  from torch.nn import Dropout, Flatten, Hardtanh
23
- from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU
24
- from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu
24
+ from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, Hardsigmoid, LeakyReLU, GELU
25
+ import torch.nn.functional as F
26
+ from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, hardsigmoid, leaky_relu, gelu
25
27
 
26
28
  from model_compression_toolkit.defaultdict import DefaultDict
27
29
  from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, PYTORCH_KERNEL, \
@@ -32,7 +34,8 @@ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tp
32
34
  from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import OPSET_NO_QUANTIZATION, \
33
35
  OPSET_QUANTIZATION_PRESERVING, OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, OPSET_DIMENSION_MANIPULATION_OPS, \
34
36
  OPSET_MERGE_OPS, OPSET_CONV, OPSET_FULLY_CONNECTED, OPSET_ANY_RELU, OPSET_ADD, OPSET_SUB, OPSET_MUL, OPSET_DIV, \
35
- OPSET_PRELU, OPSET_SWISH, OPSET_SIGMOID, OPSET_TANH
37
+ OPSET_PRELU, OPSET_SWISH, OPSET_SIGMOID, OPSET_TANH, OPSET_GELU, OPSET_BATCH_NORM, OPSET_MIN_MAX, OPSET_HARDSIGMOID, \
38
+ OPSET_HARDSWISH
36
39
 
37
40
  tp = mct.target_platform
38
41
 
@@ -95,6 +98,7 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
95
98
  attr_mapping=pytorch_linear_attr_mapping)
96
99
  tp.OperationsSetToLayers(OPSET_FULLY_CONNECTED, [Linear],
97
100
  attr_mapping=pytorch_linear_attr_mapping)
101
+ tp.OperationsSetToLayers(OPSET_BATCH_NORM, [BatchNorm2d])
98
102
  tp.OperationsSetToLayers(OPSET_ANY_RELU, [torch.relu,
99
103
  ReLU,
100
104
  ReLU6,
@@ -109,9 +113,13 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
109
113
  tp.OperationsSetToLayers(OPSET_SUB, [operator.sub, sub, subtract])
110
114
  tp.OperationsSetToLayers(OPSET_MUL, [operator.mul, mul, multiply])
111
115
  tp.OperationsSetToLayers(OPSET_DIV, [operator.truediv, div, divide])
116
+ tp.OperationsSetToLayers(OPSET_MIN_MAX, [minimum, maximum])
112
117
  tp.OperationsSetToLayers(OPSET_PRELU, [PReLU, prelu])
113
- tp.OperationsSetToLayers(OPSET_SWISH, [SiLU, silu, Hardswish, hardswish])
114
- tp.OperationsSetToLayers(OPSET_SIGMOID, [Sigmoid, sigmoid])
115
- tp.OperationsSetToLayers(OPSET_TANH, [Tanh, tanh])
118
+ tp.OperationsSetToLayers(OPSET_SWISH, [SiLU, silu])
119
+ tp.OperationsSetToLayers(OPSET_SIGMOID, [Sigmoid, sigmoid, F.sigmoid])
120
+ tp.OperationsSetToLayers(OPSET_TANH, [Tanh, tanh, F.tanh])
121
+ tp.OperationsSetToLayers(OPSET_GELU, [GELU, gelu])
122
+ tp.OperationsSetToLayers(OPSET_HARDSIGMOID, [Hardsigmoid, hardsigmoid])
123
+ tp.OperationsSetToLayers(OPSET_HARDSWISH, [Hardswish, hardswish])
116
124
 
117
125
  return pytorch_tpc