mct-nightly 2.0.0.20240408.430__py3-none-any.whl → 2.0.0.20240410.422__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/METADATA +1 -1
- {mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/RECORD +22 -20
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +22 -12
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +17 -15
- model_compression_toolkit/core/common/quantization/quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +5 -4
- model_compression_toolkit/core/common/similarity_analyzer.py +4 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +31 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py +66 -0
- model_compression_toolkit/core/keras/keras_implementation.py +6 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py +69 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +4 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +6 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +6 -2
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +9 -1
- model_compression_toolkit/ptq/keras/quantization_facade.py +22 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +28 -9
- {mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/WHEEL +0 -0
- {mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/top_level.txt +0 -0
{mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/RECORD
RENAMED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
|
1
|
+
model_compression_toolkit/__init__.py,sha256=c33LV9Kt6hpVEoLixt_I5rqhtSzRBPSrdmFEifg-VHU,1573
|
|
2
2
|
model_compression_toolkit/constants.py,sha256=KW_HUEPmQEYqCvWGyORqkYxpvO7w5LViB5J5D-pm_6o,3648
|
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
|
5
5
|
model_compression_toolkit/core/__init__.py,sha256=TrRgkWpT1AN2Faw1M_1HXyJkJnbxfn9p-RigDZl7pg0,1982
|
|
6
|
-
model_compression_toolkit/core/analyzer.py,sha256=
|
|
6
|
+
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
|
7
7
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=Ftqm59hT5TGWmSNkY9bFZkVfCacpGyZfCe-6yZR5WY0,10100
|
|
8
8
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=hFhDkS8GwzXZ7Ho_9qbbb8DAAWs3OONOfMSD5OU_b0o,6153
|
|
9
9
|
model_compression_toolkit/core/runner.py,sha256=NKSC6ujfQPy6dKtJVwxyK2zNDd64eyR5csYy9lBrCPA,11836
|
|
@@ -16,7 +16,7 @@ model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3u
|
|
|
16
16
|
model_compression_toolkit/core/common/model_collector.py,sha256=ofcepKtxc3j2Ouz6BpAKXTzPgjABnpRP47ndmJCXAkk,8352
|
|
17
17
|
model_compression_toolkit/core/common/model_validation.py,sha256=LaG8wd6aZl0OJgieE3SeiVDEPxtk8IHq9-3wSnmWhY4,1214
|
|
18
18
|
model_compression_toolkit/core/common/node_prior_info.py,sha256=WXX_PrGVG9M9I_REG5ZzFBohwmV4yf356sZnrja_FLo,2832
|
|
19
|
-
model_compression_toolkit/core/common/similarity_analyzer.py,sha256=
|
|
19
|
+
model_compression_toolkit/core/common/similarity_analyzer.py,sha256=98l9ttnXHf6VYxBW4852h2CPJKg3A6nLOovpHn-tnKs,8560
|
|
20
20
|
model_compression_toolkit/core/common/user_info.py,sha256=dSRMnT-oewmdOziIpEuW-s9K7vTSeyUBxT4z9neXurI,1648
|
|
21
21
|
model_compression_toolkit/core/common/back2framework/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
22
22
|
model_compression_toolkit/core/common/back2framework/base_model_builder.py,sha256=V1oShKzbSkdcTvREn8VnQQBzvm-tTHkWMXqMkYozF2s,2023
|
|
@@ -100,11 +100,11 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
|
|
|
100
100
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=KYdyfSmjSL4ye24nKlC_c4_AxYb14qoqaeMnZj4-8kE,2257
|
|
101
101
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
|
|
102
102
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
|
|
103
|
-
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=
|
|
104
|
-
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=
|
|
103
|
+
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=TCgpvtfyzFUedv4sZ6sKzsTyikaVl2ixLj_aHPSC2r0,27014
|
|
104
|
+
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=BieZDv9oc-Mc78S_LRMGo-s_2acbqiLE0ewaSE1v2VY,6818
|
|
105
105
|
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=T1nVWdRJfBQ_iuMQYQSIkjfkR-2n3lAOKGAz_rUZZN0,2190
|
|
106
106
|
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
|
|
107
|
-
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=
|
|
107
|
+
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
|
|
108
108
|
model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
|
|
109
109
|
model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=9BEv2l0z2trDEsr40VB8tO3ToBA_b2sd_jH9uqZ5Wo8,11503
|
|
110
110
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
|
|
@@ -142,13 +142,13 @@ model_compression_toolkit/core/common/substitutions/virtual_activation_weights_c
|
|
|
142
142
|
model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=h85L2VlDOqbLd-N98wA3SdYWiblBgSsPceNuLanJd70,4737
|
|
143
143
|
model_compression_toolkit/core/common/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
144
144
|
model_compression_toolkit/core/common/visualization/final_config_visualizer.py,sha256=6I10jKLesB-RQKaXA75Xgz2wPvylQUrnPtCcQZIynGo,6371
|
|
145
|
-
model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=
|
|
145
|
+
model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=HOq7AObkmEZiDSZXUMJDAEJzUY-fSXUT0AMgwiyH7dg,7388
|
|
146
146
|
model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=4E4ZXZmqusGIJ4XQNH8FFt07htAHgT3gy5E7wPIaVBI,21951
|
|
147
147
|
model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
148
148
|
model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
|
|
149
149
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
|
150
150
|
model_compression_toolkit/core/keras/default_framework_info.py,sha256=Ha4HTHuiw_KTS5Po1Xnv6GyK9eprpDhYWf-eooS62Ys,4961
|
|
151
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
|
151
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=RS2UEtZ_anZeDxz7Zv6sNv7v9tFVct6d9KVrUlxTGpo,29309
|
|
152
152
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
|
153
153
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=Aqh31wOPaiZcJIOm-uJwzev0eTMdJyXaOk97rs4z7BU,3879
|
|
154
154
|
model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=Xmk2ZL5CaYdb7iG62HdtZ1F64vap7ffnrsuR3e3G5hc,4851
|
|
@@ -166,6 +166,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/activatio
|
|
|
166
166
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py,sha256=9YCNPiK5BD7tLs1meabPhzfb2VsyPxrZM17zMFsW_Fo,8158
|
|
167
167
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=GR1a3mCZpNUu4WxixJXF_aSm57phAdxaRoHecNx3hxw,3168
|
|
168
168
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=5df_xGfXkqNub4xVRnCWQvSohWqdv12axjJ6edVU2H0,2478
|
|
169
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py,sha256=Hl4LEQ_bw_Vpmf3ZqHujYUqVdvTNsPlEMvr9dZhwg2U,2806
|
|
169
170
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py,sha256=R3U7cjc2E0zheMem16GHygp5jZFGSaomkNOTxTjcAgw,5794
|
|
170
171
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=V6hp67CkS_A3WqdsjLjs0ETtdZAOo4P9mhy4aT7W5FE,5940
|
|
171
172
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=dyhqZrxSTclXyarT2JYnI5WPX0OvWR_CQiwddIr632U,8143
|
|
@@ -209,7 +210,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
|
|
|
209
210
|
model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
|
|
210
211
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
|
|
211
212
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=IoMvTch5awAEPvB6Tg6ANhFGXvfSgv7JLsUBlxpMwk4,4330
|
|
212
|
-
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=
|
|
213
|
+
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=mT4jd8E1saCpAgrsClufQbnVJ0eYn1xaTQ3teALu4jk,27117
|
|
213
214
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
|
|
214
215
|
model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=E6ifk1HdO60k4IRH2EFBzAYWtwUlrGqJoQ66nknpHoQ,4983
|
|
215
216
|
model_compression_toolkit/core/pytorch/utils.py,sha256=dRPiteBg2dBNsHwZyYzXiCIAjnelSoeZZsDXlsTw5JQ,2880
|
|
@@ -228,6 +229,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/__init_
|
|
|
228
229
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_folding.py,sha256=j3q5DzbH3ys5MPFfSOVnAXdD7-g4XEKj2ADrdihVr30,8292
|
|
229
230
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=B7aC2TZNrQJ2oQVGBFhKAVqdUU5lYVJSMmwKhjxOHWk,2822
|
|
230
231
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=JDWOaNwYrZG0zTwd3HwoZUM3tKu7zPbzLOrqNQsu8xA,2162
|
|
232
|
+
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py,sha256=SBrR24ZAnWPftLinv4FuIqdBGjfYtfXbYQJN5mgy5V4,2861
|
|
231
233
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=dYGyb5ebnoeFBF0EaHPQU7CkXvoARdznEEe0laM47LA,3919
|
|
232
234
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py,sha256=iX8bLHtw2osP42-peNLTRmbpX3cUxdGsAbEfw7NLpx0,3935
|
|
233
235
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py,sha256=zKSgtVw_P9fUvdq4e7P9yaLDPG_vZ0cecM9sVPtm1ns,3799
|
|
@@ -336,9 +338,9 @@ model_compression_toolkit/gptq/common/gptq_training.py,sha256=rLA1xlOO-6gWfmc2dL
|
|
|
336
338
|
model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
337
339
|
model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
|
|
338
340
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
|
|
339
|
-
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=
|
|
341
|
+
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=OhYfH6zxRHrRhCde0lbcV9Hu2oeDD9RXh-O8vOPgLbs,18875
|
|
340
342
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=5IvgGlJlgOmQYmldjdCBv7tuzAoY0HazatG5Pedrg0Q,4639
|
|
341
|
-
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=
|
|
343
|
+
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=zAkzWpWP9_aobWgMo_BlUm7-4fR5dHvoGx0sDqs2rZg,14299
|
|
342
344
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
|
343
345
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=2YU-x4-Q5f6hkUJf0tw6vcwdNwRMHdefrFjhhyHYsvA,4782
|
|
344
346
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
|
@@ -353,9 +355,9 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
|
|
|
353
355
|
model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
354
356
|
model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
|
|
355
357
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
|
356
|
-
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=
|
|
358
|
+
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=LN4vOwcMuSSFTSnHDACV9hX_Yd2YIXJRl7WkdODuA0k,16245
|
|
357
359
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=yXJzDd24zfGs2_vfMovxD1WSh1RxXoPxN4GztOf3P5c,3967
|
|
358
|
-
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256
|
|
360
|
+
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=-4USg-tep6EQSArcTxBowhMeAuExrBTNLOWgHFpsIy4,12699
|
|
359
361
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
|
360
362
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=TCA1hAc7raPnrjl06sjFtVM4XUtLtuwAhCGX4U3KGZo,4137
|
|
361
363
|
model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
|
|
@@ -375,9 +377,9 @@ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=cSuvHHCqgr7k9
|
|
|
375
377
|
model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
|
|
376
378
|
model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
|
|
377
379
|
model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
378
|
-
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=
|
|
380
|
+
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=T1_UqXmOc4I2a6IHkQAlFhGtcAYjsXSApMIdRlvgDvg,10154
|
|
379
381
|
model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
380
|
-
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=
|
|
382
|
+
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=eof9bo-Mv_lLY7fFpiVeT5pIde-MuTWkIAqRKH4j9MI,8646
|
|
381
383
|
model_compression_toolkit/qat/__init__.py,sha256=kj2qsZh_Ca7PncsHKcaL5EVT2H8g4hYtvaQ3KFxOkwE,1143
|
|
382
384
|
model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
|
|
383
385
|
model_compression_toolkit/qat/common/qat_config.py,sha256=zoq0Vb74vCY7WlWD8JH_KPrHDoUHSvMc3gcO53u7L2U,3394
|
|
@@ -469,8 +471,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
|
469
471
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
470
472
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
471
473
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=7bbzqJN8ZAycVDvZr_5xC-niTAR5df8f03Kooev_pfg,3047
|
|
472
|
-
mct_nightly-2.0.0.
|
|
473
|
-
mct_nightly-2.0.0.
|
|
474
|
-
mct_nightly-2.0.0.
|
|
475
|
-
mct_nightly-2.0.0.
|
|
476
|
-
mct_nightly-2.0.0.
|
|
474
|
+
mct_nightly-2.0.0.20240410.422.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
475
|
+
mct_nightly-2.0.0.20240410.422.dist-info/METADATA,sha256=Xx2HTbZkpp4O8bS07IXSnaYSh9ZZTxe61I47ovv9fzE,18795
|
|
476
|
+
mct_nightly-2.0.0.20240410.422.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
477
|
+
mct_nightly-2.0.0.20240410.422.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
478
|
+
mct_nightly-2.0.0.20240410.422.dist-info/RECORD,,
|
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
|
27
27
|
from model_compression_toolkit import pruning
|
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
|
29
29
|
|
|
30
|
-
__version__ = "2.0.0.
|
|
30
|
+
__version__ = "2.0.0.20240410.000422"
|
|
@@ -30,7 +30,8 @@ from model_compression_toolkit.logger import Logger
|
|
|
30
30
|
|
|
31
31
|
def analyzer_model_quantization(representative_data_gen: Callable,
|
|
32
32
|
tb_w: TensorboardWriter,
|
|
33
|
-
|
|
33
|
+
float_graph: Graph,
|
|
34
|
+
quantized_graph: Graph,
|
|
34
35
|
fw_impl: FrameworkImplementation,
|
|
35
36
|
fw_info: FrameworkInfo):
|
|
36
37
|
"""
|
|
@@ -41,23 +42,32 @@ def analyzer_model_quantization(representative_data_gen: Callable,
|
|
|
41
42
|
Args:
|
|
42
43
|
representative_data_gen: Dataset used for calibration.
|
|
43
44
|
tb_w: TensorBoardWriter object to log events.
|
|
44
|
-
|
|
45
|
+
float_graph: Graph of float model.
|
|
46
|
+
quantized_graph: Graph of quantized model.
|
|
45
47
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
46
48
|
fw_info: Information needed for quantization about the specific framework.
|
|
47
49
|
|
|
48
50
|
"""
|
|
49
51
|
if tb_w is not None:
|
|
50
|
-
visual = NNVisualizer(
|
|
52
|
+
visual = NNVisualizer(float_graph,
|
|
53
|
+
quantized_graph,
|
|
51
54
|
fw_impl=fw_impl,
|
|
52
55
|
fw_info=fw_info)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
if i >= NUM_SAMPLES_DISTANCE_TENSORBOARD:
|
|
56
|
-
break
|
|
57
|
-
figure = visual.plot_distance_graph(_data,
|
|
58
|
-
distance_fn=compute_cs,
|
|
59
|
-
convert_to_range=lambda a: 1 - 2 * a)
|
|
60
|
-
tb_w.add_figure(figure, f'similarity_distance_sample_{i}')
|
|
56
|
+
if not visual.has_compare_points():
|
|
57
|
+
Logger.error(f'No comparing points were found to plot analyze similarity.')
|
|
61
58
|
else:
|
|
62
|
-
|
|
59
|
+
visualized_samples = 0
|
|
60
|
+
for _data in representative_data_gen():
|
|
61
|
+
batch_size = _data[0].shape[0]
|
|
62
|
+
for sample_index in range(batch_size):
|
|
63
|
+
if visualized_samples >= NUM_SAMPLES_DISTANCE_TENSORBOARD:
|
|
64
|
+
break
|
|
65
|
+
figure = visual.plot_distance_graph(_data,
|
|
66
|
+
sample_index=sample_index,
|
|
67
|
+
distance_fn=compute_cs,
|
|
68
|
+
convert_to_range=lambda a: 1 - 2 * a)
|
|
69
|
+
tb_w.add_figure(figure, f'similarity_distance_sample_{visualized_samples}')
|
|
70
|
+
visualized_samples += 1
|
|
71
|
+
if visualized_samples < NUM_SAMPLES_DISTANCE_TENSORBOARD:
|
|
72
|
+
Logger.error(f'Not enough batches in representative dataset to generate {NUM_SAMPLES_DISTANCE_TENSORBOARD} figures')
|
|
63
73
|
tb_w.close()
|
|
@@ -41,24 +41,24 @@ class BaseNodeQuantizationConfig(object):
|
|
|
41
41
|
Base class for node quantization configuration
|
|
42
42
|
"""
|
|
43
43
|
|
|
44
|
-
def set_quant_config_attr(self,
|
|
44
|
+
def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any,
|
|
45
45
|
*args: List[Any], **kwargs: Dict[str, Any]):
|
|
46
46
|
"""
|
|
47
47
|
Changes a BaseNodeQuantizationConfig's parameter.
|
|
48
48
|
Note that arg and kwargs are only to allow clean override in the child classes.
|
|
49
49
|
|
|
50
50
|
Args:
|
|
51
|
-
|
|
52
|
-
|
|
51
|
+
config_parameter_name: parameter name to change.
|
|
52
|
+
config_parameter_value: parameter value to change.
|
|
53
53
|
args: A list of additional arguments.
|
|
54
54
|
kwargs: A dictionary with additional key arguments.
|
|
55
55
|
|
|
56
56
|
"""
|
|
57
57
|
|
|
58
|
-
if hasattr(self,
|
|
59
|
-
setattr(self,
|
|
58
|
+
if hasattr(self, config_parameter_name):
|
|
59
|
+
setattr(self, config_parameter_name, config_parameter_value)
|
|
60
60
|
else:
|
|
61
|
-
Logger.warning(f"Parameter {
|
|
61
|
+
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config and "
|
|
62
62
|
f"was not updated!")
|
|
63
63
|
|
|
64
64
|
def __repr__(self) -> str:
|
|
@@ -106,6 +106,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
106
106
|
self.z_threshold = qc.z_threshold
|
|
107
107
|
self.shift_negative_ratio = qc.shift_negative_ratio
|
|
108
108
|
self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
|
|
109
|
+
self.concat_threshold_update = qc.concat_threshold_update
|
|
109
110
|
|
|
110
111
|
def quantize_node_output(self,
|
|
111
112
|
tensors: Any) -> Any:
|
|
@@ -219,7 +220,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
219
220
|
self.shift_negative_activation_correction == other.shift_negative_activation_correction and \
|
|
220
221
|
self.z_threshold == other.z_threshold and \
|
|
221
222
|
self.shift_negative_ratio == other.shift_negative_ratio and \
|
|
222
|
-
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
|
|
223
|
+
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
|
|
223
224
|
|
|
224
225
|
def __hash__(self):
|
|
225
226
|
return hash((self.activation_quantization_fn,
|
|
@@ -521,7 +522,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
521
522
|
f"{list(attrs_with_name.keys())}.")
|
|
522
523
|
return attrs_with_name
|
|
523
524
|
|
|
524
|
-
def set_quant_config_attr(self,
|
|
525
|
+
def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any, attr_name: str = None,
|
|
525
526
|
*args: List[Any], **kwargs: Dict[str, Any]):
|
|
526
527
|
"""
|
|
527
528
|
This method overrides the parent class set_quant_config_attr to enable setting a specific weights
|
|
@@ -529,26 +530,27 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
529
530
|
|
|
530
531
|
Args:
|
|
531
532
|
attr_name: attribute name to change.
|
|
532
|
-
|
|
533
|
-
|
|
533
|
+
config_parameter_name: parameter name to change.
|
|
534
|
+
config_parameter_value: parameter value to change.
|
|
534
535
|
args: A list of additional arguments.
|
|
535
536
|
kwargs: A dictionary with additional key arguments.
|
|
536
537
|
|
|
537
538
|
"""
|
|
538
539
|
|
|
539
540
|
if attr_name is None:
|
|
540
|
-
super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(
|
|
541
|
+
super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(config_parameter_name,
|
|
542
|
+
config_parameter_value,
|
|
541
543
|
*args, **kwargs)
|
|
542
544
|
else:
|
|
543
545
|
if self.has_attribute_config(attr_name):
|
|
544
546
|
attr_cfg = self.get_attr_config(attr_name)
|
|
545
|
-
if hasattr(attr_cfg,
|
|
546
|
-
setattr(attr_cfg,
|
|
547
|
+
if hasattr(attr_cfg, config_parameter_name):
|
|
548
|
+
setattr(attr_cfg, config_parameter_name, config_parameter_value)
|
|
547
549
|
else:
|
|
548
|
-
Logger.warning(f"Parameter {
|
|
550
|
+
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
|
|
549
551
|
f"weights attribute {attr_name} and was not updated!")
|
|
550
552
|
else:
|
|
551
|
-
Logger.error(f"Weights attribute {attr_name} could not be found to set parameter {
|
|
553
|
+
Logger.error(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
|
|
552
554
|
|
|
553
555
|
def __eq__(self, other: Any) -> bool:
|
|
554
556
|
"""
|
|
@@ -62,7 +62,8 @@ class QuantizationConfig:
|
|
|
62
62
|
residual_collapsing: bool = True,
|
|
63
63
|
shift_negative_ratio: float = 0.05,
|
|
64
64
|
shift_negative_threshold_recalculation: bool = False,
|
|
65
|
-
shift_negative_params_search: bool = False
|
|
65
|
+
shift_negative_params_search: bool = False,
|
|
66
|
+
concat_threshold_update: bool = False):
|
|
66
67
|
"""
|
|
67
68
|
Class to wrap all different parameters the library quantize the input model according to.
|
|
68
69
|
|
|
@@ -117,6 +118,7 @@ class QuantizationConfig:
|
|
|
117
118
|
self.shift_negative_ratio = shift_negative_ratio
|
|
118
119
|
self.shift_negative_threshold_recalculation = shift_negative_threshold_recalculation
|
|
119
120
|
self.shift_negative_params_search = shift_negative_params_search
|
|
121
|
+
self.concat_threshold_update = concat_threshold_update
|
|
120
122
|
|
|
121
123
|
def __repr__(self):
|
|
122
124
|
return str(self.__dict__)
|
|
@@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.quantization.quantize_node import get
|
|
|
23
23
|
from model_compression_toolkit.logger import Logger
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def quantize_graph_weights(
|
|
26
|
+
def quantize_graph_weights(graph_to_quantize: Graph) -> Graph:
|
|
27
27
|
"""
|
|
28
28
|
Get a graph representing a model, and quantize its nodes' weights.
|
|
29
29
|
Each node is quantized according to the passed framework info and quantization configuration.
|
|
@@ -31,12 +31,13 @@ def quantize_graph_weights(graph: Graph) -> Graph:
|
|
|
31
31
|
is calculated and subtracted from the original node's bias. The graph is quantized in-place.
|
|
32
32
|
|
|
33
33
|
Args:
|
|
34
|
-
|
|
34
|
+
graph_to_quantize: Graph to quantize its nodes.
|
|
35
35
|
|
|
36
36
|
"""
|
|
37
|
+
_quantized_graph = copy.deepcopy(graph_to_quantize)
|
|
37
38
|
# Iterate over nodes in the graph and quantize each node's weights and activations
|
|
38
39
|
# (according to operators groups in framework info).
|
|
39
|
-
for n in
|
|
40
|
+
for n in _quantized_graph.nodes():
|
|
40
41
|
for attr in n.get_node_weights_attributes():
|
|
41
42
|
if n.is_weights_quantization_enabled(attr):
|
|
42
43
|
quantized_attr, io_channels_axes = \
|
|
@@ -51,4 +52,4 @@ def quantize_graph_weights(graph: Graph) -> Graph:
|
|
|
51
52
|
# Set the attribute to be the quantized attribute.
|
|
52
53
|
n.set_weights_by_keys(attr, quantized_attr)
|
|
53
54
|
|
|
54
|
-
return
|
|
55
|
+
return _quantized_graph
|
|
@@ -146,7 +146,10 @@ def compute_mae(float_tensor: np.ndarray,
|
|
|
146
146
|
return error
|
|
147
147
|
|
|
148
148
|
|
|
149
|
-
def compute_cs(float_tensor: np.ndarray,
|
|
149
|
+
def compute_cs(float_tensor: np.ndarray,
|
|
150
|
+
fxp_tensor: np.ndarray,
|
|
151
|
+
eps: float = 1e-8,
|
|
152
|
+
batch: bool = False,
|
|
150
153
|
axis: int = None) -> float:
|
|
151
154
|
"""
|
|
152
155
|
Compute the similarity between two tensor using cosine similarity.
|
|
@@ -12,19 +12,20 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
|
|
15
16
|
from typing import Tuple, List, Callable
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
18
19
|
from matplotlib import pyplot as plt
|
|
19
20
|
from matplotlib.figure import Figure
|
|
20
21
|
|
|
21
|
-
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
|
|
22
22
|
from model_compression_toolkit.core.common import Graph
|
|
23
23
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
24
24
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
25
25
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
26
26
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
|
27
27
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_cs
|
|
28
|
+
from model_compression_toolkit.logger import Logger
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
def _get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str]]:
|
|
@@ -57,17 +58,21 @@ class NNVisualizer:
|
|
|
57
58
|
|
|
58
59
|
def __init__(self,
|
|
59
60
|
graph_float: Graph,
|
|
61
|
+
graph_quantized: Graph,
|
|
60
62
|
fw_impl: FrameworkImplementation,
|
|
61
63
|
fw_info: FrameworkInfo):
|
|
62
64
|
"""
|
|
63
65
|
Initialize a NNVisualizer object.
|
|
64
66
|
Args:
|
|
65
67
|
graph_float: Float version of the graph.
|
|
68
|
+
graph_quantized: Quantized version of the graph.
|
|
69
|
+
fw_impl: Framework implementation with framework-specific methods implementation.
|
|
70
|
+
fw_info: Framework info with framework-specific information.
|
|
66
71
|
|
|
67
72
|
"""
|
|
68
73
|
|
|
69
74
|
self.graph_float = graph_float
|
|
70
|
-
self.graph_quantized =
|
|
75
|
+
self.graph_quantized = graph_quantized
|
|
71
76
|
self.fw_impl = fw_impl
|
|
72
77
|
self.fw_info = fw_info
|
|
73
78
|
|
|
@@ -75,6 +80,16 @@ class NNVisualizer:
|
|
|
75
80
|
self.compare_points, self.compare_points_name = _get_compare_points(self.graph_quantized)
|
|
76
81
|
self.compare_points_float, self.compare_points_name_float = _get_compare_points(self.graph_float)
|
|
77
82
|
|
|
83
|
+
if len(self.compare_points) != len(self.compare_points_float):
|
|
84
|
+
Logger.critical(f"Number of compare points in float and quantized models must be equal but "
|
|
85
|
+
f"num of quantized compare points: {len(self.compare_points)} and "
|
|
86
|
+
f"num of float compare points: {len(self.compare_points_float)}")
|
|
87
|
+
if len(self.compare_points_name) != len(self.compare_points_name_float):
|
|
88
|
+
Logger.critical(f"Number of compare points in float and quantized models must be equal "
|
|
89
|
+
f"but num of quantized compare points: {len(self.compare_points_name)}"
|
|
90
|
+
f" and num of float compare points: "
|
|
91
|
+
f"{len(self.compare_points_name_float)}")
|
|
92
|
+
|
|
78
93
|
self.quantized_model, _ = self.fw_impl.model_builder(self.graph_quantized,
|
|
79
94
|
mode=ModelBuilderMode.QUANTIZED,
|
|
80
95
|
append2output=self.compare_points,
|
|
@@ -85,8 +100,19 @@ class NNVisualizer:
|
|
|
85
100
|
append2output=self.compare_points_float,
|
|
86
101
|
fw_info=self.fw_info)
|
|
87
102
|
|
|
103
|
+
def has_compare_points(self) -> bool:
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
Returns: Whether or not compare points were found.
|
|
107
|
+
|
|
108
|
+
"""
|
|
109
|
+
return len(self.compare_points_float) > 0 and len(self.compare_points) > 0 and len(
|
|
110
|
+
self.compare_points_name_float) > 0 and len(self.compare_points_name) > 0
|
|
111
|
+
|
|
112
|
+
|
|
88
113
|
def plot_distance_graph(self,
|
|
89
114
|
input_image: np.ndarray,
|
|
115
|
+
sample_index: int,
|
|
90
116
|
distance_fn: Callable = compute_cs,
|
|
91
117
|
convert_to_range: Callable = lambda a: a) -> Figure:
|
|
92
118
|
"""
|
|
@@ -95,6 +121,7 @@ class NNVisualizer:
|
|
|
95
121
|
|
|
96
122
|
Args:
|
|
97
123
|
input_image: Image to use as input to the networks.
|
|
124
|
+
sample_index: The index of the sample from input_image to use for comparison.
|
|
98
125
|
distance_fn: Distance function to calculate the distance between two tensors.
|
|
99
126
|
convert_to_range: Optional function to move the distance values into a specific range, e.g., when using
|
|
100
127
|
cosine similarity for distance, use 'lambda a: 1 - 2 * a' to convert the distance values to the range
|
|
@@ -108,7 +135,7 @@ class NNVisualizer:
|
|
|
108
135
|
# to make the difference more noticeable when exists
|
|
109
136
|
new_inputs = []
|
|
110
137
|
for single_input in input_image:
|
|
111
|
-
img = single_input[
|
|
138
|
+
img = single_input[sample_index]
|
|
112
139
|
new_inputs.append(np.expand_dims(img, axis=0))
|
|
113
140
|
|
|
114
141
|
# Get outputs
|
|
@@ -123,7 +150,7 @@ class NNVisualizer:
|
|
|
123
150
|
|
|
124
151
|
# Display the result: distance at every layer's output.
|
|
125
152
|
fig = plt.figure()
|
|
126
|
-
plt.plot(distance_array)
|
|
153
|
+
plt.plot(list(range(len(distance_array))), distance_array)
|
|
127
154
|
eps = 0.5
|
|
128
155
|
y_limits = (min(distance_array) - eps, max(distance_array) + eps)
|
|
129
156
|
plt.ylim(y_limits)
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from tensorflow.keras.layers import Concatenate
|
|
18
|
+
import tensorflow as tf
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.core import common
|
|
21
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
22
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
23
|
+
from model_compression_toolkit.constants import THRESHOLD
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ConcatThresholdUpdate(common.BaseSubstitution):
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
Find concat layers and match their prior layers thresholds unless prior layer outputs to multiple layers.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
"""
|
|
36
|
+
Initialize a threshold_updater object.
|
|
37
|
+
"""
|
|
38
|
+
concatination_node = NodeOperationMatcher(Concatenate) | \
|
|
39
|
+
NodeOperationMatcher(tf.concat)
|
|
40
|
+
super().__init__(matcher_instance=concatination_node)
|
|
41
|
+
|
|
42
|
+
def substitute(self,
|
|
43
|
+
graph: Graph,
|
|
44
|
+
node: BaseNode) -> Graph:
|
|
45
|
+
"""
|
|
46
|
+
Update previous layers thresholds to match concatinations quantization thresholds. No change if
|
|
47
|
+
previous layer outputs to multiple layers. No change in case of uniform quantization.
|
|
48
|
+
No change in case of multiple quantization candidates (mixed precision).
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
graph: Graph we apply the substitution on.
|
|
53
|
+
node: Node refference to edit previous nodes thresholds.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Graph after applying the substitution.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
if len(node.candidates_quantization_cfg) == 1 and THRESHOLD in node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params:
|
|
60
|
+
concat_threshold = node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD]
|
|
61
|
+
prev_nodes = graph.get_prev_nodes(node)
|
|
62
|
+
for prev_node in prev_nodes:
|
|
63
|
+
if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != Concatenate and prev_node.type != tf.concat:
|
|
64
|
+
prev_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] = concat_threshold
|
|
65
|
+
|
|
66
|
+
return graph
|
|
@@ -80,7 +80,8 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.line
|
|
|
80
80
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \
|
|
81
81
|
keras_residual_collapsing
|
|
82
82
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \
|
|
83
|
-
InputScalingWithPad
|
|
83
|
+
InputScalingWithPad
|
|
84
|
+
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.concat_threshold_update import ConcatThresholdUpdate
|
|
84
85
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
|
|
85
86
|
ReLUBoundToPowerOfTwo
|
|
86
87
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.multi_head_attention_decomposition import \
|
|
@@ -300,8 +301,8 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
300
301
|
"""
|
|
301
302
|
return keras_op2d_add_const_collapsing()
|
|
302
303
|
|
|
303
|
-
def get_substitutions_post_statistics_collection(self,
|
|
304
|
-
|
|
304
|
+
def get_substitutions_post_statistics_collection(self,
|
|
305
|
+
quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
|
|
305
306
|
"""
|
|
306
307
|
Return a list of the framework substitutions used after we collect statistics.
|
|
307
308
|
|
|
@@ -317,6 +318,8 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
317
318
|
if quant_config.input_scaling:
|
|
318
319
|
substitutions_list.append(InputScaling())
|
|
319
320
|
substitutions_list.append(InputScalingWithPad())
|
|
321
|
+
if quant_config.concat_threshold_update:
|
|
322
|
+
substitutions_list.append(ConcatThresholdUpdate())
|
|
320
323
|
return substitutions_list
|
|
321
324
|
|
|
322
325
|
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import List
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.core import common
|
|
21
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
22
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
23
|
+
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
24
|
+
from model_compression_toolkit.constants import THRESHOLD
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ConcatThresholdUpdate(common.BaseSubstitution):
|
|
28
|
+
"""
|
|
29
|
+
Find concat layers and match their prior layers thresholds unless prior layer outputs to multiple layers.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def __init__(self):
|
|
34
|
+
"""
|
|
35
|
+
Initialize a threshold_updater object.
|
|
36
|
+
"""
|
|
37
|
+
concatination_node = NodeOperationMatcher(torch.cat) | \
|
|
38
|
+
NodeOperationMatcher(torch.concat)
|
|
39
|
+
super().__init__(matcher_instance=concatination_node)
|
|
40
|
+
|
|
41
|
+
def substitute(self,
|
|
42
|
+
graph: Graph,
|
|
43
|
+
node: BaseNode) -> Graph:
|
|
44
|
+
"""
|
|
45
|
+
Update previous layers thresholds to match concatinations quantization thresholds. No change if
|
|
46
|
+
previous layer outputs to multiple layers. No change in case of uniform quantization.
|
|
47
|
+
No change in case of multiple quantization candidates (mixed precision).
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
graph: Graph we apply the substitution on.
|
|
52
|
+
node: Node refference to edit previous nodes thresholds.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Graph after applying the substitution.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
if len(node.candidates_quantization_cfg) == 1 and THRESHOLD in node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params:
|
|
59
|
+
concat_threshold = node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD]
|
|
60
|
+
prev_nodes = graph.get_prev_nodes(node)
|
|
61
|
+
for prev_node in prev_nodes:
|
|
62
|
+
if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != torch.cat and prev_node.type != torch.concat:
|
|
63
|
+
prev_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] = concat_threshold
|
|
64
|
+
|
|
65
|
+
return graph
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
|
|
@@ -73,6 +73,8 @@ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.vi
|
|
|
73
73
|
VirtualActivationWeightsComposition
|
|
74
74
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.weights_activation_split import \
|
|
75
75
|
WeightsActivationSplit
|
|
76
|
+
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.concat_threshold_update import \
|
|
77
|
+
ConcatThresholdUpdate
|
|
76
78
|
from model_compression_toolkit.core.pytorch.hessian.activation_trace_hessian_calculator_pytorch import \
|
|
77
79
|
ActivationTraceHessianCalculatorPytorch
|
|
78
80
|
from model_compression_toolkit.core.pytorch.hessian.weights_trace_hessian_calculator_pytorch import \
|
|
@@ -302,6 +304,8 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
302
304
|
substitutions_list.append(pytorch_softmax_shift())
|
|
303
305
|
if quant_config.input_scaling:
|
|
304
306
|
Logger.critical('Input scaling is currently not supported for Pytorch.')
|
|
307
|
+
if quant_config.concat_threshold_update:
|
|
308
|
+
substitutions_list.append(ConcatThresholdUpdate())
|
|
305
309
|
return substitutions_list
|
|
306
310
|
|
|
307
311
|
|
|
@@ -337,12 +337,16 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
337
337
|
node = node[0]
|
|
338
338
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
|
|
339
339
|
fw_info=self.fw_info)
|
|
340
|
+
# TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
|
|
341
|
+
# To enable GPTQ for other attributes, this code needs to be modified.
|
|
340
342
|
weights, weight_quant_config, activation_quant_config = \
|
|
341
343
|
layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
|
|
342
344
|
for weight_attr, weight in weights.items():
|
|
343
345
|
node.set_weights_by_keys(weight_attr, weight.numpy())
|
|
344
|
-
for
|
|
345
|
-
node.final_weights_quantization_cfg.set_quant_config_attr(
|
|
346
|
+
for config_parameter_name, config_parameter_value in weight_quant_config.items():
|
|
347
|
+
node.final_weights_quantization_cfg.set_quant_config_attr(config_parameter_name,
|
|
348
|
+
config_parameter_value,
|
|
349
|
+
attr_name=kernel_attribute)
|
|
346
350
|
for config_attr, config_value in activation_quant_config.items():
|
|
347
351
|
node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
|
|
348
352
|
if self.gptq_config.train_bias:
|
|
@@ -12,10 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
import copy
|
|
15
16
|
|
|
16
17
|
from typing import Callable, Tuple
|
|
17
18
|
from packaging import version
|
|
18
19
|
|
|
20
|
+
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
|
|
19
21
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
|
20
22
|
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
|
|
21
23
|
from model_compression_toolkit.logger import Logger
|
|
@@ -210,6 +212,8 @@ if FOUND_TF:
|
|
|
210
212
|
target_resource_utilization=target_resource_utilization,
|
|
211
213
|
tb_w=tb_w)
|
|
212
214
|
|
|
215
|
+
float_graph = copy.deepcopy(tg)
|
|
216
|
+
|
|
213
217
|
tg_gptq = gptq_runner(tg,
|
|
214
218
|
core_config,
|
|
215
219
|
gptq_config,
|
|
@@ -223,7 +227,12 @@ if FOUND_TF:
|
|
|
223
227
|
del hessian_info_service
|
|
224
228
|
|
|
225
229
|
if core_config.debug_config.analyze_similarity:
|
|
226
|
-
analyzer_model_quantization(representative_data_gen,
|
|
230
|
+
analyzer_model_quantization(representative_data_gen,
|
|
231
|
+
tb_w,
|
|
232
|
+
float_graph,
|
|
233
|
+
tg_gptq,
|
|
234
|
+
fw_impl,
|
|
235
|
+
DEFAULT_KERAS_INFO)
|
|
227
236
|
|
|
228
237
|
return get_exportable_keras_model(tg_gptq)
|
|
229
238
|
|
|
@@ -284,12 +284,16 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
284
284
|
node = node[0]
|
|
285
285
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
|
|
286
286
|
fw_info=self.fw_info)
|
|
287
|
+
# TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
|
|
288
|
+
# To enable GPTQ for other attributes, this code needs to be modified.
|
|
287
289
|
weights, weight_quant_config, activation_quant_config = \
|
|
288
290
|
layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
|
|
289
291
|
for weight_attr, weight in weights.items():
|
|
290
292
|
node.set_weights_by_keys(weight_attr, self.fw_impl.to_numpy(weight))
|
|
291
|
-
for
|
|
292
|
-
node.final_weights_quantization_cfg.set_quant_config_attr(
|
|
293
|
+
for config_parameter_name, config_parameter_value in weight_quant_config.items():
|
|
294
|
+
node.final_weights_quantization_cfg.set_quant_config_attr(config_parameter_name,
|
|
295
|
+
config_parameter_value,
|
|
296
|
+
attr_name=kernel_attribute)
|
|
293
297
|
for config_attr, config_value in activation_quant_config.items():
|
|
294
298
|
node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
|
|
295
299
|
if self.gptq_config.train_bias and hasattr(layer.layer, BIAS):
|
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
import copy
|
|
16
|
+
|
|
15
17
|
from typing import Callable
|
|
16
18
|
from model_compression_toolkit.core import common
|
|
17
19
|
from model_compression_toolkit.constants import FOUND_TORCH
|
|
@@ -177,6 +179,7 @@ if FOUND_TORCH:
|
|
|
177
179
|
tpc=target_platform_capabilities,
|
|
178
180
|
target_resource_utilization=target_resource_utilization,
|
|
179
181
|
tb_w=tb_w)
|
|
182
|
+
float_graph = copy.deepcopy(graph)
|
|
180
183
|
|
|
181
184
|
# ---------------------- #
|
|
182
185
|
# GPTQ Runner
|
|
@@ -192,7 +195,12 @@ if FOUND_TORCH:
|
|
|
192
195
|
hessian_info_service=hessian_info_service)
|
|
193
196
|
|
|
194
197
|
if core_config.debug_config.analyze_similarity:
|
|
195
|
-
analyzer_model_quantization(representative_data_gen,
|
|
198
|
+
analyzer_model_quantization(representative_data_gen,
|
|
199
|
+
tb_w,
|
|
200
|
+
float_graph,
|
|
201
|
+
graph_gptq,
|
|
202
|
+
fw_impl,
|
|
203
|
+
DEFAULT_PYTORCH_INFO)
|
|
196
204
|
|
|
197
205
|
return get_exportable_pytorch_model(graph_gptq)
|
|
198
206
|
|
|
@@ -12,11 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
import copy
|
|
15
16
|
|
|
16
17
|
from typing import Callable
|
|
17
18
|
|
|
18
19
|
from model_compression_toolkit.core import CoreConfig
|
|
19
20
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
21
|
+
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
|
|
20
22
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
|
21
23
|
from model_compression_toolkit.logger import Logger
|
|
22
24
|
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
@@ -122,8 +124,8 @@ if FOUND_TF:
|
|
|
122
124
|
if core_config.mixed_precision_enable:
|
|
123
125
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
124
126
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
|
125
|
-
|
|
126
|
-
|
|
127
|
+
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
|
|
128
|
+
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
127
129
|
|
|
128
130
|
tb_w = init_tensorboard_writer(fw_info)
|
|
129
131
|
|
|
@@ -139,15 +141,30 @@ if FOUND_TF:
|
|
|
139
141
|
target_resource_utilization=target_resource_utilization,
|
|
140
142
|
tb_w=tb_w)
|
|
141
143
|
|
|
142
|
-
tg
|
|
144
|
+
# At this point, tg is a graph that went through substitutions (such as BN folding) and is
|
|
145
|
+
# ready for quantization (namely, it holds quantization params, etc.) but the weights are
|
|
146
|
+
# not quantized yet. For this reason, we use it to create a graph that acts as a "float" graph
|
|
147
|
+
# for things like similarity analyzer (because the quantized and float graph should have the same
|
|
148
|
+
# architecture to find the appropriate compare points for similarity computation).
|
|
149
|
+
similarity_baseline_graph = copy.deepcopy(tg)
|
|
150
|
+
|
|
151
|
+
graph_with_stats_correction = ptq_runner(tg,
|
|
152
|
+
representative_data_gen,
|
|
153
|
+
core_config,
|
|
154
|
+
fw_info,
|
|
155
|
+
fw_impl,
|
|
156
|
+
tb_w)
|
|
143
157
|
|
|
144
158
|
if core_config.debug_config.analyze_similarity:
|
|
159
|
+
quantized_graph = quantize_graph_weights(graph_with_stats_correction)
|
|
145
160
|
analyzer_model_quantization(representative_data_gen,
|
|
146
|
-
tb_w,
|
|
161
|
+
tb_w,
|
|
162
|
+
similarity_baseline_graph,
|
|
163
|
+
quantized_graph,
|
|
147
164
|
fw_impl,
|
|
148
165
|
fw_info)
|
|
149
166
|
|
|
150
|
-
return get_exportable_keras_model(
|
|
167
|
+
return get_exportable_keras_model(graph_with_stats_correction)
|
|
151
168
|
|
|
152
169
|
|
|
153
170
|
|
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
import copy
|
|
16
|
+
|
|
15
17
|
from typing import Callable
|
|
16
18
|
|
|
17
19
|
from model_compression_toolkit.core import common
|
|
@@ -26,6 +28,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quant
|
|
|
26
28
|
from model_compression_toolkit.core.runner import core_runner
|
|
27
29
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
28
30
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
31
|
+
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
|
|
29
32
|
|
|
30
33
|
|
|
31
34
|
if FOUND_TORCH:
|
|
@@ -90,14 +93,16 @@ if FOUND_TORCH:
|
|
|
90
93
|
|
|
91
94
|
"""
|
|
92
95
|
|
|
96
|
+
fw_info = DEFAULT_PYTORCH_INFO
|
|
97
|
+
|
|
93
98
|
if core_config.mixed_precision_enable:
|
|
94
99
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
95
100
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
101
|
+
"MixedPrecisionQuantizationConfig. Please use "
|
|
102
|
+
"pytorch_post_training_quantization API, or pass a valid mixed precision "
|
|
103
|
+
"configuration.") # pragma: no cover
|
|
99
104
|
|
|
100
|
-
tb_w = init_tensorboard_writer(
|
|
105
|
+
tb_w = init_tensorboard_writer(fw_info)
|
|
101
106
|
|
|
102
107
|
fw_impl = PytorchImplementation()
|
|
103
108
|
|
|
@@ -105,22 +110,36 @@ if FOUND_TORCH:
|
|
|
105
110
|
tg, bit_widths_config, _ = core_runner(in_model=in_module,
|
|
106
111
|
representative_data_gen=representative_data_gen,
|
|
107
112
|
core_config=core_config,
|
|
108
|
-
fw_info=
|
|
113
|
+
fw_info=fw_info,
|
|
109
114
|
fw_impl=fw_impl,
|
|
110
115
|
tpc=target_platform_capabilities,
|
|
111
116
|
target_resource_utilization=target_resource_utilization,
|
|
112
117
|
tb_w=tb_w)
|
|
113
118
|
|
|
114
|
-
tg
|
|
119
|
+
# At this point, tg is a graph that went through substitutions (such as BN folding) and is
|
|
120
|
+
# ready for quantization (namely, it holds quantization params, etc.) but the weights are
|
|
121
|
+
# not quantized yet. For this reason, we use it to create a graph that acts as a "float" graph
|
|
122
|
+
# for things like similarity analyzer (because the quantized and float graph should have the same
|
|
123
|
+
# architecture to find the appropriate compare points for similarity computation).
|
|
124
|
+
similarity_baseline_graph = copy.deepcopy(tg)
|
|
125
|
+
|
|
126
|
+
graph_with_stats_correction = ptq_runner(tg,
|
|
127
|
+
representative_data_gen,
|
|
128
|
+
core_config,
|
|
129
|
+
fw_info,
|
|
130
|
+
fw_impl,
|
|
131
|
+
tb_w)
|
|
115
132
|
|
|
116
133
|
if core_config.debug_config.analyze_similarity:
|
|
134
|
+
quantized_graph = quantize_graph_weights(graph_with_stats_correction)
|
|
117
135
|
analyzer_model_quantization(representative_data_gen,
|
|
118
136
|
tb_w,
|
|
119
|
-
|
|
137
|
+
similarity_baseline_graph,
|
|
138
|
+
quantized_graph,
|
|
120
139
|
fw_impl,
|
|
121
|
-
|
|
140
|
+
fw_info)
|
|
122
141
|
|
|
123
|
-
return get_exportable_pytorch_model(
|
|
142
|
+
return get_exportable_pytorch_model(graph_with_stats_correction)
|
|
124
143
|
|
|
125
144
|
|
|
126
145
|
else:
|
{mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/LICENSE.md
RENAMED
|
File without changes
|
|
File without changes
|
{mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/top_level.txt
RENAMED
|
File without changes
|