mct-nightly 2.0.0.20240408.430__py3-none-any.whl → 2.0.0.20240410.422__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/RECORD +22 -20
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +22 -12
  5. model_compression_toolkit/core/common/quantization/node_quantization_config.py +17 -15
  6. model_compression_toolkit/core/common/quantization/quantization_config.py +3 -1
  7. model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +5 -4
  8. model_compression_toolkit/core/common/similarity_analyzer.py +4 -1
  9. model_compression_toolkit/core/common/visualization/nn_visualizer.py +31 -4
  10. model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py +66 -0
  11. model_compression_toolkit/core/keras/keras_implementation.py +6 -3
  12. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py +69 -0
  13. model_compression_toolkit/core/pytorch/pytorch_implementation.py +4 -0
  14. model_compression_toolkit/gptq/keras/gptq_training.py +6 -2
  15. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -1
  16. model_compression_toolkit/gptq/pytorch/gptq_training.py +6 -2
  17. model_compression_toolkit/gptq/pytorch/quantization_facade.py +9 -1
  18. model_compression_toolkit/ptq/keras/quantization_facade.py +22 -5
  19. model_compression_toolkit/ptq/pytorch/quantization_facade.py +28 -9
  20. {mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/LICENSE.md +0 -0
  21. {mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/WHEEL +0 -0
  22. {mct_nightly-2.0.0.20240408.430.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.0.0.20240408.430
3
+ Version: 2.0.0.20240410.422
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,9 +1,9 @@
1
- model_compression_toolkit/__init__.py,sha256=anO19rotJl4O84TFH2Q-g_vw5me9I3_WDoj1NZgiLbo,1573
1
+ model_compression_toolkit/__init__.py,sha256=c33LV9Kt6hpVEoLixt_I5rqhtSzRBPSrdmFEifg-VHU,1573
2
2
  model_compression_toolkit/constants.py,sha256=KW_HUEPmQEYqCvWGyORqkYxpvO7w5LViB5J5D-pm_6o,3648
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
5
5
  model_compression_toolkit/core/__init__.py,sha256=TrRgkWpT1AN2Faw1M_1HXyJkJnbxfn9p-RigDZl7pg0,1982
6
- model_compression_toolkit/core/analyzer.py,sha256=dbsD61pakp_9JXNyAScLdtJvcXny9jr_cMbET0Bd3Sg,2975
6
+ model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
7
7
  model_compression_toolkit/core/graph_prep_runner.py,sha256=Ftqm59hT5TGWmSNkY9bFZkVfCacpGyZfCe-6yZR5WY0,10100
8
8
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=hFhDkS8GwzXZ7Ho_9qbbb8DAAWs3OONOfMSD5OU_b0o,6153
9
9
  model_compression_toolkit/core/runner.py,sha256=NKSC6ujfQPy6dKtJVwxyK2zNDd64eyR5csYy9lBrCPA,11836
@@ -16,7 +16,7 @@ model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3u
16
16
  model_compression_toolkit/core/common/model_collector.py,sha256=ofcepKtxc3j2Ouz6BpAKXTzPgjABnpRP47ndmJCXAkk,8352
17
17
  model_compression_toolkit/core/common/model_validation.py,sha256=LaG8wd6aZl0OJgieE3SeiVDEPxtk8IHq9-3wSnmWhY4,1214
18
18
  model_compression_toolkit/core/common/node_prior_info.py,sha256=WXX_PrGVG9M9I_REG5ZzFBohwmV4yf356sZnrja_FLo,2832
19
- model_compression_toolkit/core/common/similarity_analyzer.py,sha256=2w-q7guEb5bpLY4Vk_TMjR8TzLYEymR3tPFlrVq7K68,8515
19
+ model_compression_toolkit/core/common/similarity_analyzer.py,sha256=98l9ttnXHf6VYxBW4852h2CPJKg3A6nLOovpHn-tnKs,8560
20
20
  model_compression_toolkit/core/common/user_info.py,sha256=dSRMnT-oewmdOziIpEuW-s9K7vTSeyUBxT4z9neXurI,1648
21
21
  model_compression_toolkit/core/common/back2framework/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
22
22
  model_compression_toolkit/core/common/back2framework/base_model_builder.py,sha256=V1oShKzbSkdcTvREn8VnQQBzvm-tTHkWMXqMkYozF2s,2023
@@ -100,11 +100,11 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
100
100
  model_compression_toolkit/core/common/quantization/core_config.py,sha256=KYdyfSmjSL4ye24nKlC_c4_AxYb14qoqaeMnZj4-8kE,2257
101
101
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
102
102
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
103
- model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=lXONxIOSvYMgkN9M1st4tV1V5JSpijUGxF0hZWRvtUI,26737
104
- model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=hQMKm55EXS1oV-Upt6IQtsYhpuhMvYeWRJhh6lhv_Ko,6699
103
+ model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=TCgpvtfyzFUedv4sZ6sKzsTyikaVl2ixLj_aHPSC2r0,27014
104
+ model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=BieZDv9oc-Mc78S_LRMGo-s_2acbqiLE0ewaSE1v2VY,6818
105
105
  model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=T1nVWdRJfBQ_iuMQYQSIkjfkR-2n3lAOKGAz_rUZZN0,2190
106
106
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
107
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=xnM9O9LshYw3dprqfsnK9mw7ipOEAkI85o20auyfswg,2626
107
+ model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
108
108
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
109
109
  model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=9BEv2l0z2trDEsr40VB8tO3ToBA_b2sd_jH9uqZ5Wo8,11503
110
110
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
@@ -142,13 +142,13 @@ model_compression_toolkit/core/common/substitutions/virtual_activation_weights_c
142
142
  model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=h85L2VlDOqbLd-N98wA3SdYWiblBgSsPceNuLanJd70,4737
143
143
  model_compression_toolkit/core/common/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
144
144
  model_compression_toolkit/core/common/visualization/final_config_visualizer.py,sha256=6I10jKLesB-RQKaXA75Xgz2wPvylQUrnPtCcQZIynGo,6371
145
- model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=jjsOng-fLxeQFQNshIsOu_w1d5a3fJ359Hcnt85Te-o,5921
145
+ model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=HOq7AObkmEZiDSZXUMJDAEJzUY-fSXUT0AMgwiyH7dg,7388
146
146
  model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=4E4ZXZmqusGIJ4XQNH8FFt07htAHgT3gy5E7wPIaVBI,21951
147
147
  model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
148
148
  model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
149
149
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
150
150
  model_compression_toolkit/core/keras/default_framework_info.py,sha256=Ha4HTHuiw_KTS5Po1Xnv6GyK9eprpDhYWf-eooS62Ys,4961
151
- model_compression_toolkit/core/keras/keras_implementation.py,sha256=NDHLl19I-xQrQGcsAwTcFjnIjCRn31xaPrqDYm8g_dg,29027
151
+ model_compression_toolkit/core/keras/keras_implementation.py,sha256=RS2UEtZ_anZeDxz7Zv6sNv7v9tFVct6d9KVrUlxTGpo,29309
152
152
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
153
153
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=Aqh31wOPaiZcJIOm-uJwzev0eTMdJyXaOk97rs4z7BU,3879
154
154
  model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=Xmk2ZL5CaYdb7iG62HdtZ1F64vap7ffnrsuR3e3G5hc,4851
@@ -166,6 +166,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/activatio
166
166
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py,sha256=9YCNPiK5BD7tLs1meabPhzfb2VsyPxrZM17zMFsW_Fo,8158
167
167
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=GR1a3mCZpNUu4WxixJXF_aSm57phAdxaRoHecNx3hxw,3168
168
168
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=5df_xGfXkqNub4xVRnCWQvSohWqdv12axjJ6edVU2H0,2478
169
+ model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py,sha256=Hl4LEQ_bw_Vpmf3ZqHujYUqVdvTNsPlEMvr9dZhwg2U,2806
169
170
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py,sha256=R3U7cjc2E0zheMem16GHygp5jZFGSaomkNOTxTjcAgw,5794
170
171
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=V6hp67CkS_A3WqdsjLjs0ETtdZAOo4P9mhy4aT7W5FE,5940
171
172
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=dyhqZrxSTclXyarT2JYnI5WPX0OvWR_CQiwddIr632U,8143
@@ -209,7 +210,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
209
210
  model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
210
211
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
211
212
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=IoMvTch5awAEPvB6Tg6ANhFGXvfSgv7JLsUBlxpMwk4,4330
212
- model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=dakO4Nj-tFfs53y6dJyXbpoljx2n3ZqmMoB4CFWGNSQ,26868
213
+ model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=mT4jd8E1saCpAgrsClufQbnVJ0eYn1xaTQ3teALu4jk,27117
213
214
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
214
215
  model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=E6ifk1HdO60k4IRH2EFBzAYWtwUlrGqJoQ66nknpHoQ,4983
215
216
  model_compression_toolkit/core/pytorch/utils.py,sha256=dRPiteBg2dBNsHwZyYzXiCIAjnelSoeZZsDXlsTw5JQ,2880
@@ -228,6 +229,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/__init_
228
229
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_folding.py,sha256=j3q5DzbH3ys5MPFfSOVnAXdD7-g4XEKj2ADrdihVr30,8292
229
230
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=B7aC2TZNrQJ2oQVGBFhKAVqdUU5lYVJSMmwKhjxOHWk,2822
230
231
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=JDWOaNwYrZG0zTwd3HwoZUM3tKu7zPbzLOrqNQsu8xA,2162
232
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py,sha256=SBrR24ZAnWPftLinv4FuIqdBGjfYtfXbYQJN5mgy5V4,2861
231
233
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=dYGyb5ebnoeFBF0EaHPQU7CkXvoARdznEEe0laM47LA,3919
232
234
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py,sha256=iX8bLHtw2osP42-peNLTRmbpX3cUxdGsAbEfw7NLpx0,3935
233
235
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py,sha256=zKSgtVw_P9fUvdq4e7P9yaLDPG_vZ0cecM9sVPtm1ns,3799
@@ -336,9 +338,9 @@ model_compression_toolkit/gptq/common/gptq_training.py,sha256=rLA1xlOO-6gWfmc2dL
336
338
  model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
337
339
  model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
338
340
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
339
- model_compression_toolkit/gptq/keras/gptq_training.py,sha256=6TseqtBzZkLqyc3hiRVdA1dv01us6Y_Su05CBboGjjc,18438
341
+ model_compression_toolkit/gptq/keras/gptq_training.py,sha256=OhYfH6zxRHrRhCde0lbcV9Hu2oeDD9RXh-O8vOPgLbs,18875
340
342
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=5IvgGlJlgOmQYmldjdCBv7tuzAoY0HazatG5Pedrg0Q,4639
341
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=yRgiikyeELar4jlsdcf5pO9HQcxiyhKiAXY3lsMixew,13913
343
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=zAkzWpWP9_aobWgMo_BlUm7-4fR5dHvoGx0sDqs2rZg,14299
342
344
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
343
345
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=2YU-x4-Q5f6hkUJf0tw6vcwdNwRMHdefrFjhhyHYsvA,4782
344
346
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -353,9 +355,9 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
353
355
  model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
354
356
  model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
355
357
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
356
- model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=ksFXtepAsk-66xk7OPciG05kU9sgAUrWqjOgplsGSnw,15808
358
+ model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=LN4vOwcMuSSFTSnHDACV9hX_Yd2YIXJRl7WkdODuA0k,16245
357
359
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=yXJzDd24zfGs2_vfMovxD1WSh1RxXoPxN4GztOf3P5c,3967
358
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=bImrTw9rrAVc3VD4nmrXmBo_K4fuf5m5XPPf8ybOThs,12430
360
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=-4USg-tep6EQSArcTxBowhMeAuExrBTNLOWgHFpsIy4,12699
359
361
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
360
362
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=TCA1hAc7raPnrjl06sjFtVM4XUtLtuwAhCGX4U3KGZo,4137
361
363
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
@@ -375,9 +377,9 @@ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=cSuvHHCqgr7k9
375
377
  model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
376
378
  model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
377
379
  model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
378
- model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=5Doa5Rwer84DRgJxLa2e6aX9B4yGYdmFGgiv71_wD9o,8992
380
+ model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=T1_UqXmOc4I2a6IHkQAlFhGtcAYjsXSApMIdRlvgDvg,10154
379
381
  model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
380
- model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=QW7lMRuuonupLPZ2w2PDIQd7qpDZ_euLInhskTc1Yes,7518
382
+ model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=eof9bo-Mv_lLY7fFpiVeT5pIde-MuTWkIAqRKH4j9MI,8646
381
383
  model_compression_toolkit/qat/__init__.py,sha256=kj2qsZh_Ca7PncsHKcaL5EVT2H8g4hYtvaQ3KFxOkwE,1143
382
384
  model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
383
385
  model_compression_toolkit/qat/common/qat_config.py,sha256=zoq0Vb74vCY7WlWD8JH_KPrHDoUHSvMc3gcO53u7L2U,3394
@@ -469,8 +471,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
469
471
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
470
472
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
471
473
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=7bbzqJN8ZAycVDvZr_5xC-niTAR5df8f03Kooev_pfg,3047
472
- mct_nightly-2.0.0.20240408.430.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
473
- mct_nightly-2.0.0.20240408.430.dist-info/METADATA,sha256=5s1f6HnTLamnYz4r2orEDjLgg0h-L-v3dcQkXT9K-c4,18795
474
- mct_nightly-2.0.0.20240408.430.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
475
- mct_nightly-2.0.0.20240408.430.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
476
- mct_nightly-2.0.0.20240408.430.dist-info/RECORD,,
474
+ mct_nightly-2.0.0.20240410.422.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
475
+ mct_nightly-2.0.0.20240410.422.dist-info/METADATA,sha256=Xx2HTbZkpp4O8bS07IXSnaYSh9ZZTxe61I47ovv9fzE,18795
476
+ mct_nightly-2.0.0.20240410.422.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
477
+ mct_nightly-2.0.0.20240410.422.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
478
+ mct_nightly-2.0.0.20240410.422.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.0.0.20240408.000430"
30
+ __version__ = "2.0.0.20240410.000422"
@@ -30,7 +30,8 @@ from model_compression_toolkit.logger import Logger
30
30
 
31
31
  def analyzer_model_quantization(representative_data_gen: Callable,
32
32
  tb_w: TensorboardWriter,
33
- tg: Graph,
33
+ float_graph: Graph,
34
+ quantized_graph: Graph,
34
35
  fw_impl: FrameworkImplementation,
35
36
  fw_info: FrameworkInfo):
36
37
  """
@@ -41,23 +42,32 @@ def analyzer_model_quantization(representative_data_gen: Callable,
41
42
  Args:
42
43
  representative_data_gen: Dataset used for calibration.
43
44
  tb_w: TensorBoardWriter object to log events.
44
- tg: Graph of quantized model.
45
+ float_graph: Graph of float model.
46
+ quantized_graph: Graph of quantized model.
45
47
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
46
48
  fw_info: Information needed for quantization about the specific framework.
47
49
 
48
50
  """
49
51
  if tb_w is not None:
50
- visual = NNVisualizer(tg,
52
+ visual = NNVisualizer(float_graph,
53
+ quantized_graph,
51
54
  fw_impl=fw_impl,
52
55
  fw_info=fw_info)
53
-
54
- for i, _data in enumerate(representative_data_gen()):
55
- if i >= NUM_SAMPLES_DISTANCE_TENSORBOARD:
56
- break
57
- figure = visual.plot_distance_graph(_data,
58
- distance_fn=compute_cs,
59
- convert_to_range=lambda a: 1 - 2 * a)
60
- tb_w.add_figure(figure, f'similarity_distance_sample_{i}')
56
+ if not visual.has_compare_points():
57
+ Logger.error(f'No comparing points were found to plot analyze similarity.')
61
58
  else:
62
- Logger.warning(f'Not enough batches in representative dataset to generate {NUM_SAMPLES_DISTANCE_TENSORBOARD} figures')
59
+ visualized_samples = 0
60
+ for _data in representative_data_gen():
61
+ batch_size = _data[0].shape[0]
62
+ for sample_index in range(batch_size):
63
+ if visualized_samples >= NUM_SAMPLES_DISTANCE_TENSORBOARD:
64
+ break
65
+ figure = visual.plot_distance_graph(_data,
66
+ sample_index=sample_index,
67
+ distance_fn=compute_cs,
68
+ convert_to_range=lambda a: 1 - 2 * a)
69
+ tb_w.add_figure(figure, f'similarity_distance_sample_{visualized_samples}')
70
+ visualized_samples += 1
71
+ if visualized_samples < NUM_SAMPLES_DISTANCE_TENSORBOARD:
72
+ Logger.error(f'Not enough batches in representative dataset to generate {NUM_SAMPLES_DISTANCE_TENSORBOARD} figures')
63
73
  tb_w.close()
@@ -41,24 +41,24 @@ class BaseNodeQuantizationConfig(object):
41
41
  Base class for node quantization configuration
42
42
  """
43
43
 
44
- def set_quant_config_attr(self, parameter_name: str, parameter_value: Any,
44
+ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any,
45
45
  *args: List[Any], **kwargs: Dict[str, Any]):
46
46
  """
47
47
  Changes a BaseNodeQuantizationConfig's parameter.
48
48
  Note that arg and kwargs are only to allow clean override in the child classes.
49
49
 
50
50
  Args:
51
- parameter_name: parameter name to change.
52
- parameter_value: parameter value to change.
51
+ config_parameter_name: parameter name to change.
52
+ config_parameter_value: parameter value to change.
53
53
  args: A list of additional arguments.
54
54
  kwargs: A dictionary with additional key arguments.
55
55
 
56
56
  """
57
57
 
58
- if hasattr(self, parameter_name):
59
- setattr(self, parameter_name, parameter_value)
58
+ if hasattr(self, config_parameter_name):
59
+ setattr(self, config_parameter_name, config_parameter_value)
60
60
  else:
61
- Logger.warning(f"Parameter {parameter_name} could not be found in the node quantization config and "
61
+ Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config and "
62
62
  f"was not updated!")
63
63
 
64
64
  def __repr__(self) -> str:
@@ -106,6 +106,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
106
106
  self.z_threshold = qc.z_threshold
107
107
  self.shift_negative_ratio = qc.shift_negative_ratio
108
108
  self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
109
+ self.concat_threshold_update = qc.concat_threshold_update
109
110
 
110
111
  def quantize_node_output(self,
111
112
  tensors: Any) -> Any:
@@ -219,7 +220,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
219
220
  self.shift_negative_activation_correction == other.shift_negative_activation_correction and \
220
221
  self.z_threshold == other.z_threshold and \
221
222
  self.shift_negative_ratio == other.shift_negative_ratio and \
222
- self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
223
+ self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
223
224
 
224
225
  def __hash__(self):
225
226
  return hash((self.activation_quantization_fn,
@@ -521,7 +522,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
521
522
  f"{list(attrs_with_name.keys())}.")
522
523
  return attrs_with_name
523
524
 
524
- def set_quant_config_attr(self, parameter_name: str, parameter_value: Any, attr_name: str = None,
525
+ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any, attr_name: str = None,
525
526
  *args: List[Any], **kwargs: Dict[str, Any]):
526
527
  """
527
528
  This method overrides the parent class set_quant_config_attr to enable setting a specific weights
@@ -529,26 +530,27 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
529
530
 
530
531
  Args:
531
532
  attr_name: attribute name to change.
532
- parameter_name: parameter name to change.
533
- parameter_value: parameter value to change.
533
+ config_parameter_name: parameter name to change.
534
+ config_parameter_value: parameter value to change.
534
535
  args: A list of additional arguments.
535
536
  kwargs: A dictionary with additional key arguments.
536
537
 
537
538
  """
538
539
 
539
540
  if attr_name is None:
540
- super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(parameter_name, parameter_value,
541
+ super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(config_parameter_name,
542
+ config_parameter_value,
541
543
  *args, **kwargs)
542
544
  else:
543
545
  if self.has_attribute_config(attr_name):
544
546
  attr_cfg = self.get_attr_config(attr_name)
545
- if hasattr(attr_cfg, parameter_name):
546
- setattr(attr_cfg, parameter_name, parameter_value)
547
+ if hasattr(attr_cfg, config_parameter_name):
548
+ setattr(attr_cfg, config_parameter_name, config_parameter_value)
547
549
  else:
548
- Logger.warning(f"Parameter {parameter_name} could not be found in the node quantization config of "
550
+ Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
549
551
  f"weights attribute {attr_name} and was not updated!")
550
552
  else:
551
- Logger.error(f"Weights attribute {attr_name} could not be found to set parameter {parameter_name}.")
553
+ Logger.error(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
552
554
 
553
555
  def __eq__(self, other: Any) -> bool:
554
556
  """
@@ -62,7 +62,8 @@ class QuantizationConfig:
62
62
  residual_collapsing: bool = True,
63
63
  shift_negative_ratio: float = 0.05,
64
64
  shift_negative_threshold_recalculation: bool = False,
65
- shift_negative_params_search: bool = False):
65
+ shift_negative_params_search: bool = False,
66
+ concat_threshold_update: bool = False):
66
67
  """
67
68
  Class to wrap all different parameters the library quantize the input model according to.
68
69
 
@@ -117,6 +118,7 @@ class QuantizationConfig:
117
118
  self.shift_negative_ratio = shift_negative_ratio
118
119
  self.shift_negative_threshold_recalculation = shift_negative_threshold_recalculation
119
120
  self.shift_negative_params_search = shift_negative_params_search
121
+ self.concat_threshold_update = concat_threshold_update
120
122
 
121
123
  def __repr__(self):
122
124
  return str(self.__dict__)
@@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.quantization.quantize_node import get
23
23
  from model_compression_toolkit.logger import Logger
24
24
 
25
25
 
26
- def quantize_graph_weights(graph: Graph) -> Graph:
26
+ def quantize_graph_weights(graph_to_quantize: Graph) -> Graph:
27
27
  """
28
28
  Get a graph representing a model, and quantize its nodes' weights.
29
29
  Each node is quantized according to the passed framework info and quantization configuration.
@@ -31,12 +31,13 @@ def quantize_graph_weights(graph: Graph) -> Graph:
31
31
  is calculated and subtracted from the original node's bias. The graph is quantized in-place.
32
32
 
33
33
  Args:
34
- graph: Graph to quantize its nodes.
34
+ graph_to_quantize: Graph to quantize its nodes.
35
35
 
36
36
  """
37
+ _quantized_graph = copy.deepcopy(graph_to_quantize)
37
38
  # Iterate over nodes in the graph and quantize each node's weights and activations
38
39
  # (according to operators groups in framework info).
39
- for n in graph.nodes():
40
+ for n in _quantized_graph.nodes():
40
41
  for attr in n.get_node_weights_attributes():
41
42
  if n.is_weights_quantization_enabled(attr):
42
43
  quantized_attr, io_channels_axes = \
@@ -51,4 +52,4 @@ def quantize_graph_weights(graph: Graph) -> Graph:
51
52
  # Set the attribute to be the quantized attribute.
52
53
  n.set_weights_by_keys(attr, quantized_attr)
53
54
 
54
- return graph
55
+ return _quantized_graph
@@ -146,7 +146,10 @@ def compute_mae(float_tensor: np.ndarray,
146
146
  return error
147
147
 
148
148
 
149
- def compute_cs(float_tensor: np.ndarray, fxp_tensor: np.ndarray, eps: float = 1e-8, batch: bool = False,
149
+ def compute_cs(float_tensor: np.ndarray,
150
+ fxp_tensor: np.ndarray,
151
+ eps: float = 1e-8,
152
+ batch: bool = False,
150
153
  axis: int = None) -> float:
151
154
  """
152
155
  Compute the similarity between two tensor using cosine similarity.
@@ -12,19 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
15
16
  from typing import Tuple, List, Callable
16
17
 
17
18
  import numpy as np
18
19
  from matplotlib import pyplot as plt
19
20
  from matplotlib.figure import Figure
20
21
 
21
- from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
22
22
  from model_compression_toolkit.core.common import Graph
23
23
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
24
24
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
25
25
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
26
26
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
27
27
  from model_compression_toolkit.core.common.similarity_analyzer import compute_cs
28
+ from model_compression_toolkit.logger import Logger
28
29
 
29
30
 
30
31
  def _get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str]]:
@@ -57,17 +58,21 @@ class NNVisualizer:
57
58
 
58
59
  def __init__(self,
59
60
  graph_float: Graph,
61
+ graph_quantized: Graph,
60
62
  fw_impl: FrameworkImplementation,
61
63
  fw_info: FrameworkInfo):
62
64
  """
63
65
  Initialize a NNVisualizer object.
64
66
  Args:
65
67
  graph_float: Float version of the graph.
68
+ graph_quantized: Quantized version of the graph.
69
+ fw_impl: Framework implementation with framework-specific methods implementation.
70
+ fw_info: Framework info with framework-specific information.
66
71
 
67
72
  """
68
73
 
69
74
  self.graph_float = graph_float
70
- self.graph_quantized = quantize_graph_weights(graph_float)
75
+ self.graph_quantized = graph_quantized
71
76
  self.fw_impl = fw_impl
72
77
  self.fw_info = fw_info
73
78
 
@@ -75,6 +80,16 @@ class NNVisualizer:
75
80
  self.compare_points, self.compare_points_name = _get_compare_points(self.graph_quantized)
76
81
  self.compare_points_float, self.compare_points_name_float = _get_compare_points(self.graph_float)
77
82
 
83
+ if len(self.compare_points) != len(self.compare_points_float):
84
+ Logger.critical(f"Number of compare points in float and quantized models must be equal but "
85
+ f"num of quantized compare points: {len(self.compare_points)} and "
86
+ f"num of float compare points: {len(self.compare_points_float)}")
87
+ if len(self.compare_points_name) != len(self.compare_points_name_float):
88
+ Logger.critical(f"Number of compare points in float and quantized models must be equal "
89
+ f"but num of quantized compare points: {len(self.compare_points_name)}"
90
+ f" and num of float compare points: "
91
+ f"{len(self.compare_points_name_float)}")
92
+
78
93
  self.quantized_model, _ = self.fw_impl.model_builder(self.graph_quantized,
79
94
  mode=ModelBuilderMode.QUANTIZED,
80
95
  append2output=self.compare_points,
@@ -85,8 +100,19 @@ class NNVisualizer:
85
100
  append2output=self.compare_points_float,
86
101
  fw_info=self.fw_info)
87
102
 
103
+ def has_compare_points(self) -> bool:
104
+ """
105
+
106
+ Returns: Whether or not compare points were found.
107
+
108
+ """
109
+ return len(self.compare_points_float) > 0 and len(self.compare_points) > 0 and len(
110
+ self.compare_points_name_float) > 0 and len(self.compare_points_name) > 0
111
+
112
+
88
113
  def plot_distance_graph(self,
89
114
  input_image: np.ndarray,
115
+ sample_index: int,
90
116
  distance_fn: Callable = compute_cs,
91
117
  convert_to_range: Callable = lambda a: a) -> Figure:
92
118
  """
@@ -95,6 +121,7 @@ class NNVisualizer:
95
121
 
96
122
  Args:
97
123
  input_image: Image to use as input to the networks.
124
+ sample_index: The index of the sample from input_image to use for comparison.
98
125
  distance_fn: Distance function to calculate the distance between two tensors.
99
126
  convert_to_range: Optional function to move the distance values into a specific range, e.g., when using
100
127
  cosine similarity for distance, use 'lambda a: 1 - 2 * a' to convert the distance values to the range
@@ -108,7 +135,7 @@ class NNVisualizer:
108
135
  # to make the difference more noticeable when exists
109
136
  new_inputs = []
110
137
  for single_input in input_image:
111
- img = single_input[0]
138
+ img = single_input[sample_index]
112
139
  new_inputs.append(np.expand_dims(img, axis=0))
113
140
 
114
141
  # Get outputs
@@ -123,7 +150,7 @@ class NNVisualizer:
123
150
 
124
151
  # Display the result: distance at every layer's output.
125
152
  fig = plt.figure()
126
- plt.plot(distance_array)
153
+ plt.plot(list(range(len(distance_array))), distance_array)
127
154
  eps = 0.5
128
155
  y_limits = (min(distance_array) - eps, max(distance_array) + eps)
129
156
  plt.ylim(y_limits)
@@ -0,0 +1,66 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from tensorflow.keras.layers import Concatenate
18
+ import tensorflow as tf
19
+
20
+ from model_compression_toolkit.core import common
21
+ from model_compression_toolkit.core.common import Graph, BaseNode
22
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
23
+ from model_compression_toolkit.constants import THRESHOLD
24
+
25
+
26
+
27
+ class ConcatThresholdUpdate(common.BaseSubstitution):
28
+
29
+
30
+ """
31
+ Find concat layers and match their prior layers thresholds unless prior layer outputs to multiple layers.
32
+ """
33
+
34
+ def __init__(self):
35
+ """
36
+ Initialize a threshold_updater object.
37
+ """
38
+ concatination_node = NodeOperationMatcher(Concatenate) | \
39
+ NodeOperationMatcher(tf.concat)
40
+ super().__init__(matcher_instance=concatination_node)
41
+
42
+ def substitute(self,
43
+ graph: Graph,
44
+ node: BaseNode) -> Graph:
45
+ """
46
+ Update previous layers thresholds to match concatinations quantization thresholds. No change if
47
+ previous layer outputs to multiple layers. No change in case of uniform quantization.
48
+ No change in case of multiple quantization candidates (mixed precision).
49
+
50
+
51
+ Args:
52
+ graph: Graph we apply the substitution on.
53
+ node: Node refference to edit previous nodes thresholds.
54
+
55
+ Returns:
56
+ Graph after applying the substitution.
57
+ """
58
+
59
+ if len(node.candidates_quantization_cfg) == 1 and THRESHOLD in node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params:
60
+ concat_threshold = node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD]
61
+ prev_nodes = graph.get_prev_nodes(node)
62
+ for prev_node in prev_nodes:
63
+ if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != Concatenate and prev_node.type != tf.concat:
64
+ prev_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] = concat_threshold
65
+
66
+ return graph
@@ -80,7 +80,8 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.line
80
80
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \
81
81
  keras_residual_collapsing
82
82
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \
83
- InputScalingWithPad
83
+ InputScalingWithPad
84
+ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.concat_threshold_update import ConcatThresholdUpdate
84
85
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
85
86
  ReLUBoundToPowerOfTwo
86
87
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.multi_head_attention_decomposition import \
@@ -300,8 +301,8 @@ class KerasImplementation(FrameworkImplementation):
300
301
  """
301
302
  return keras_op2d_add_const_collapsing()
302
303
 
303
- def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) \
304
- -> List[common.BaseSubstitution]:
304
+ def get_substitutions_post_statistics_collection(self,
305
+ quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
305
306
  """
306
307
  Return a list of the framework substitutions used after we collect statistics.
307
308
 
@@ -317,6 +318,8 @@ class KerasImplementation(FrameworkImplementation):
317
318
  if quant_config.input_scaling:
318
319
  substitutions_list.append(InputScaling())
319
320
  substitutions_list.append(InputScalingWithPad())
321
+ if quant_config.concat_threshold_update:
322
+ substitutions_list.append(ConcatThresholdUpdate())
320
323
  return substitutions_list
321
324
 
322
325
 
@@ -0,0 +1,69 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import List
17
+
18
+ import torch
19
+
20
+ from model_compression_toolkit.core import common
21
+ from model_compression_toolkit.core.common.graph.base_graph import Graph
22
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
23
+ from model_compression_toolkit.core.common.graph.base_node import BaseNode
24
+ from model_compression_toolkit.constants import THRESHOLD
25
+
26
+
27
+ class ConcatThresholdUpdate(common.BaseSubstitution):
28
+ """
29
+ Find concat layers and match their prior layers thresholds unless prior layer outputs to multiple layers.
30
+ """
31
+
32
+
33
+ def __init__(self):
34
+ """
35
+ Initialize a threshold_updater object.
36
+ """
37
+ concatination_node = NodeOperationMatcher(torch.cat) | \
38
+ NodeOperationMatcher(torch.concat)
39
+ super().__init__(matcher_instance=concatination_node)
40
+
41
+ def substitute(self,
42
+ graph: Graph,
43
+ node: BaseNode) -> Graph:
44
+ """
45
+ Update previous layers thresholds to match concatinations quantization thresholds. No change if
46
+ previous layer outputs to multiple layers. No change in case of uniform quantization.
47
+ No change in case of multiple quantization candidates (mixed precision).
48
+
49
+
50
+ Args:
51
+ graph: Graph we apply the substitution on.
52
+ node: Node refference to edit previous nodes thresholds.
53
+
54
+ Returns:
55
+ Graph after applying the substitution.
56
+ """
57
+
58
+ if len(node.candidates_quantization_cfg) == 1 and THRESHOLD in node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params:
59
+ concat_threshold = node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD]
60
+ prev_nodes = graph.get_prev_nodes(node)
61
+ for prev_node in prev_nodes:
62
+ if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != torch.cat and prev_node.type != torch.concat:
63
+ prev_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] = concat_threshold
64
+
65
+ return graph
66
+
67
+
68
+
69
+
@@ -73,6 +73,8 @@ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.vi
73
73
  VirtualActivationWeightsComposition
74
74
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.weights_activation_split import \
75
75
  WeightsActivationSplit
76
+ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.concat_threshold_update import \
77
+ ConcatThresholdUpdate
76
78
  from model_compression_toolkit.core.pytorch.hessian.activation_trace_hessian_calculator_pytorch import \
77
79
  ActivationTraceHessianCalculatorPytorch
78
80
  from model_compression_toolkit.core.pytorch.hessian.weights_trace_hessian_calculator_pytorch import \
@@ -302,6 +304,8 @@ class PytorchImplementation(FrameworkImplementation):
302
304
  substitutions_list.append(pytorch_softmax_shift())
303
305
  if quant_config.input_scaling:
304
306
  Logger.critical('Input scaling is currently not supported for Pytorch.')
307
+ if quant_config.concat_threshold_update:
308
+ substitutions_list.append(ConcatThresholdUpdate())
305
309
  return substitutions_list
306
310
 
307
311
 
@@ -337,12 +337,16 @@ class KerasGPTQTrainer(GPTQTrainer):
337
337
  node = node[0]
338
338
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
339
339
  fw_info=self.fw_info)
340
+ # TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
341
+ # To enable GPTQ for other attributes, this code needs to be modified.
340
342
  weights, weight_quant_config, activation_quant_config = \
341
343
  layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
342
344
  for weight_attr, weight in weights.items():
343
345
  node.set_weights_by_keys(weight_attr, weight.numpy())
344
- for config_attr, config_value in weight_quant_config.items():
345
- node.final_weights_quantization_cfg.set_quant_config_attr(config_attr, config_value)
346
+ for config_parameter_name, config_parameter_value in weight_quant_config.items():
347
+ node.final_weights_quantization_cfg.set_quant_config_attr(config_parameter_name,
348
+ config_parameter_value,
349
+ attr_name=kernel_attribute)
346
350
  for config_attr, config_value in activation_quant_config.items():
347
351
  node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
348
352
  if self.gptq_config.train_bias:
@@ -12,10 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
15
16
 
16
17
  from typing import Callable, Tuple
17
18
  from packaging import version
18
19
 
20
+ from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
19
21
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
20
22
  from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
21
23
  from model_compression_toolkit.logger import Logger
@@ -210,6 +212,8 @@ if FOUND_TF:
210
212
  target_resource_utilization=target_resource_utilization,
211
213
  tb_w=tb_w)
212
214
 
215
+ float_graph = copy.deepcopy(tg)
216
+
213
217
  tg_gptq = gptq_runner(tg,
214
218
  core_config,
215
219
  gptq_config,
@@ -223,7 +227,12 @@ if FOUND_TF:
223
227
  del hessian_info_service
224
228
 
225
229
  if core_config.debug_config.analyze_similarity:
226
- analyzer_model_quantization(representative_data_gen, tb_w, tg_gptq, fw_impl, fw_info)
230
+ analyzer_model_quantization(representative_data_gen,
231
+ tb_w,
232
+ float_graph,
233
+ tg_gptq,
234
+ fw_impl,
235
+ DEFAULT_KERAS_INFO)
227
236
 
228
237
  return get_exportable_keras_model(tg_gptq)
229
238
 
@@ -284,12 +284,16 @@ class PytorchGPTQTrainer(GPTQTrainer):
284
284
  node = node[0]
285
285
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
286
286
  fw_info=self.fw_info)
287
+ # TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
288
+ # To enable GPTQ for other attributes, this code needs to be modified.
287
289
  weights, weight_quant_config, activation_quant_config = \
288
290
  layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
289
291
  for weight_attr, weight in weights.items():
290
292
  node.set_weights_by_keys(weight_attr, self.fw_impl.to_numpy(weight))
291
- for config_attr, config_value in weight_quant_config.items():
292
- node.final_weights_quantization_cfg.set_quant_config_attr(config_attr, config_value)
293
+ for config_parameter_name, config_parameter_value in weight_quant_config.items():
294
+ node.final_weights_quantization_cfg.set_quant_config_attr(config_parameter_name,
295
+ config_parameter_value,
296
+ attr_name=kernel_attribute)
293
297
  for config_attr, config_value in activation_quant_config.items():
294
298
  node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
295
299
  if self.gptq_config.train_bias and hasattr(layer.layer, BIAS):
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
16
+
15
17
  from typing import Callable
16
18
  from model_compression_toolkit.core import common
17
19
  from model_compression_toolkit.constants import FOUND_TORCH
@@ -177,6 +179,7 @@ if FOUND_TORCH:
177
179
  tpc=target_platform_capabilities,
178
180
  target_resource_utilization=target_resource_utilization,
179
181
  tb_w=tb_w)
182
+ float_graph = copy.deepcopy(graph)
180
183
 
181
184
  # ---------------------- #
182
185
  # GPTQ Runner
@@ -192,7 +195,12 @@ if FOUND_TORCH:
192
195
  hessian_info_service=hessian_info_service)
193
196
 
194
197
  if core_config.debug_config.analyze_similarity:
195
- analyzer_model_quantization(representative_data_gen, tb_w, graph_gptq, fw_impl, DEFAULT_PYTORCH_INFO)
198
+ analyzer_model_quantization(representative_data_gen,
199
+ tb_w,
200
+ float_graph,
201
+ graph_gptq,
202
+ fw_impl,
203
+ DEFAULT_PYTORCH_INFO)
196
204
 
197
205
  return get_exportable_pytorch_model(graph_gptq)
198
206
 
@@ -12,11 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
15
16
 
16
17
  from typing import Callable
17
18
 
18
19
  from model_compression_toolkit.core import CoreConfig
19
20
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
21
+ from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
20
22
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
21
23
  from model_compression_toolkit.logger import Logger
22
24
  from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
@@ -122,8 +124,8 @@ if FOUND_TF:
122
124
  if core_config.mixed_precision_enable:
123
125
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
124
126
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
125
- "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
126
- "API, or pass a valid mixed precision configuration.") # pragma: no cover
127
+ "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
128
+ "API, or pass a valid mixed precision configuration.") # pragma: no cover
127
129
 
128
130
  tb_w = init_tensorboard_writer(fw_info)
129
131
 
@@ -139,15 +141,30 @@ if FOUND_TF:
139
141
  target_resource_utilization=target_resource_utilization,
140
142
  tb_w=tb_w)
141
143
 
142
- tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
144
+ # At this point, tg is a graph that went through substitutions (such as BN folding) and is
145
+ # ready for quantization (namely, it holds quantization params, etc.) but the weights are
146
+ # not quantized yet. For this reason, we use it to create a graph that acts as a "float" graph
147
+ # for things like similarity analyzer (because the quantized and float graph should have the same
148
+ # architecture to find the appropriate compare points for similarity computation).
149
+ similarity_baseline_graph = copy.deepcopy(tg)
150
+
151
+ graph_with_stats_correction = ptq_runner(tg,
152
+ representative_data_gen,
153
+ core_config,
154
+ fw_info,
155
+ fw_impl,
156
+ tb_w)
143
157
 
144
158
  if core_config.debug_config.analyze_similarity:
159
+ quantized_graph = quantize_graph_weights(graph_with_stats_correction)
145
160
  analyzer_model_quantization(representative_data_gen,
146
- tb_w, tg,
161
+ tb_w,
162
+ similarity_baseline_graph,
163
+ quantized_graph,
147
164
  fw_impl,
148
165
  fw_info)
149
166
 
150
- return get_exportable_keras_model(tg)
167
+ return get_exportable_keras_model(graph_with_stats_correction)
151
168
 
152
169
 
153
170
 
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
16
+
15
17
  from typing import Callable
16
18
 
17
19
  from model_compression_toolkit.core import common
@@ -26,6 +28,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quant
26
28
  from model_compression_toolkit.core.runner import core_runner
27
29
  from model_compression_toolkit.ptq.runner import ptq_runner
28
30
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
31
+ from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
29
32
 
30
33
 
31
34
  if FOUND_TORCH:
@@ -90,14 +93,16 @@ if FOUND_TORCH:
90
93
 
91
94
  """
92
95
 
96
+ fw_info = DEFAULT_PYTORCH_INFO
97
+
93
98
  if core_config.mixed_precision_enable:
94
99
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
95
100
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
96
- "MixedPrecisionQuantizationConfig. Please use "
97
- "pytorch_post_training_quantization API, or pass a valid mixed precision "
98
- "configuration.") # pragma: no cover
101
+ "MixedPrecisionQuantizationConfig. Please use "
102
+ "pytorch_post_training_quantization API, or pass a valid mixed precision "
103
+ "configuration.") # pragma: no cover
99
104
 
100
- tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
105
+ tb_w = init_tensorboard_writer(fw_info)
101
106
 
102
107
  fw_impl = PytorchImplementation()
103
108
 
@@ -105,22 +110,36 @@ if FOUND_TORCH:
105
110
  tg, bit_widths_config, _ = core_runner(in_model=in_module,
106
111
  representative_data_gen=representative_data_gen,
107
112
  core_config=core_config,
108
- fw_info=DEFAULT_PYTORCH_INFO,
113
+ fw_info=fw_info,
109
114
  fw_impl=fw_impl,
110
115
  tpc=target_platform_capabilities,
111
116
  target_resource_utilization=target_resource_utilization,
112
117
  tb_w=tb_w)
113
118
 
114
- tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)
119
+ # At this point, tg is a graph that went through substitutions (such as BN folding) and is
120
+ # ready for quantization (namely, it holds quantization params, etc.) but the weights are
121
+ # not quantized yet. For this reason, we use it to create a graph that acts as a "float" graph
122
+ # for things like similarity analyzer (because the quantized and float graph should have the same
123
+ # architecture to find the appropriate compare points for similarity computation).
124
+ similarity_baseline_graph = copy.deepcopy(tg)
125
+
126
+ graph_with_stats_correction = ptq_runner(tg,
127
+ representative_data_gen,
128
+ core_config,
129
+ fw_info,
130
+ fw_impl,
131
+ tb_w)
115
132
 
116
133
  if core_config.debug_config.analyze_similarity:
134
+ quantized_graph = quantize_graph_weights(graph_with_stats_correction)
117
135
  analyzer_model_quantization(representative_data_gen,
118
136
  tb_w,
119
- tg,
137
+ similarity_baseline_graph,
138
+ quantized_graph,
120
139
  fw_impl,
121
- DEFAULT_PYTORCH_INFO)
140
+ fw_info)
122
141
 
123
- return get_exportable_pytorch_model(tg)
142
+ return get_exportable_pytorch_model(graph_with_stats_correction)
124
143
 
125
144
 
126
145
  else: