mct-nightly 2.1.0.20240814.458__py3-none-any.whl → 2.1.0.20240815.452__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240814.458.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/METADATA +1 -1
- {mct_nightly-2.1.0.20240814.458.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/RECORD +20 -17
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +80 -0
- model_compression_toolkit/core/keras/constants.py +4 -1
- model_compression_toolkit/core/keras/default_framework_info.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py +241 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +2 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py +2 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/sigmoid_mul_to_swish.py +89 -0
- model_compression_toolkit/core/keras/keras_implementation.py +7 -1
- model_compression_toolkit/core/pytorch/default_framework_info.py +8 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +5 -3
- model_compression_toolkit/core/runner.py +3 -0
- model_compression_toolkit/xquant/common/tensorboard_utils.py +6 -4
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +2 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +2 -1
- {mct_nightly-2.1.0.20240814.458.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240814.458.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240814.458.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/top_level.txt +0 -0
{mct_nightly-2.1.0.20240814.458.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=R0Zwbt0JpEgVMFa8F2SnrHQ0xhwmPSq0tvWkS53l3eI,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -8,7 +8,7 @@ model_compression_toolkit/core/__init__.py,sha256=tnDtL9KmT0vsOU27SsJ19TKDEbIH-t
|
|
8
8
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
9
9
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=7-b7Jd5jBVaXOWg5nSqbEyzBtdaGDbCxs8aqMV6GZ6I,11287
|
10
10
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=K9eJ7VbB_rpeyxX4yEnorOmSxFW3DkvofzxS6QI8Hp8,6454
|
11
|
-
model_compression_toolkit/core/runner.py,sha256=
|
11
|
+
model_compression_toolkit/core/runner.py,sha256=kiNClmonlaqNI2U72bzGoJUzLxKHLh61iak9-HvsfQM,13880
|
12
12
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
13
13
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
14
14
|
model_compression_toolkit/core/common/framework_implementation.py,sha256=kSg2f7wS7e2EyvX6y0eKfNTTFvVFVrB8lvldJvcPvN8,20724
|
@@ -63,6 +63,7 @@ model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py,sha256
|
|
63
63
|
model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,sha256=LLDguK7afsbN742ucLpmJr5TUfTyFpK1vbf2bpVr1v0,882
|
64
64
|
model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=7dKMi5S0zQZ16m8NWn1XIuoXsKuZUg64G4-uK8-j1PQ,5177
|
65
65
|
model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=H8qYkJsk88OszUJo-Zde7vTmWiypLTg9KbbzIZ-hhvM,2812
|
66
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=klmaMQDeFc3IxRLf6YX4Dw1opFksbLyN10yFHdKAtLo,4875
|
66
67
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=rppRZJdSCQGiZsd93QxoUIhj51eETvQbuI5JiC2TUeA,4963
|
67
68
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=pk8HRoShDhiUprBC4m1AFQv1SacS4hOrj0MRdbq-5gY,7556
|
68
69
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=TTTux4YiOnQqt-2h7Y38959XaDwNZc0eufLMx_yws5U,37578
|
@@ -150,10 +151,10 @@ model_compression_toolkit/core/common/visualization/final_config_visualizer.py,s
|
|
150
151
|
model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=HOq7AObkmEZiDSZXUMJDAEJzUY-fSXUT0AMgwiyH7dg,7388
|
151
152
|
model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=1-OQu3RNKXA55qfKG1MPq4JxTzmFeVKFDWv5i3TktRw,23676
|
152
153
|
model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
153
|
-
model_compression_toolkit/core/keras/constants.py,sha256=
|
154
|
+
model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiOcnLAq1v0MMBCJqMJzzFk,3225
|
154
155
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
155
|
-
model_compression_toolkit/core/keras/default_framework_info.py,sha256=
|
156
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
156
|
+
model_compression_toolkit/core/keras/default_framework_info.py,sha256=PYcER89eEXjKtR0T7-2Y4f7cckqoD5OQbpHePoRkMec,5030
|
157
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=uOTGpsgH4h9MBduVBp8v7mm2S8njbkC72qvXcrZUjeI,30604
|
157
158
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
158
159
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
|
159
160
|
model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=s56UIgiPipUQRNd2sd1xW6GFfYNMBmrocRCNtvpYLbY,4977
|
@@ -172,10 +173,11 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm
|
|
172
173
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=GR1a3mCZpNUu4WxixJXF_aSm57phAdxaRoHecNx3hxw,3168
|
173
174
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=5df_xGfXkqNub4xVRnCWQvSohWqdv12axjJ6edVU2H0,2478
|
174
175
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py,sha256=Hl4LEQ_bw_Vpmf3ZqHujYUqVdvTNsPlEMvr9dZhwg2U,2806
|
176
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py,sha256=K2svZ8xKK6LAnV86556AwIKnvIjcEqXjJicjp7KC-zY,11132
|
175
177
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py,sha256=R3U7cjc2E0zheMem16GHygp5jZFGSaomkNOTxTjcAgw,5794
|
176
178
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=V6hp67CkS_A3WqdsjLjs0ETtdZAOo4P9mhy4aT7W5FE,5940
|
177
|
-
model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=
|
178
|
-
model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py,sha256=
|
179
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=AvquvVVVT8-ioeVn-gjqysK4L41L3I7TlNOEDfWjViY,8185
|
180
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py,sha256=9MZJp4GNTLesWN5uQ5eOQyAHLzLYDAHAjRi-LpNppSc,4257
|
179
181
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=l9PUREBf4aRwWILiybdteveeUbh7js-i-hLt8Ma0e4c,26771
|
180
182
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py,sha256=IdKOg6AWZWMcmDbOuNdxetS5_zTarXIIffdYL7JTdvk,3872
|
181
183
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_identity.py,sha256=z2J2Xk7b_w_fEgJmK87lwwBmEoAZpGxPmsBrR24IkZs,2035
|
@@ -183,6 +185,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_
|
|
183
185
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py,sha256=ryes9y1ie-vjBGso2TeO4EXxVk69Ew3iSAhshPz1Ou4,5542
|
184
186
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/separableconv_decomposition.py,sha256=TEaHlIbXj_ZjIdT5TmAICD3WLD3u_7g0fLWQcNzTJuM,7941
|
185
187
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py,sha256=13ejpU2z7c5O2w0Iy_uz3HaBbXVYrsQpEqt0nKErVvg,11169
|
188
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/sigmoid_mul_to_swish.py,sha256=4Yf-sIj6oqYENdXs2FRxbvLCI1siDo29XpGb17mISBw,4062
|
186
189
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/softmax_shift.py,sha256=Qk5seDALj_th9dHJehY7ynZjvFjVfCv_mJ1enA5hX0c,1623
|
187
190
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=wH9ocMLL725-uUPU-zCxdd8NwT5nyd0ZShmI7iuTwF8,1462
|
188
191
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/weights_activation_split.py,sha256=rjIheZW7LbSPv9bzMSmC8wl6UUxaTkd4J2IHinObT-Y,1814
|
@@ -214,7 +217,7 @@ model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_c
|
|
214
217
|
model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
215
218
|
model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
216
219
|
model_compression_toolkit/core/pytorch/constants.py,sha256=YwD_joIF0vK8UG2vW1NVvg36pCNWA0vHOXjAgy_XWn0,2794
|
217
|
-
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256
|
220
|
+
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=-Vls1P_8Ckm_18nnOsmQkZ71SmzHwtQLbQ383Z4Rb-U,4365
|
218
221
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
|
219
222
|
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=xmcJyU-rkIDX1a_X9LILzf2Ko2z_4I4xnlHkezKH-2w,27669
|
220
223
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
|
@@ -246,7 +249,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_
|
|
246
249
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=hAZXzrEinHa-dJHLj39Hy_9Q-13QyO95rtYVSLrhvT8,4915
|
247
250
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py,sha256=DcJEIkGvBdIMOelNIwaJUZ5UsAHiGnDJPR20I464vWo,2929
|
248
251
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py,sha256=XFtU9yuBmoZlX0f0mS6otMPWMk-RcWs94XdvvTNhW8Y,3303
|
249
|
-
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=
|
252
|
+
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=3WCLvPyx7tVkM0rwYhYq-gntCzW9R_DcImR1ucKlPac,10772
|
250
253
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/softmax_shift.py,sha256=05lV4pIL3hJkZl4JQPV4wk_EFD0eYLG5b8cdzvZk4P8,1588
|
251
254
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transform_function_call_method.py,sha256=EC9Dvp-_UlpDWnipnf8ds65wh_Y-T8pXAFIwRScWpiY,2044
|
252
255
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=WmEa8Xjji-_tIbthDxlLAGSr69nWk-YKcHNaVqLa7sg,1375
|
@@ -512,7 +515,7 @@ model_compression_toolkit/xquant/common/model_analyzer.py,sha256=T_8OetIQNqR0nkf
|
|
512
515
|
model_compression_toolkit/xquant/common/model_folding_utils.py,sha256=7XMNmsngJgCPVjsuMNt6g4hzhkviB45qUmNRe9jQE7g,4815
|
513
516
|
model_compression_toolkit/xquant/common/similarity_calculator.py,sha256=yCs_vlOThLzq7z-u2PkcEErLj7N7qCBPpRa6_5h34J8,10460
|
514
517
|
model_compression_toolkit/xquant/common/similarity_functions.py,sha256=Atah1otdX9oUUch2JK-p-e291QHtkP_c4DfLG9WWo1Y,2935
|
515
|
-
model_compression_toolkit/xquant/common/tensorboard_utils.py,sha256=
|
518
|
+
model_compression_toolkit/xquant/common/tensorboard_utils.py,sha256=6ZDbGHnCzSxJicWoS60GBd5HTfZuBBw1HkM7rj3Ki5w,6610
|
516
519
|
model_compression_toolkit/xquant/common/xquant_config.py,sha256=Qt56cra2tU1PeHlLx_Cqztf5q-ED8MPelhb8coSumFw,1675
|
517
520
|
model_compression_toolkit/xquant/keras/__init__.py,sha256=zbtceCVRsi-Gvl_pOmq5laqVqu55vAU1ie2FR2RK1Po,709
|
518
521
|
model_compression_toolkit/xquant/keras/dataset_utils.py,sha256=quvVymhvpcPIOneCu5J6K_QAqBHOCIj8IxZxSN2fItA,2258
|
@@ -520,16 +523,16 @@ model_compression_toolkit/xquant/keras/facade_xquant_report.py,sha256=7pf3PUMAj7
|
|
520
523
|
model_compression_toolkit/xquant/keras/keras_report_utils.py,sha256=zUvhqehKKRHEkk6y8g1xQH47b6fTMuPy6stGEZ6mI24,3081
|
521
524
|
model_compression_toolkit/xquant/keras/model_analyzer.py,sha256=WXi9BPI9_TzRWn50lM1i-6cwPPRW0p43Shg_xpHFclU,6521
|
522
525
|
model_compression_toolkit/xquant/keras/similarity_functions.py,sha256=P2qMJAo94Sz_BCao-bnhEeewKtjeLLDDH2r9luDXJ04,2710
|
523
|
-
model_compression_toolkit/xquant/keras/tensorboard_utils.py,sha256=
|
526
|
+
model_compression_toolkit/xquant/keras/tensorboard_utils.py,sha256=h67lf_agZwOuzF37or1YSF1dbTCdw-b3UyvTeRXhTp8,9225
|
524
527
|
model_compression_toolkit/xquant/pytorch/__init__.py,sha256=ycb1Xt7PtixY2Uabr94JGSwBMcct66O8ZMVf3Qa3ud8,719
|
525
528
|
model_compression_toolkit/xquant/pytorch/dataset_utils.py,sha256=KFKiFkhIPpEr1ZH5jekZFrgs20VzzKVxSV9YMgH68yI,2894
|
526
529
|
model_compression_toolkit/xquant/pytorch/facade_xquant_report.py,sha256=sr_7TkmkRE0FhdJ7BwXGLFELmR4l_nK7IlTys6oYgoU,3179
|
527
530
|
model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-ihJBLy5Cic-MQiUM_ZGV6SCXoNdscE,5549
|
528
531
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
|
529
532
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
530
|
-
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=
|
531
|
-
mct_nightly-2.1.0.
|
532
|
-
mct_nightly-2.1.0.
|
533
|
-
mct_nightly-2.1.0.
|
534
|
-
mct_nightly-2.1.0.
|
535
|
-
mct_nightly-2.1.0.
|
533
|
+
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
534
|
+
mct_nightly-2.1.0.20240815.452.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
535
|
+
mct_nightly-2.1.0.20240815.452.dist-info/METADATA,sha256=sRuvfW9Die83_at1NPFhuX1I9FZcyEEHNc38yC11mWg,19718
|
536
|
+
mct_nightly-2.1.0.20240815.452.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
537
|
+
mct_nightly-2.1.0.20240815.452.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
538
|
+
mct_nightly-2.1.0.20240815.452.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.1.0.
|
30
|
+
__version__ = "2.1.0.20240815.000452"
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
import numpy as np
|
16
|
+
|
17
|
+
from model_compression_toolkit.core import ResourceUtilization, FrameworkInfo
|
18
|
+
from model_compression_toolkit.core.common import Graph
|
19
|
+
from model_compression_toolkit.logger import Logger
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
21
|
+
|
22
|
+
|
23
|
+
def filter_candidates_for_mixed_precision(graph: Graph,
|
24
|
+
target_resource_utilization: ResourceUtilization,
|
25
|
+
fw_info: FrameworkInfo,
|
26
|
+
tpc: TargetPlatformCapabilities):
|
27
|
+
"""
|
28
|
+
Filters out candidates in case of mixed precision search for only weights or activation compression.
|
29
|
+
For instance, if running only weights compression - filters out candidates of activation configurable nodes
|
30
|
+
such that only a single candidate would remain, with the bitwidth equal to the one defined in the matching layer's
|
31
|
+
base config in the TPC.
|
32
|
+
|
33
|
+
Note" This function modifies the graph inplace!
|
34
|
+
|
35
|
+
Args:
|
36
|
+
graph: A graph representation of the model to be quantized.
|
37
|
+
target_resource_utilization: The resource utilization of the target device.
|
38
|
+
fw_info: fw_info: Information needed for quantization about the specific framework.
|
39
|
+
tpc: TargetPlatformCapabilities object that describes the desired inference target platform.
|
40
|
+
|
41
|
+
"""
|
42
|
+
|
43
|
+
no_total_restrictions = (target_resource_utilization.total_memory == np.inf and
|
44
|
+
target_resource_utilization.bops == np.inf)
|
45
|
+
|
46
|
+
if target_resource_utilization.weights_memory < np.inf:
|
47
|
+
if target_resource_utilization.activation_memory == np.inf and no_total_restrictions:
|
48
|
+
# Running mixed precision for weights compression only -
|
49
|
+
# filter out candidates activation only configurable node
|
50
|
+
weights_conf = graph.get_weights_configurable_nodes(fw_info)
|
51
|
+
for n in graph.get_activation_configurable_nodes():
|
52
|
+
if n not in weights_conf:
|
53
|
+
base_cfg_nbits = n.get_qco(tpc).base_config.activation_n_bits
|
54
|
+
filtered_conf = [c for c in n.candidates_quantization_cfg if
|
55
|
+
c.activation_quantization_cfg.enable_activation_quantization and
|
56
|
+
c.activation_quantization_cfg.activation_n_bits == base_cfg_nbits]
|
57
|
+
|
58
|
+
if len(filtered_conf) != 1:
|
59
|
+
Logger.critical(f"Running weights only mixed precision failed on layer {n.name} with multiple "
|
60
|
+
f"activation quantization configurations.") # pragma: no cover
|
61
|
+
n.candidates_quantization_cfg = filtered_conf
|
62
|
+
|
63
|
+
elif target_resource_utilization.activation_memory < np.inf:
|
64
|
+
if target_resource_utilization.weights_memory == np.inf and no_total_restrictions:
|
65
|
+
# Running mixed precision for activation compression only -
|
66
|
+
# filter out candidates weights only configurable node
|
67
|
+
activation_conf = graph.get_activation_configurable_nodes()
|
68
|
+
for n in graph.get_weights_configurable_nodes(fw_info):
|
69
|
+
if n not in activation_conf:
|
70
|
+
kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
|
71
|
+
base_cfg_nbits = n.get_qco(tpc).base_config.attr_weights_configs_mapping[kernel_attr].weights_n_bits
|
72
|
+
filtered_conf = [c for c in n.candidates_quantization_cfg if
|
73
|
+
c.weights_quantization_cfg.get_attr_config(
|
74
|
+
kernel_attr).enable_weights_quantization and
|
75
|
+
c.weights_quantization_cfg.get_attr_config(
|
76
|
+
kernel_attr).weights_n_bits == base_cfg_nbits]
|
77
|
+
if len(filtered_conf) != 1:
|
78
|
+
Logger.critical(f"Running activation only mixed precision failed on layer {n.name} with multiple "
|
79
|
+
f"weights quantization configurations.") # pragma: no cover
|
80
|
+
n.candidates_quantization_cfg = filtered_conf
|
@@ -31,7 +31,8 @@ KERNEL_SIZE = 'kernel_size'
|
|
31
31
|
PADDING = 'padding'
|
32
32
|
GROUPS = 'groups'
|
33
33
|
STRIDES = 'strides'
|
34
|
-
|
34
|
+
DILATION_RATE = 'dilation_rate'
|
35
|
+
DILATIONS = 'dilations'
|
35
36
|
DATA_FORMAT = 'data_format'
|
36
37
|
LAYER_NAME = 'name'
|
37
38
|
TRAINABLE = 'trainable'
|
@@ -62,6 +63,7 @@ DEPTHWISE_CONSTRAINT = 'depthwise_constraint'
|
|
62
63
|
KERNEL_INITIALIZER = 'kernel_initializer'
|
63
64
|
KERNEL_REGULARIZER = 'kernel_regularizer'
|
64
65
|
KERNEL_CONSTRAINT = 'kernel_constraint'
|
66
|
+
RATE = 'rate'
|
65
67
|
|
66
68
|
# functional nodes attributes
|
67
69
|
FUNCTION = 'function'
|
@@ -71,6 +73,7 @@ F_MATMUL = 'matmul'
|
|
71
73
|
F_STACK = 'stack'
|
72
74
|
F_STRIDED_SLICE_BEGIN = 'begin_mask'
|
73
75
|
F_STRIDED_SLICE_END = 'end_mask'
|
76
|
+
F_SWISH = 'nn.silu'
|
74
77
|
|
75
78
|
# Layers variables names:
|
76
79
|
KERNEL: str = 'kernel'
|
@@ -29,7 +29,7 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
29
29
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
30
30
|
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
|
31
31
|
from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
|
32
|
-
KERNEL, DEPTHWISE_KERNEL
|
32
|
+
KERNEL, DEPTHWISE_KERNEL, GELU
|
33
33
|
from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization
|
34
34
|
|
35
35
|
"""
|
@@ -75,7 +75,8 @@ ACTIVATION2MINMAX = {SOFTMAX: (0, SOFTMAX_THRESHOLD),
|
|
75
75
|
TANH: (-1, 1),
|
76
76
|
SWISH: (-0.279, None),
|
77
77
|
RELU: (0, None),
|
78
|
-
SELU: (
|
78
|
+
SELU: (-1.76, None),
|
79
|
+
GELU: (-0.17, None),
|
79
80
|
}
|
80
81
|
|
81
82
|
"""
|
@@ -0,0 +1,241 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import tensorflow as tf
|
18
|
+
from packaging import version
|
19
|
+
if version.parse(tf.__version__) >= version.parse("2.13"):
|
20
|
+
from keras.src.layers.core import TFOpLambda
|
21
|
+
from keras.src.layers import Conv2D, DepthwiseConv2D
|
22
|
+
else:
|
23
|
+
from keras.layers.core import TFOpLambda
|
24
|
+
from keras.layers import Conv2D, DepthwiseConv2D
|
25
|
+
from model_compression_toolkit.logger import Logger
|
26
|
+
from model_compression_toolkit.core import common
|
27
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
|
28
|
+
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
29
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
30
|
+
from model_compression_toolkit.constants import REUSE, REUSE_GROUP
|
31
|
+
from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, FILTERS, PADDING, \
|
32
|
+
KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE
|
33
|
+
|
34
|
+
|
35
|
+
def extract_bias_node_data(_node: FunctionalNode, _graph: Graph) -> np.ndarray:
|
36
|
+
"""
|
37
|
+
Check is can extract bias from next node.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
_node: conv node to check for subsequent add\bias_add node to extract bias from.
|
41
|
+
_graph: model graph.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
The bias weight. None if couldn't extract bias.
|
45
|
+
|
46
|
+
"""
|
47
|
+
b = None
|
48
|
+
next_nodes = _graph.get_next_nodes(_node)
|
49
|
+
if len(next_nodes) == 1 and len(_graph.get_prev_nodes(next_nodes[0])) == 1:
|
50
|
+
# Found pattern in graph: conv_node->next_node. Check if next node is add\bias_add that can be absorbed as bias.
|
51
|
+
if next_nodes[0].is_match_type(tf.add):
|
52
|
+
b = next_nodes[0].weights.get(0, next_nodes[0].weights.get(1))
|
53
|
+
if b is not None and len(b.shape) != 1:
|
54
|
+
# Constant input to Add node (bias) has irregular shape. Expecting a 1-D array.
|
55
|
+
b = None # pragma: no cover
|
56
|
+
elif next_nodes[0].is_match_type(tf.nn.bias_add):
|
57
|
+
# In bias_add, weight is always 1-D array. Extract weight from weights or kwargs.
|
58
|
+
if 1 in next_nodes[0].weights:
|
59
|
+
b = next_nodes[0].weights[1]
|
60
|
+
elif BIAS in _node.op_call_kwargs:
|
61
|
+
b = np.array(_node.op_call_kwargs[BIAS], dtype=np.float32)
|
62
|
+
|
63
|
+
return b
|
64
|
+
|
65
|
+
|
66
|
+
def replace_conv_node(graph: Graph, new_node: BaseNode, old_node: FunctionalNode, remove_add_node: bool):
|
67
|
+
"""
|
68
|
+
Replace in-place a functional conv node (and possibly subsequent add node) with Conv layer.
|
69
|
+
Args:
|
70
|
+
graph: model Graph.
|
71
|
+
new_node: Conv layer node.
|
72
|
+
old_node: conv function node.
|
73
|
+
remove_add_node: whether to remove subsequent add node or not.
|
74
|
+
"""
|
75
|
+
graph.add_node(new_node)
|
76
|
+
|
77
|
+
# Replace functional conv node (and potentially add node) with Conv node.
|
78
|
+
graph.reconnect_in_edges(old_node, new_node)
|
79
|
+
if remove_add_node:
|
80
|
+
next_nodes = graph.get_next_nodes(old_node)
|
81
|
+
graph.reconnect_out_edges(next_nodes[0], new_node)
|
82
|
+
graph.replace_output_node(current_node=next_nodes[0], new_node=new_node)
|
83
|
+
graph.remove_edge(old_node, next_nodes[0])
|
84
|
+
graph.remove_node(next_nodes[0])
|
85
|
+
else:
|
86
|
+
graph.reconnect_out_edges(old_node, new_node)
|
87
|
+
graph.replace_output_node(current_node=old_node, new_node=new_node)
|
88
|
+
graph.remove_node(old_node)
|
89
|
+
|
90
|
+
|
91
|
+
class Conv2dFuncToConv2dLayer(common.BaseSubstitution):
|
92
|
+
"""
|
93
|
+
Substitutes tf.nn.conv2d, tf.compat.v1.nn.conv2d, tf.nn.convolution, tf.compat.v1.nn.convolution functions with a Conv2D layer.
|
94
|
+
"""
|
95
|
+
|
96
|
+
def __init__(self):
|
97
|
+
"""
|
98
|
+
Initializes the Conv2dFuncToConv2dLayer substitution matcher instance.
|
99
|
+
"""
|
100
|
+
conv2d_matcher = NodeOperationMatcher(tf.nn.conv2d) | NodeOperationMatcher(tf.compat.v1.nn.conv2d)
|
101
|
+
convolution_matcher = NodeOperationMatcher(tf.nn.convolution) | NodeOperationMatcher(tf.compat.v1.nn.convolution)
|
102
|
+
super().__init__(matcher_instance=conv2d_matcher | convolution_matcher)
|
103
|
+
|
104
|
+
def substitute(self,
|
105
|
+
graph: Graph,
|
106
|
+
conv_func_node: FunctionalNode) -> Graph:
|
107
|
+
"""
|
108
|
+
Substitutes conv functions with a Conv2D layer.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
graph: The graph on which the substitution is applied.
|
112
|
+
conv_func_node: The functional node to be replaced.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
The modified graph after applying the substitution.
|
116
|
+
"""
|
117
|
+
|
118
|
+
if 1 in conv_func_node.weights:
|
119
|
+
k = conv_func_node.weights[1]
|
120
|
+
elif FILTERS in conv_func_node.op_call_kwargs:
|
121
|
+
k = np.array(conv_func_node.op_call_kwargs[FILTERS], dtype=np.float32)
|
122
|
+
else:
|
123
|
+
# Conv weight isn't a constant -> skip substitution.
|
124
|
+
return graph # pragma: no cover
|
125
|
+
|
126
|
+
if len(k.shape) != 4:
|
127
|
+
# Conv dimension doesn't match conv2d dimension (K1 x K2 x Cin x Cout) -> skip substitution.
|
128
|
+
return graph # pragma: no cover
|
129
|
+
|
130
|
+
# Check if can extract bias from next node.
|
131
|
+
b = extract_bias_node_data(conv_func_node, graph)
|
132
|
+
|
133
|
+
weights = {KERNEL: k}
|
134
|
+
# Create Conv2D layer attributes.
|
135
|
+
conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2]}
|
136
|
+
if len(conv_func_node.op_call_args) > 0:
|
137
|
+
Logger.critical(f"node {conv_func_node.name} expected to have only kwargs but got args={conv_func_node.op_call_args}.") # pragma: no cover
|
138
|
+
if STRIDES in conv_func_node.op_call_kwargs:
|
139
|
+
strides = conv_func_node.op_call_kwargs[STRIDES]
|
140
|
+
if len(strides) == 4:
|
141
|
+
if strides[0] > 1 or strides[3] > 1:
|
142
|
+
# Non-standard strides -> skip substitution.
|
143
|
+
return graph # pragma: no cover
|
144
|
+
conv_fw_attr[STRIDES] = strides[1:3]
|
145
|
+
else:
|
146
|
+
conv_fw_attr[STRIDES] = strides
|
147
|
+
if PADDING in conv_func_node.op_call_kwargs:
|
148
|
+
padding = conv_func_node.op_call_kwargs[PADDING]
|
149
|
+
if not isinstance(padding, str):
|
150
|
+
# Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
|
151
|
+
return graph # pragma: no cover
|
152
|
+
conv_fw_attr[PADDING] = padding
|
153
|
+
if DILATIONS in conv_func_node.op_call_kwargs and conv_func_node.op_call_kwargs[DILATIONS] is not None:
|
154
|
+
dilations = conv_func_node.op_call_kwargs[DILATIONS]
|
155
|
+
if isinstance(dilations, (list, tuple)) and len(dilations) == 4:
|
156
|
+
if dilations[0] > 1 or dilations[3] > 1:
|
157
|
+
# Non-standard dilations -> skip substitution.
|
158
|
+
return graph # pragma: no cover
|
159
|
+
conv_fw_attr[DILATION_RATE] = dilations[1:3]
|
160
|
+
else:
|
161
|
+
conv_fw_attr[DILATION_RATE] = dilations
|
162
|
+
if b is None:
|
163
|
+
conv_fw_attr[USE_BIAS] = False
|
164
|
+
else:
|
165
|
+
weights[BIAS] = b
|
166
|
+
|
167
|
+
_reuse_params = {REUSE: conv_func_node.reuse, REUSE_GROUP: conv_func_node.reuse_group}
|
168
|
+
conv_node = BaseNode(conv_func_node.name, conv_fw_attr, conv_func_node.input_shape, conv_func_node.output_shape,
|
169
|
+
weights, Conv2D, **_reuse_params)
|
170
|
+
|
171
|
+
replace_conv_node(graph, conv_node, conv_func_node, remove_add_node=b is not None)
|
172
|
+
return graph
|
173
|
+
|
174
|
+
|
175
|
+
class DwConv2dFuncToDwConv2dLayer(common.BaseSubstitution):
|
176
|
+
"""
|
177
|
+
Substitutes tf.nn.depthwise_conv2d & tf.compat.v1.nn.depthwise_conv2d functions with a DepthwiseConv2D layer.
|
178
|
+
"""
|
179
|
+
|
180
|
+
def __init__(self):
|
181
|
+
"""
|
182
|
+
Initializes the DwConv2dFuncToDwConv2dLayer substitution matcher.
|
183
|
+
"""
|
184
|
+
matcher = NodeOperationMatcher(tf.nn.depthwise_conv2d) | NodeOperationMatcher(tf.compat.v1.nn.depthwise_conv2d)
|
185
|
+
super().__init__(matcher_instance=matcher)
|
186
|
+
|
187
|
+
def substitute(self,
|
188
|
+
graph: Graph,
|
189
|
+
dwconv_func_node: FunctionalNode) -> Graph:
|
190
|
+
"""
|
191
|
+
Substitutes dw-conv2d functions with a DepthwiseConv2D layer.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
graph: The graph on which the substitution is applied.
|
195
|
+
dwconv_func_node: The DepthwiseConv2D node to be replaced.
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
The modified graph after applying the substitution.
|
199
|
+
"""
|
200
|
+
|
201
|
+
if 1 not in dwconv_func_node.weights:
|
202
|
+
# Conv weight isn't a constant -> skip substitution.
|
203
|
+
return graph # pragma: no cover
|
204
|
+
|
205
|
+
k = dwconv_func_node.weights[1]
|
206
|
+
|
207
|
+
# Check is can extract bias from next node.
|
208
|
+
b = extract_bias_node_data(dwconv_func_node, graph)
|
209
|
+
|
210
|
+
weights = {DEPTHWISE_KERNEL: k}
|
211
|
+
k_shape = k.shape
|
212
|
+
conv_fw_attr = {DEPTH_MULTIPLIER: k_shape[3], KERNEL_SIZE: k_shape[:2]}
|
213
|
+
if len(dwconv_func_node.op_call_args) > 0:
|
214
|
+
Logger.critical(f"node {dwconv_func_node.name} expected to have only kwargs but got args={dwconv_func_node.op_call_args}.") # pragma: no cover
|
215
|
+
if STRIDES in dwconv_func_node.op_call_kwargs:
|
216
|
+
strides = dwconv_func_node.op_call_kwargs[STRIDES]
|
217
|
+
if strides[0] > 1 or strides[3] > 1:
|
218
|
+
# Non-standard strides -> skip substitution.
|
219
|
+
return graph # pragma: no cover
|
220
|
+
conv_fw_attr[STRIDES] = strides[1:3]
|
221
|
+
if PADDING in dwconv_func_node.op_call_kwargs:
|
222
|
+
padding = dwconv_func_node.op_call_kwargs[PADDING]
|
223
|
+
if not isinstance(padding, str):
|
224
|
+
# Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
|
225
|
+
return graph # pragma: no cover
|
226
|
+
conv_fw_attr[PADDING] = padding
|
227
|
+
if RATE in dwconv_func_node.op_call_kwargs and dwconv_func_node.op_call_kwargs[RATE] is not None:
|
228
|
+
conv_fw_attr[DILATION_RATE] = dwconv_func_node.op_call_kwargs[RATE]
|
229
|
+
elif DILATIONS in dwconv_func_node.op_call_kwargs and dwconv_func_node.op_call_kwargs[DILATIONS] is not None:
|
230
|
+
conv_fw_attr[DILATION_RATE] = dwconv_func_node.op_call_kwargs[DILATIONS]
|
231
|
+
if b is None:
|
232
|
+
conv_fw_attr[USE_BIAS] = False
|
233
|
+
else:
|
234
|
+
weights[BIAS] = b
|
235
|
+
|
236
|
+
_reuse_params = {REUSE: dwconv_func_node.reuse, REUSE_GROUP: dwconv_func_node.reuse_group}
|
237
|
+
conv_node = BaseNode(dwconv_func_node.name, conv_fw_attr, dwconv_func_node.input_shape, dwconv_func_node.output_shape,
|
238
|
+
weights, DepthwiseConv2D, **_reuse_params)
|
239
|
+
|
240
|
+
replace_conv_node(graph, conv_node, dwconv_func_node, remove_add_node=b is not None)
|
241
|
+
return graph
|
@@ -23,7 +23,7 @@ else:
|
|
23
23
|
from model_compression_toolkit.core.common import BaseNode
|
24
24
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, NodeFrameworkAttrMatcher
|
25
25
|
from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing, Op2DAddConstCollapsing
|
26
|
-
from model_compression_toolkit.core.keras.constants import KERNEL, KERNEL_SIZE, STRIDES,
|
26
|
+
from model_compression_toolkit.core.keras.constants import KERNEL, KERNEL_SIZE, STRIDES, DILATION_RATE, LINEAR, \
|
27
27
|
ACTIVATION, BIAS, USE_BIAS, LAYER_NAME, FILTERS, PADDING, GROUPS, DATA_FORMAT
|
28
28
|
from model_compression_toolkit.logger import Logger
|
29
29
|
|
@@ -122,7 +122,7 @@ def keras_linear_collapsing() -> Conv2DCollapsing:
|
|
122
122
|
USE_BIAS,
|
123
123
|
STRIDES,
|
124
124
|
PADDING,
|
125
|
-
|
125
|
+
DILATION_RATE,
|
126
126
|
GROUPS,
|
127
127
|
FILTERS,
|
128
128
|
data_format_str=DATA_FORMAT,
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py
CHANGED
@@ -65,6 +65,8 @@ class MatmulToDenseSubstitution(common.BaseSubstitution):
|
|
65
65
|
|
66
66
|
# read const from matmul inputs
|
67
67
|
w = matmul_node.weights.get(1)
|
68
|
+
if w is None:
|
69
|
+
w = np.array(matmul_node.op_call_kwargs['b'], dtype=np.float32) if 'b' in matmul_node.op_call_kwargs else None
|
68
70
|
if w is None:
|
69
71
|
Logger.critical(f"Matmul substitution failed: Unable to locate weight for node {matmul_node.name}.") # pragma: no cover
|
70
72
|
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/sigmoid_mul_to_swish.py
ADDED
@@ -0,0 +1,89 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductors Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Tuple, Union
|
17
|
+
import numpy as np
|
18
|
+
import tensorflow as tf
|
19
|
+
from packaging import version
|
20
|
+
if version.parse(tf.__version__) >= version.parse("2.13"):
|
21
|
+
from keras.src.layers.core import TFOpLambda
|
22
|
+
from keras.src.layers import Multiply, Activation
|
23
|
+
else:
|
24
|
+
from keras.layers.core import TFOpLambda
|
25
|
+
from keras.layers import Multiply, Activation
|
26
|
+
from model_compression_toolkit.core import common
|
27
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
|
28
|
+
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
29
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, \
|
30
|
+
EdgeMatcher, NodeFrameworkAttrMatcher
|
31
|
+
from model_compression_toolkit.constants import REUSE, REUSE_GROUP
|
32
|
+
from model_compression_toolkit.core.keras.constants import FUNCTION, F_SWISH, ACTIVATION, SIGMOID
|
33
|
+
|
34
|
+
|
35
|
+
class MulSigmoidToSwish(common.BaseSubstitution):
|
36
|
+
"""
|
37
|
+
Substitutes mul(x, sigmoid(x)) with swish.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(self):
|
41
|
+
"""
|
42
|
+
Initializes the MulSigmoidToSwish substitution matcher instance.
|
43
|
+
"""
|
44
|
+
mul_matcher = NodeOperationMatcher(tf.math.multiply) | NodeOperationMatcher(Multiply)
|
45
|
+
activation_sigmoid = NodeOperationMatcher(Activation) & NodeFrameworkAttrMatcher(ACTIVATION, SIGMOID)
|
46
|
+
sigmoid_matcher = NodeOperationMatcher(tf.sigmoid) | activation_sigmoid
|
47
|
+
super().__init__(matcher_instance=EdgeMatcher(sigmoid_matcher, mul_matcher))
|
48
|
+
|
49
|
+
def substitute(self,
|
50
|
+
graph: Graph,
|
51
|
+
sigmoid_mul_edge: Tuple[FunctionalNode, Union[FunctionalNode, BaseNode], int]) -> Graph:
|
52
|
+
"""
|
53
|
+
Substitutes mul(x, sigmoid(x)) with swish.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
graph: The graph on which the substitution is applied.
|
57
|
+
sigmoid_mul_edge: edge between sigmoid and multiply nodes
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
The modified graph after applying the substitution.
|
61
|
+
"""
|
62
|
+
|
63
|
+
sigmoid_node, mul_node, _ = sigmoid_mul_edge
|
64
|
+
if sigmoid_node in [o.node for o in graph.output_nodes]:
|
65
|
+
# Sigmoid node in outputs -> Skip substitution.
|
66
|
+
return graph
|
67
|
+
|
68
|
+
input_node = graph.get_prev_nodes(sigmoid_node)[0]
|
69
|
+
if len(graph.get_next_nodes(sigmoid_node)) > 1 or input_node not in graph.get_prev_nodes(mul_node):
|
70
|
+
# Structure isn't mul(x, sigmoid(x)) -> Skip substitution.
|
71
|
+
return graph
|
72
|
+
_reuse_params = {REUSE: mul_node.reuse, REUSE_GROUP: mul_node.reuse_group}
|
73
|
+
swish_node = FunctionalNode(f'swish__{sigmoid_node.name}_{mul_node.name}', {FUNCTION: F_SWISH},
|
74
|
+
sigmoid_node.input_shape, mul_node.output_shape, {}, TFOpLambda,
|
75
|
+
op_call_args=[], op_call_kwargs={}, functional_op=tf.nn.silu, **_reuse_params)
|
76
|
+
|
77
|
+
graph.add_node(swish_node)
|
78
|
+
|
79
|
+
# Replace functional conv node (and potentially add node) with Conv node.
|
80
|
+
graph.reconnect_in_edges(sigmoid_node, swish_node)
|
81
|
+
graph.reconnect_out_edges(mul_node, swish_node)
|
82
|
+
graph.replace_output_node(current_node=mul_node, new_node=swish_node)
|
83
|
+
graph.remove_edge(input_node, mul_node)
|
84
|
+
graph.remove_edge(sigmoid_node, mul_node)
|
85
|
+
graph.remove_node(sigmoid_node)
|
86
|
+
graph.remove_node(mul_node)
|
87
|
+
|
88
|
+
return graph
|
89
|
+
|
@@ -69,6 +69,9 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.acti
|
|
69
69
|
ActivationDecomposition
|
70
70
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.matmul_substitution import \
|
71
71
|
MatmulToDenseSubstitution
|
72
|
+
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.sigmoid_mul_to_swish import MulSigmoidToSwish
|
73
|
+
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.conv_funcs_to_layer import \
|
74
|
+
Conv2dFuncToConv2dLayer, DwConv2dFuncToDwConv2dLayer
|
72
75
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.softmax_shift import \
|
73
76
|
keras_softmax_shift
|
74
77
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.batchnorm_folding import \
|
@@ -242,8 +245,11 @@ class KerasImplementation(FrameworkImplementation):
|
|
242
245
|
Returns: A list of the framework substitutions used to prepare the graph.
|
243
246
|
|
244
247
|
"""
|
245
|
-
return [
|
248
|
+
return [MulSigmoidToSwish(),
|
249
|
+
SeparableConvDecomposition(),
|
246
250
|
MatmulToDenseSubstitution(),
|
251
|
+
Conv2dFuncToConv2dLayer(),
|
252
|
+
DwConv2dFuncToDwConv2dLayer(),
|
247
253
|
MultiHeadAttentionDecomposition(),
|
248
254
|
ActivationDecomposition(),
|
249
255
|
DwconvToConv(),
|
@@ -12,8 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid
|
16
|
-
from torch.nn.functional import hardsigmoid, relu, relu6, softmax
|
15
|
+
from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid, GELU, SELU
|
16
|
+
from torch.nn.functional import hardsigmoid, relu, relu6, softmax, gelu, selu
|
17
17
|
from torch.nn import Conv2d, ConvTranspose2d, Linear
|
18
18
|
from torch import sigmoid
|
19
19
|
|
@@ -74,7 +74,12 @@ LAYER2MINMAX = {Softmax: (0, SOFTMAX_THRESHOLD),
|
|
74
74
|
ReLU: (0, None),
|
75
75
|
relu: (0, None),
|
76
76
|
ReLU6: (0, None),
|
77
|
-
relu6: (0, None)
|
77
|
+
relu6: (0, None),
|
78
|
+
GELU: (-0.17, None),
|
79
|
+
gelu: (-0.17, None),
|
80
|
+
SELU: (-1.76, None),
|
81
|
+
selu: (-1.76, None),
|
82
|
+
}
|
78
83
|
|
79
84
|
"""
|
80
85
|
Mapping from a QuantizationMethod to an activation quantizer function.
|
@@ -17,9 +17,9 @@ from typing import Tuple, Any, Callable
|
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
import torch.nn.functional
|
20
|
-
from torch.nn import Conv2d, Linear, PReLU, ELU, Hardswish, Dropout, ZeroPad2d, SiLU
|
20
|
+
from torch.nn import Conv2d, Linear, PReLU, ELU, Hardswish, Dropout, ZeroPad2d, SiLU, GELU
|
21
21
|
from torch import reshape
|
22
|
-
from torch.nn.functional import hardswish, silu, prelu, elu
|
22
|
+
from torch.nn.functional import hardswish, silu, prelu, elu, gelu
|
23
23
|
from torch.nn.functional import avg_pool2d
|
24
24
|
|
25
25
|
from model_compression_toolkit.core import CoreConfig, FrameworkInfo
|
@@ -68,7 +68,9 @@ def shift_negative_activation_node_matchers():
|
|
68
68
|
NodeOperationMatcher(Hardswish) | \
|
69
69
|
NodeOperationMatcher(hardswish) | \
|
70
70
|
NodeOperationMatcher(SiLU) | \
|
71
|
-
NodeOperationMatcher(silu)
|
71
|
+
NodeOperationMatcher(silu) | \
|
72
|
+
NodeOperationMatcher(GELU) | \
|
73
|
+
NodeOperationMatcher(gelu)
|
72
74
|
|
73
75
|
# Match linear layers where we can add a correction.
|
74
76
|
linear_node = NodeOperationMatcher(Conv2d) | \
|
@@ -27,6 +27,8 @@ from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_
|
|
27
27
|
SchedulerInfo
|
28
28
|
from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph
|
29
29
|
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
|
30
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_candidates_filter import \
|
31
|
+
filter_candidates_for_mixed_precision
|
30
32
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import \
|
31
33
|
requires_mixed_precision
|
32
34
|
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
|
@@ -137,6 +139,7 @@ def core_runner(in_model: Any,
|
|
137
139
|
if core_config.mixed_precision_enable:
|
138
140
|
if core_config.mixed_precision_config.configuration_overwrite is None:
|
139
141
|
|
142
|
+
filter_candidates_for_mixed_precision(graph, target_resource_utilization, fw_info, tpc)
|
140
143
|
bit_widths_config = search_bit_width(tg,
|
141
144
|
fw_info,
|
142
145
|
fw_impl,
|
@@ -115,12 +115,14 @@ class TensorboardUtils:
|
|
115
115
|
similarity_metrics (Dict[str, Dict[str, float]]): A dictionary containing similarity metrics between quantized and float models for both representative and validation datasets.
|
116
116
|
quantized_model_metadata (Dict): Metadata from the quantized model.
|
117
117
|
"""
|
118
|
-
# Add the computed max cut
|
119
|
-
maxcut_str = f"MaxCut: {quantized_model_metadata['scheduling_info'][MAX_CUT]}"
|
120
|
-
self.tb_writer.add_text(maxcut_str, MAX_CUT)
|
121
|
-
|
122
118
|
# Add output similarity between quantized and float models on representative and validation datasets
|
123
119
|
output_similarity_repr = f"Similarity Metrics on outputs using representative dataset: \n" + "\n".join([f"{key}: {value:.4f}" for key, value in similarity_metrics[OUTPUT_SIMILARITY_METRICS_REPR].items()])
|
124
120
|
output_similarity_val = f"Similarity Metrics on outputs using validation dataset: \n" + "\n".join([f"{key}: {value:.4f}" for key, value in similarity_metrics[OUTPUT_SIMILARITY_METRICS_VAL].items()])
|
125
121
|
self.tb_writer.add_text(output_similarity_repr, OUTPUT_SIMILARITY_METRICS_REPR)
|
126
122
|
self.tb_writer.add_text(output_similarity_val, OUTPUT_SIMILARITY_METRICS_VAL)
|
123
|
+
|
124
|
+
# Add the max cut if it was computed
|
125
|
+
if 'scheduling_info' in quantized_model_metadata:
|
126
|
+
maxcut_str = f"MaxCut: {quantized_model_metadata['scheduling_info'][MAX_CUT]}"
|
127
|
+
self.tb_writer.add_text(maxcut_str, MAX_CUT)
|
128
|
+
|
@@ -76,7 +76,8 @@ class KerasTensorboardUtils(TensorboardUtils):
|
|
76
76
|
# Read the quantized model into a graph structure.
|
77
77
|
quant_graph = model_reader(quantized_model)
|
78
78
|
|
79
|
-
|
79
|
+
if 'scheduling_info' in quantized_model_metadata:
|
80
|
+
insert_cut_info_into_graph(quant_graph, quantized_model_metadata)
|
80
81
|
|
81
82
|
# Iterate over each node in the graph.
|
82
83
|
for node in quant_graph.nodes:
|
@@ -79,7 +79,8 @@ class PytorchTensorboardUtils(TensorboardUtils):
|
|
79
79
|
to_tensor=self.fw_impl.to_tensor,
|
80
80
|
to_numpy=self.fw_impl.to_numpy)
|
81
81
|
|
82
|
-
|
82
|
+
if 'scheduling_info' in quantized_model_metadata:
|
83
|
+
insert_cut_info_into_graph(quant_graph, quantized_model_metadata, quantized_model)
|
83
84
|
|
84
85
|
# Iterate through each node in the graph
|
85
86
|
for node in quant_graph.nodes:
|
{mct_nightly-2.1.0.20240814.458.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.1.0.20240814.458.dist-info → mct_nightly-2.1.0.20240815.452.dist-info}/top_level.txt
RENAMED
File without changes
|