mct-nightly 2.1.0.20240813.141729__py3-none-any.whl → 2.1.0.20240815.452__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {mct_nightly-2.1.0.20240813.141729.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240813.141729.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/RECORD +17 -14
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +80 -0
  5. model_compression_toolkit/core/keras/constants.py +4 -1
  6. model_compression_toolkit/core/keras/default_framework_info.py +3 -2
  7. model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py +241 -0
  8. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +2 -2
  9. model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py +2 -0
  10. model_compression_toolkit/core/keras/graph_substitutions/substitutions/sigmoid_mul_to_swish.py +89 -0
  11. model_compression_toolkit/core/keras/keras_implementation.py +7 -1
  12. model_compression_toolkit/core/pytorch/default_framework_info.py +8 -3
  13. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +5 -3
  14. model_compression_toolkit/core/runner.py +3 -0
  15. {mct_nightly-2.1.0.20240813.141729.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/LICENSE.md +0 -0
  16. {mct_nightly-2.1.0.20240813.141729.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/WHEEL +0 -0
  17. {mct_nightly-2.1.0.20240813.141729.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.1.0.20240813.141729
3
+ Version: 2.1.0.20240815.452
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=4Ah8Ywj3HTi7gOBoXzHjZ7FBnpJqziSb4iuxGWzX9R8,1573
1
+ model_compression_toolkit/__init__.py,sha256=R0Zwbt0JpEgVMFa8F2SnrHQ0xhwmPSq0tvWkS53l3eI,1573
2
2
  model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -8,7 +8,7 @@ model_compression_toolkit/core/__init__.py,sha256=tnDtL9KmT0vsOU27SsJ19TKDEbIH-t
8
8
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
9
9
  model_compression_toolkit/core/graph_prep_runner.py,sha256=7-b7Jd5jBVaXOWg5nSqbEyzBtdaGDbCxs8aqMV6GZ6I,11287
10
10
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=K9eJ7VbB_rpeyxX4yEnorOmSxFW3DkvofzxS6QI8Hp8,6454
11
- model_compression_toolkit/core/runner.py,sha256=XQDNJirZkVJ_FXP72d7tbVc_Tr3Jw0Eqm_kxNHW8kPs,13636
11
+ model_compression_toolkit/core/runner.py,sha256=kiNClmonlaqNI2U72bzGoJUzLxKHLh61iak9-HvsfQM,13880
12
12
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
13
13
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
14
14
  model_compression_toolkit/core/common/framework_implementation.py,sha256=kSg2f7wS7e2EyvX6y0eKfNTTFvVFVrB8lvldJvcPvN8,20724
@@ -63,6 +63,7 @@ model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py,sha256
63
63
  model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,sha256=LLDguK7afsbN742ucLpmJr5TUfTyFpK1vbf2bpVr1v0,882
64
64
  model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=7dKMi5S0zQZ16m8NWn1XIuoXsKuZUg64G4-uK8-j1PQ,5177
65
65
  model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=H8qYkJsk88OszUJo-Zde7vTmWiypLTg9KbbzIZ-hhvM,2812
66
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=klmaMQDeFc3IxRLf6YX4Dw1opFksbLyN10yFHdKAtLo,4875
66
67
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=rppRZJdSCQGiZsd93QxoUIhj51eETvQbuI5JiC2TUeA,4963
67
68
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=pk8HRoShDhiUprBC4m1AFQv1SacS4hOrj0MRdbq-5gY,7556
68
69
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=TTTux4YiOnQqt-2h7Y38959XaDwNZc0eufLMx_yws5U,37578
@@ -150,10 +151,10 @@ model_compression_toolkit/core/common/visualization/final_config_visualizer.py,s
150
151
  model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=HOq7AObkmEZiDSZXUMJDAEJzUY-fSXUT0AMgwiyH7dg,7388
151
152
  model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=1-OQu3RNKXA55qfKG1MPq4JxTzmFeVKFDWv5i3TktRw,23676
152
153
  model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
153
- model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
154
+ model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiOcnLAq1v0MMBCJqMJzzFk,3225
154
155
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
155
- model_compression_toolkit/core/keras/default_framework_info.py,sha256=HcHplb7IcnOTyK2p6uhp3OVG4-RV3RDo9C_4evaIzkQ,4981
156
- model_compression_toolkit/core/keras/keras_implementation.py,sha256=hzNC6wz1gtL2EqmRCMCQYl8AqIDJPu6rdOX6nvPgjCM,30193
156
+ model_compression_toolkit/core/keras/default_framework_info.py,sha256=PYcER89eEXjKtR0T7-2Y4f7cckqoD5OQbpHePoRkMec,5030
157
+ model_compression_toolkit/core/keras/keras_implementation.py,sha256=uOTGpsgH4h9MBduVBp8v7mm2S8njbkC72qvXcrZUjeI,30604
157
158
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
158
159
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
159
160
  model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=s56UIgiPipUQRNd2sd1xW6GFfYNMBmrocRCNtvpYLbY,4977
@@ -172,10 +173,11 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm
172
173
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=GR1a3mCZpNUu4WxixJXF_aSm57phAdxaRoHecNx3hxw,3168
173
174
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=5df_xGfXkqNub4xVRnCWQvSohWqdv12axjJ6edVU2H0,2478
174
175
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py,sha256=Hl4LEQ_bw_Vpmf3ZqHujYUqVdvTNsPlEMvr9dZhwg2U,2806
176
+ model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py,sha256=K2svZ8xKK6LAnV86556AwIKnvIjcEqXjJicjp7KC-zY,11132
175
177
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py,sha256=R3U7cjc2E0zheMem16GHygp5jZFGSaomkNOTxTjcAgw,5794
176
178
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=V6hp67CkS_A3WqdsjLjs0ETtdZAOo4P9mhy4aT7W5FE,5940
177
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=i5kdo6-GJe5j4ZVoBp9irLLqqS_H24izrUvda17laf0,8177
178
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py,sha256=kjwlKtm5yhNgWVVcW6mN-hn7enwAnn_8-TUZvxZBiQs,4112
179
+ model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=AvquvVVVT8-ioeVn-gjqysK4L41L3I7TlNOEDfWjViY,8185
180
+ model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py,sha256=9MZJp4GNTLesWN5uQ5eOQyAHLzLYDAHAjRi-LpNppSc,4257
179
181
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=l9PUREBf4aRwWILiybdteveeUbh7js-i-hLt8Ma0e4c,26771
180
182
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py,sha256=IdKOg6AWZWMcmDbOuNdxetS5_zTarXIIffdYL7JTdvk,3872
181
183
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_identity.py,sha256=z2J2Xk7b_w_fEgJmK87lwwBmEoAZpGxPmsBrR24IkZs,2035
@@ -183,6 +185,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_
183
185
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py,sha256=ryes9y1ie-vjBGso2TeO4EXxVk69Ew3iSAhshPz1Ou4,5542
184
186
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/separableconv_decomposition.py,sha256=TEaHlIbXj_ZjIdT5TmAICD3WLD3u_7g0fLWQcNzTJuM,7941
185
187
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py,sha256=13ejpU2z7c5O2w0Iy_uz3HaBbXVYrsQpEqt0nKErVvg,11169
188
+ model_compression_toolkit/core/keras/graph_substitutions/substitutions/sigmoid_mul_to_swish.py,sha256=4Yf-sIj6oqYENdXs2FRxbvLCI1siDo29XpGb17mISBw,4062
186
189
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/softmax_shift.py,sha256=Qk5seDALj_th9dHJehY7ynZjvFjVfCv_mJ1enA5hX0c,1623
187
190
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=wH9ocMLL725-uUPU-zCxdd8NwT5nyd0ZShmI7iuTwF8,1462
188
191
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/weights_activation_split.py,sha256=rjIheZW7LbSPv9bzMSmC8wl6UUxaTkd4J2IHinObT-Y,1814
@@ -214,7 +217,7 @@ model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_c
214
217
  model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
215
218
  model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
216
219
  model_compression_toolkit/core/pytorch/constants.py,sha256=YwD_joIF0vK8UG2vW1NVvg36pCNWA0vHOXjAgy_XWn0,2794
217
- model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
220
+ model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=-Vls1P_8Ckm_18nnOsmQkZ71SmzHwtQLbQ383Z4Rb-U,4365
218
221
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
219
222
  model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=xmcJyU-rkIDX1a_X9LILzf2Ko2z_4I4xnlHkezKH-2w,27669
220
223
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
@@ -246,7 +249,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_
246
249
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=hAZXzrEinHa-dJHLj39Hy_9Q-13QyO95rtYVSLrhvT8,4915
247
250
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py,sha256=DcJEIkGvBdIMOelNIwaJUZ5UsAHiGnDJPR20I464vWo,2929
248
251
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py,sha256=XFtU9yuBmoZlX0f0mS6otMPWMk-RcWs94XdvvTNhW8Y,3303
249
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=lOPl5zDU3FoR9WmlxO04Pfi65MimK0gmnuHzQJodQdY,10668
252
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=3WCLvPyx7tVkM0rwYhYq-gntCzW9R_DcImR1ucKlPac,10772
250
253
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/softmax_shift.py,sha256=05lV4pIL3hJkZl4JQPV4wk_EFD0eYLG5b8cdzvZk4P8,1588
251
254
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transform_function_call_method.py,sha256=EC9Dvp-_UlpDWnipnf8ds65wh_Y-T8pXAFIwRScWpiY,2044
252
255
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=WmEa8Xjji-_tIbthDxlLAGSr69nWk-YKcHNaVqLa7sg,1375
@@ -528,8 +531,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
528
531
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
529
532
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
530
533
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
531
- mct_nightly-2.1.0.20240813.141729.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
532
- mct_nightly-2.1.0.20240813.141729.dist-info/METADATA,sha256=rKrCpumEkmoyMZLUBKH25qud3SqUEXdhHnBzP9nqdrE,19721
533
- mct_nightly-2.1.0.20240813.141729.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
534
- mct_nightly-2.1.0.20240813.141729.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
535
- mct_nightly-2.1.0.20240813.141729.dist-info/RECORD,,
534
+ mct_nightly-2.1.0.20240815.452.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
535
+ mct_nightly-2.1.0.20240815.452.dist-info/METADATA,sha256=sRuvfW9Die83_at1NPFhuX1I9FZcyEEHNc38yC11mWg,19718
536
+ mct_nightly-2.1.0.20240815.452.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
537
+ mct_nightly-2.1.0.20240815.452.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
538
+ mct_nightly-2.1.0.20240815.452.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.1.0.20240813.141729"
30
+ __version__ = "2.1.0.20240815.000452"
@@ -0,0 +1,80 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+
17
+ from model_compression_toolkit.core import ResourceUtilization, FrameworkInfo
18
+ from model_compression_toolkit.core.common import Graph
19
+ from model_compression_toolkit.logger import Logger
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
21
+
22
+
23
+ def filter_candidates_for_mixed_precision(graph: Graph,
24
+ target_resource_utilization: ResourceUtilization,
25
+ fw_info: FrameworkInfo,
26
+ tpc: TargetPlatformCapabilities):
27
+ """
28
+ Filters out candidates in case of mixed precision search for only weights or activation compression.
29
+ For instance, if running only weights compression - filters out candidates of activation configurable nodes
30
+ such that only a single candidate would remain, with the bitwidth equal to the one defined in the matching layer's
31
+ base config in the TPC.
32
+
33
+ Note" This function modifies the graph inplace!
34
+
35
+ Args:
36
+ graph: A graph representation of the model to be quantized.
37
+ target_resource_utilization: The resource utilization of the target device.
38
+ fw_info: fw_info: Information needed for quantization about the specific framework.
39
+ tpc: TargetPlatformCapabilities object that describes the desired inference target platform.
40
+
41
+ """
42
+
43
+ no_total_restrictions = (target_resource_utilization.total_memory == np.inf and
44
+ target_resource_utilization.bops == np.inf)
45
+
46
+ if target_resource_utilization.weights_memory < np.inf:
47
+ if target_resource_utilization.activation_memory == np.inf and no_total_restrictions:
48
+ # Running mixed precision for weights compression only -
49
+ # filter out candidates activation only configurable node
50
+ weights_conf = graph.get_weights_configurable_nodes(fw_info)
51
+ for n in graph.get_activation_configurable_nodes():
52
+ if n not in weights_conf:
53
+ base_cfg_nbits = n.get_qco(tpc).base_config.activation_n_bits
54
+ filtered_conf = [c for c in n.candidates_quantization_cfg if
55
+ c.activation_quantization_cfg.enable_activation_quantization and
56
+ c.activation_quantization_cfg.activation_n_bits == base_cfg_nbits]
57
+
58
+ if len(filtered_conf) != 1:
59
+ Logger.critical(f"Running weights only mixed precision failed on layer {n.name} with multiple "
60
+ f"activation quantization configurations.") # pragma: no cover
61
+ n.candidates_quantization_cfg = filtered_conf
62
+
63
+ elif target_resource_utilization.activation_memory < np.inf:
64
+ if target_resource_utilization.weights_memory == np.inf and no_total_restrictions:
65
+ # Running mixed precision for activation compression only -
66
+ # filter out candidates weights only configurable node
67
+ activation_conf = graph.get_activation_configurable_nodes()
68
+ for n in graph.get_weights_configurable_nodes(fw_info):
69
+ if n not in activation_conf:
70
+ kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
71
+ base_cfg_nbits = n.get_qco(tpc).base_config.attr_weights_configs_mapping[kernel_attr].weights_n_bits
72
+ filtered_conf = [c for c in n.candidates_quantization_cfg if
73
+ c.weights_quantization_cfg.get_attr_config(
74
+ kernel_attr).enable_weights_quantization and
75
+ c.weights_quantization_cfg.get_attr_config(
76
+ kernel_attr).weights_n_bits == base_cfg_nbits]
77
+ if len(filtered_conf) != 1:
78
+ Logger.critical(f"Running activation only mixed precision failed on layer {n.name} with multiple "
79
+ f"weights quantization configurations.") # pragma: no cover
80
+ n.candidates_quantization_cfg = filtered_conf
@@ -31,7 +31,8 @@ KERNEL_SIZE = 'kernel_size'
31
31
  PADDING = 'padding'
32
32
  GROUPS = 'groups'
33
33
  STRIDES = 'strides'
34
- DILATIONS = 'dilation_rate'
34
+ DILATION_RATE = 'dilation_rate'
35
+ DILATIONS = 'dilations'
35
36
  DATA_FORMAT = 'data_format'
36
37
  LAYER_NAME = 'name'
37
38
  TRAINABLE = 'trainable'
@@ -62,6 +63,7 @@ DEPTHWISE_CONSTRAINT = 'depthwise_constraint'
62
63
  KERNEL_INITIALIZER = 'kernel_initializer'
63
64
  KERNEL_REGULARIZER = 'kernel_regularizer'
64
65
  KERNEL_CONSTRAINT = 'kernel_constraint'
66
+ RATE = 'rate'
65
67
 
66
68
  # functional nodes attributes
67
69
  FUNCTION = 'function'
@@ -71,6 +73,7 @@ F_MATMUL = 'matmul'
71
73
  F_STACK = 'stack'
72
74
  F_STRIDED_SLICE_BEGIN = 'begin_mask'
73
75
  F_STRIDED_SLICE_END = 'end_mask'
76
+ F_SWISH = 'nn.silu'
74
77
 
75
78
  # Layers variables names:
76
79
  KERNEL: str = 'kernel'
@@ -29,7 +29,7 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
29
29
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
30
30
  from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
31
31
  from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
32
- KERNEL, DEPTHWISE_KERNEL
32
+ KERNEL, DEPTHWISE_KERNEL, GELU
33
33
  from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization
34
34
 
35
35
  """
@@ -75,7 +75,8 @@ ACTIVATION2MINMAX = {SOFTMAX: (0, SOFTMAX_THRESHOLD),
75
75
  TANH: (-1, 1),
76
76
  SWISH: (-0.279, None),
77
77
  RELU: (0, None),
78
- SELU: (None, None),
78
+ SELU: (-1.76, None),
79
+ GELU: (-0.17, None),
79
80
  }
80
81
 
81
82
  """
@@ -0,0 +1,241 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+ import tensorflow as tf
18
+ from packaging import version
19
+ if version.parse(tf.__version__) >= version.parse("2.13"):
20
+ from keras.src.layers.core import TFOpLambda
21
+ from keras.src.layers import Conv2D, DepthwiseConv2D
22
+ else:
23
+ from keras.layers.core import TFOpLambda
24
+ from keras.layers import Conv2D, DepthwiseConv2D
25
+ from model_compression_toolkit.logger import Logger
26
+ from model_compression_toolkit.core import common
27
+ from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
28
+ from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
29
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
30
+ from model_compression_toolkit.constants import REUSE, REUSE_GROUP
31
+ from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, FILTERS, PADDING, \
32
+ KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE
33
+
34
+
35
+ def extract_bias_node_data(_node: FunctionalNode, _graph: Graph) -> np.ndarray:
36
+ """
37
+ Check is can extract bias from next node.
38
+
39
+ Args:
40
+ _node: conv node to check for subsequent add\bias_add node to extract bias from.
41
+ _graph: model graph.
42
+
43
+ Returns:
44
+ The bias weight. None if couldn't extract bias.
45
+
46
+ """
47
+ b = None
48
+ next_nodes = _graph.get_next_nodes(_node)
49
+ if len(next_nodes) == 1 and len(_graph.get_prev_nodes(next_nodes[0])) == 1:
50
+ # Found pattern in graph: conv_node->next_node. Check if next node is add\bias_add that can be absorbed as bias.
51
+ if next_nodes[0].is_match_type(tf.add):
52
+ b = next_nodes[0].weights.get(0, next_nodes[0].weights.get(1))
53
+ if b is not None and len(b.shape) != 1:
54
+ # Constant input to Add node (bias) has irregular shape. Expecting a 1-D array.
55
+ b = None # pragma: no cover
56
+ elif next_nodes[0].is_match_type(tf.nn.bias_add):
57
+ # In bias_add, weight is always 1-D array. Extract weight from weights or kwargs.
58
+ if 1 in next_nodes[0].weights:
59
+ b = next_nodes[0].weights[1]
60
+ elif BIAS in _node.op_call_kwargs:
61
+ b = np.array(_node.op_call_kwargs[BIAS], dtype=np.float32)
62
+
63
+ return b
64
+
65
+
66
+ def replace_conv_node(graph: Graph, new_node: BaseNode, old_node: FunctionalNode, remove_add_node: bool):
67
+ """
68
+ Replace in-place a functional conv node (and possibly subsequent add node) with Conv layer.
69
+ Args:
70
+ graph: model Graph.
71
+ new_node: Conv layer node.
72
+ old_node: conv function node.
73
+ remove_add_node: whether to remove subsequent add node or not.
74
+ """
75
+ graph.add_node(new_node)
76
+
77
+ # Replace functional conv node (and potentially add node) with Conv node.
78
+ graph.reconnect_in_edges(old_node, new_node)
79
+ if remove_add_node:
80
+ next_nodes = graph.get_next_nodes(old_node)
81
+ graph.reconnect_out_edges(next_nodes[0], new_node)
82
+ graph.replace_output_node(current_node=next_nodes[0], new_node=new_node)
83
+ graph.remove_edge(old_node, next_nodes[0])
84
+ graph.remove_node(next_nodes[0])
85
+ else:
86
+ graph.reconnect_out_edges(old_node, new_node)
87
+ graph.replace_output_node(current_node=old_node, new_node=new_node)
88
+ graph.remove_node(old_node)
89
+
90
+
91
+ class Conv2dFuncToConv2dLayer(common.BaseSubstitution):
92
+ """
93
+ Substitutes tf.nn.conv2d, tf.compat.v1.nn.conv2d, tf.nn.convolution, tf.compat.v1.nn.convolution functions with a Conv2D layer.
94
+ """
95
+
96
+ def __init__(self):
97
+ """
98
+ Initializes the Conv2dFuncToConv2dLayer substitution matcher instance.
99
+ """
100
+ conv2d_matcher = NodeOperationMatcher(tf.nn.conv2d) | NodeOperationMatcher(tf.compat.v1.nn.conv2d)
101
+ convolution_matcher = NodeOperationMatcher(tf.nn.convolution) | NodeOperationMatcher(tf.compat.v1.nn.convolution)
102
+ super().__init__(matcher_instance=conv2d_matcher | convolution_matcher)
103
+
104
+ def substitute(self,
105
+ graph: Graph,
106
+ conv_func_node: FunctionalNode) -> Graph:
107
+ """
108
+ Substitutes conv functions with a Conv2D layer.
109
+
110
+ Args:
111
+ graph: The graph on which the substitution is applied.
112
+ conv_func_node: The functional node to be replaced.
113
+
114
+ Returns:
115
+ The modified graph after applying the substitution.
116
+ """
117
+
118
+ if 1 in conv_func_node.weights:
119
+ k = conv_func_node.weights[1]
120
+ elif FILTERS in conv_func_node.op_call_kwargs:
121
+ k = np.array(conv_func_node.op_call_kwargs[FILTERS], dtype=np.float32)
122
+ else:
123
+ # Conv weight isn't a constant -> skip substitution.
124
+ return graph # pragma: no cover
125
+
126
+ if len(k.shape) != 4:
127
+ # Conv dimension doesn't match conv2d dimension (K1 x K2 x Cin x Cout) -> skip substitution.
128
+ return graph # pragma: no cover
129
+
130
+ # Check if can extract bias from next node.
131
+ b = extract_bias_node_data(conv_func_node, graph)
132
+
133
+ weights = {KERNEL: k}
134
+ # Create Conv2D layer attributes.
135
+ conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2]}
136
+ if len(conv_func_node.op_call_args) > 0:
137
+ Logger.critical(f"node {conv_func_node.name} expected to have only kwargs but got args={conv_func_node.op_call_args}.") # pragma: no cover
138
+ if STRIDES in conv_func_node.op_call_kwargs:
139
+ strides = conv_func_node.op_call_kwargs[STRIDES]
140
+ if len(strides) == 4:
141
+ if strides[0] > 1 or strides[3] > 1:
142
+ # Non-standard strides -> skip substitution.
143
+ return graph # pragma: no cover
144
+ conv_fw_attr[STRIDES] = strides[1:3]
145
+ else:
146
+ conv_fw_attr[STRIDES] = strides
147
+ if PADDING in conv_func_node.op_call_kwargs:
148
+ padding = conv_func_node.op_call_kwargs[PADDING]
149
+ if not isinstance(padding, str):
150
+ # Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
151
+ return graph # pragma: no cover
152
+ conv_fw_attr[PADDING] = padding
153
+ if DILATIONS in conv_func_node.op_call_kwargs and conv_func_node.op_call_kwargs[DILATIONS] is not None:
154
+ dilations = conv_func_node.op_call_kwargs[DILATIONS]
155
+ if isinstance(dilations, (list, tuple)) and len(dilations) == 4:
156
+ if dilations[0] > 1 or dilations[3] > 1:
157
+ # Non-standard dilations -> skip substitution.
158
+ return graph # pragma: no cover
159
+ conv_fw_attr[DILATION_RATE] = dilations[1:3]
160
+ else:
161
+ conv_fw_attr[DILATION_RATE] = dilations
162
+ if b is None:
163
+ conv_fw_attr[USE_BIAS] = False
164
+ else:
165
+ weights[BIAS] = b
166
+
167
+ _reuse_params = {REUSE: conv_func_node.reuse, REUSE_GROUP: conv_func_node.reuse_group}
168
+ conv_node = BaseNode(conv_func_node.name, conv_fw_attr, conv_func_node.input_shape, conv_func_node.output_shape,
169
+ weights, Conv2D, **_reuse_params)
170
+
171
+ replace_conv_node(graph, conv_node, conv_func_node, remove_add_node=b is not None)
172
+ return graph
173
+
174
+
175
+ class DwConv2dFuncToDwConv2dLayer(common.BaseSubstitution):
176
+ """
177
+ Substitutes tf.nn.depthwise_conv2d & tf.compat.v1.nn.depthwise_conv2d functions with a DepthwiseConv2D layer.
178
+ """
179
+
180
+ def __init__(self):
181
+ """
182
+ Initializes the DwConv2dFuncToDwConv2dLayer substitution matcher.
183
+ """
184
+ matcher = NodeOperationMatcher(tf.nn.depthwise_conv2d) | NodeOperationMatcher(tf.compat.v1.nn.depthwise_conv2d)
185
+ super().__init__(matcher_instance=matcher)
186
+
187
+ def substitute(self,
188
+ graph: Graph,
189
+ dwconv_func_node: FunctionalNode) -> Graph:
190
+ """
191
+ Substitutes dw-conv2d functions with a DepthwiseConv2D layer.
192
+
193
+ Args:
194
+ graph: The graph on which the substitution is applied.
195
+ dwconv_func_node: The DepthwiseConv2D node to be replaced.
196
+
197
+ Returns:
198
+ The modified graph after applying the substitution.
199
+ """
200
+
201
+ if 1 not in dwconv_func_node.weights:
202
+ # Conv weight isn't a constant -> skip substitution.
203
+ return graph # pragma: no cover
204
+
205
+ k = dwconv_func_node.weights[1]
206
+
207
+ # Check is can extract bias from next node.
208
+ b = extract_bias_node_data(dwconv_func_node, graph)
209
+
210
+ weights = {DEPTHWISE_KERNEL: k}
211
+ k_shape = k.shape
212
+ conv_fw_attr = {DEPTH_MULTIPLIER: k_shape[3], KERNEL_SIZE: k_shape[:2]}
213
+ if len(dwconv_func_node.op_call_args) > 0:
214
+ Logger.critical(f"node {dwconv_func_node.name} expected to have only kwargs but got args={dwconv_func_node.op_call_args}.") # pragma: no cover
215
+ if STRIDES in dwconv_func_node.op_call_kwargs:
216
+ strides = dwconv_func_node.op_call_kwargs[STRIDES]
217
+ if strides[0] > 1 or strides[3] > 1:
218
+ # Non-standard strides -> skip substitution.
219
+ return graph # pragma: no cover
220
+ conv_fw_attr[STRIDES] = strides[1:3]
221
+ if PADDING in dwconv_func_node.op_call_kwargs:
222
+ padding = dwconv_func_node.op_call_kwargs[PADDING]
223
+ if not isinstance(padding, str):
224
+ # Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
225
+ return graph # pragma: no cover
226
+ conv_fw_attr[PADDING] = padding
227
+ if RATE in dwconv_func_node.op_call_kwargs and dwconv_func_node.op_call_kwargs[RATE] is not None:
228
+ conv_fw_attr[DILATION_RATE] = dwconv_func_node.op_call_kwargs[RATE]
229
+ elif DILATIONS in dwconv_func_node.op_call_kwargs and dwconv_func_node.op_call_kwargs[DILATIONS] is not None:
230
+ conv_fw_attr[DILATION_RATE] = dwconv_func_node.op_call_kwargs[DILATIONS]
231
+ if b is None:
232
+ conv_fw_attr[USE_BIAS] = False
233
+ else:
234
+ weights[BIAS] = b
235
+
236
+ _reuse_params = {REUSE: dwconv_func_node.reuse, REUSE_GROUP: dwconv_func_node.reuse_group}
237
+ conv_node = BaseNode(dwconv_func_node.name, conv_fw_attr, dwconv_func_node.input_shape, dwconv_func_node.output_shape,
238
+ weights, DepthwiseConv2D, **_reuse_params)
239
+
240
+ replace_conv_node(graph, conv_node, dwconv_func_node, remove_add_node=b is not None)
241
+ return graph
@@ -23,7 +23,7 @@ else:
23
23
  from model_compression_toolkit.core.common import BaseNode
24
24
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, NodeFrameworkAttrMatcher
25
25
  from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing, Op2DAddConstCollapsing
26
- from model_compression_toolkit.core.keras.constants import KERNEL, KERNEL_SIZE, STRIDES, DILATIONS, LINEAR, \
26
+ from model_compression_toolkit.core.keras.constants import KERNEL, KERNEL_SIZE, STRIDES, DILATION_RATE, LINEAR, \
27
27
  ACTIVATION, BIAS, USE_BIAS, LAYER_NAME, FILTERS, PADDING, GROUPS, DATA_FORMAT
28
28
  from model_compression_toolkit.logger import Logger
29
29
 
@@ -122,7 +122,7 @@ def keras_linear_collapsing() -> Conv2DCollapsing:
122
122
  USE_BIAS,
123
123
  STRIDES,
124
124
  PADDING,
125
- DILATIONS,
125
+ DILATION_RATE,
126
126
  GROUPS,
127
127
  FILTERS,
128
128
  data_format_str=DATA_FORMAT,
@@ -65,6 +65,8 @@ class MatmulToDenseSubstitution(common.BaseSubstitution):
65
65
 
66
66
  # read const from matmul inputs
67
67
  w = matmul_node.weights.get(1)
68
+ if w is None:
69
+ w = np.array(matmul_node.op_call_kwargs['b'], dtype=np.float32) if 'b' in matmul_node.op_call_kwargs else None
68
70
  if w is None:
69
71
  Logger.critical(f"Matmul substitution failed: Unable to locate weight for node {matmul_node.name}.") # pragma: no cover
70
72
 
@@ -0,0 +1,89 @@
1
+ # Copyright 2024 Sony Semiconductors Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Tuple, Union
17
+ import numpy as np
18
+ import tensorflow as tf
19
+ from packaging import version
20
+ if version.parse(tf.__version__) >= version.parse("2.13"):
21
+ from keras.src.layers.core import TFOpLambda
22
+ from keras.src.layers import Multiply, Activation
23
+ else:
24
+ from keras.layers.core import TFOpLambda
25
+ from keras.layers import Multiply, Activation
26
+ from model_compression_toolkit.core import common
27
+ from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
28
+ from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
29
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, \
30
+ EdgeMatcher, NodeFrameworkAttrMatcher
31
+ from model_compression_toolkit.constants import REUSE, REUSE_GROUP
32
+ from model_compression_toolkit.core.keras.constants import FUNCTION, F_SWISH, ACTIVATION, SIGMOID
33
+
34
+
35
+ class MulSigmoidToSwish(common.BaseSubstitution):
36
+ """
37
+ Substitutes mul(x, sigmoid(x)) with swish.
38
+ """
39
+
40
+ def __init__(self):
41
+ """
42
+ Initializes the MulSigmoidToSwish substitution matcher instance.
43
+ """
44
+ mul_matcher = NodeOperationMatcher(tf.math.multiply) | NodeOperationMatcher(Multiply)
45
+ activation_sigmoid = NodeOperationMatcher(Activation) & NodeFrameworkAttrMatcher(ACTIVATION, SIGMOID)
46
+ sigmoid_matcher = NodeOperationMatcher(tf.sigmoid) | activation_sigmoid
47
+ super().__init__(matcher_instance=EdgeMatcher(sigmoid_matcher, mul_matcher))
48
+
49
+ def substitute(self,
50
+ graph: Graph,
51
+ sigmoid_mul_edge: Tuple[FunctionalNode, Union[FunctionalNode, BaseNode], int]) -> Graph:
52
+ """
53
+ Substitutes mul(x, sigmoid(x)) with swish.
54
+
55
+ Args:
56
+ graph: The graph on which the substitution is applied.
57
+ sigmoid_mul_edge: edge between sigmoid and multiply nodes
58
+
59
+ Returns:
60
+ The modified graph after applying the substitution.
61
+ """
62
+
63
+ sigmoid_node, mul_node, _ = sigmoid_mul_edge
64
+ if sigmoid_node in [o.node for o in graph.output_nodes]:
65
+ # Sigmoid node in outputs -> Skip substitution.
66
+ return graph
67
+
68
+ input_node = graph.get_prev_nodes(sigmoid_node)[0]
69
+ if len(graph.get_next_nodes(sigmoid_node)) > 1 or input_node not in graph.get_prev_nodes(mul_node):
70
+ # Structure isn't mul(x, sigmoid(x)) -> Skip substitution.
71
+ return graph
72
+ _reuse_params = {REUSE: mul_node.reuse, REUSE_GROUP: mul_node.reuse_group}
73
+ swish_node = FunctionalNode(f'swish__{sigmoid_node.name}_{mul_node.name}', {FUNCTION: F_SWISH},
74
+ sigmoid_node.input_shape, mul_node.output_shape, {}, TFOpLambda,
75
+ op_call_args=[], op_call_kwargs={}, functional_op=tf.nn.silu, **_reuse_params)
76
+
77
+ graph.add_node(swish_node)
78
+
79
+ # Replace functional conv node (and potentially add node) with Conv node.
80
+ graph.reconnect_in_edges(sigmoid_node, swish_node)
81
+ graph.reconnect_out_edges(mul_node, swish_node)
82
+ graph.replace_output_node(current_node=mul_node, new_node=swish_node)
83
+ graph.remove_edge(input_node, mul_node)
84
+ graph.remove_edge(sigmoid_node, mul_node)
85
+ graph.remove_node(sigmoid_node)
86
+ graph.remove_node(mul_node)
87
+
88
+ return graph
89
+
@@ -69,6 +69,9 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.acti
69
69
  ActivationDecomposition
70
70
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.matmul_substitution import \
71
71
  MatmulToDenseSubstitution
72
+ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.sigmoid_mul_to_swish import MulSigmoidToSwish
73
+ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.conv_funcs_to_layer import \
74
+ Conv2dFuncToConv2dLayer, DwConv2dFuncToDwConv2dLayer
72
75
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.softmax_shift import \
73
76
  keras_softmax_shift
74
77
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.batchnorm_folding import \
@@ -242,8 +245,11 @@ class KerasImplementation(FrameworkImplementation):
242
245
  Returns: A list of the framework substitutions used to prepare the graph.
243
246
 
244
247
  """
245
- return [SeparableConvDecomposition(),
248
+ return [MulSigmoidToSwish(),
249
+ SeparableConvDecomposition(),
246
250
  MatmulToDenseSubstitution(),
251
+ Conv2dFuncToConv2dLayer(),
252
+ DwConv2dFuncToDwConv2dLayer(),
247
253
  MultiHeadAttentionDecomposition(),
248
254
  ActivationDecomposition(),
249
255
  DwconvToConv(),
@@ -12,8 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid
16
- from torch.nn.functional import hardsigmoid, relu, relu6, softmax
15
+ from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid, GELU, SELU
16
+ from torch.nn.functional import hardsigmoid, relu, relu6, softmax, gelu, selu
17
17
  from torch.nn import Conv2d, ConvTranspose2d, Linear
18
18
  from torch import sigmoid
19
19
 
@@ -74,7 +74,12 @@ LAYER2MINMAX = {Softmax: (0, SOFTMAX_THRESHOLD),
74
74
  ReLU: (0, None),
75
75
  relu: (0, None),
76
76
  ReLU6: (0, None),
77
- relu6: (0, None)}
77
+ relu6: (0, None),
78
+ GELU: (-0.17, None),
79
+ gelu: (-0.17, None),
80
+ SELU: (-1.76, None),
81
+ selu: (-1.76, None),
82
+ }
78
83
 
79
84
  """
80
85
  Mapping from a QuantizationMethod to an activation quantizer function.
@@ -17,9 +17,9 @@ from typing import Tuple, Any, Callable
17
17
 
18
18
  import numpy as np
19
19
  import torch.nn.functional
20
- from torch.nn import Conv2d, Linear, PReLU, ELU, Hardswish, Dropout, ZeroPad2d, SiLU
20
+ from torch.nn import Conv2d, Linear, PReLU, ELU, Hardswish, Dropout, ZeroPad2d, SiLU, GELU
21
21
  from torch import reshape
22
- from torch.nn.functional import hardswish, silu, prelu, elu
22
+ from torch.nn.functional import hardswish, silu, prelu, elu, gelu
23
23
  from torch.nn.functional import avg_pool2d
24
24
 
25
25
  from model_compression_toolkit.core import CoreConfig, FrameworkInfo
@@ -68,7 +68,9 @@ def shift_negative_activation_node_matchers():
68
68
  NodeOperationMatcher(Hardswish) | \
69
69
  NodeOperationMatcher(hardswish) | \
70
70
  NodeOperationMatcher(SiLU) | \
71
- NodeOperationMatcher(silu)
71
+ NodeOperationMatcher(silu) | \
72
+ NodeOperationMatcher(GELU) | \
73
+ NodeOperationMatcher(gelu)
72
74
 
73
75
  # Match linear layers where we can add a correction.
74
76
  linear_node = NodeOperationMatcher(Conv2d) | \
@@ -27,6 +27,8 @@ from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_
27
27
  SchedulerInfo
28
28
  from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph
29
29
  from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
30
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_candidates_filter import \
31
+ filter_candidates_for_mixed_precision
30
32
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import \
31
33
  requires_mixed_precision
32
34
  from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
@@ -137,6 +139,7 @@ def core_runner(in_model: Any,
137
139
  if core_config.mixed_precision_enable:
138
140
  if core_config.mixed_precision_config.configuration_overwrite is None:
139
141
 
142
+ filter_candidates_for_mixed_precision(graph, target_resource_utilization, fw_info, tpc)
140
143
  bit_widths_config = search_bit_width(tg,
141
144
  fw_info,
142
145
  fw_impl,