mct-nightly 2.1.0.20240807.445__py3-none-any.whl → 2.1.0.20240808.431__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/RECORD +33 -32
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/constants.py +14 -1
  5. model_compression_toolkit/core/common/fusion/graph_fuser.py +135 -0
  6. model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +4 -0
  7. model_compression_toolkit/core/common/quantization/debug_config.py +4 -1
  8. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +29 -1
  9. model_compression_toolkit/core/runner.py +21 -1
  10. model_compression_toolkit/gptq/keras/quantization_facade.py +13 -11
  11. model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -11
  12. model_compression_toolkit/metadata.py +61 -2
  13. model_compression_toolkit/ptq/keras/quantization_facade.py +12 -10
  14. model_compression_toolkit/ptq/pytorch/quantization_facade.py +12 -12
  15. model_compression_toolkit/qat/keras/quantization_facade.py +8 -8
  16. model_compression_toolkit/qat/pytorch/quantization_facade.py +8 -8
  17. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +10 -13
  18. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +68 -52
  19. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +35 -29
  20. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +35 -28
  21. model_compression_toolkit/xquant/common/constants.py +3 -0
  22. model_compression_toolkit/xquant/common/core_report_generator.py +9 -1
  23. model_compression_toolkit/xquant/common/framework_report_utils.py +5 -14
  24. model_compression_toolkit/xquant/common/tensorboard_utils.py +30 -5
  25. model_compression_toolkit/xquant/keras/facade_xquant_report.py +2 -0
  26. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -1
  27. model_compression_toolkit/xquant/keras/tensorboard_utils.py +101 -4
  28. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +2 -0
  29. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -2
  30. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +109 -3
  31. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/LICENSE.md +0 -0
  32. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/WHEEL +0 -0
  33. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.1.0.20240807.445
3
+ Version: 2.1.0.20240808.431
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,13 +1,13 @@
1
- model_compression_toolkit/__init__.py,sha256=5v5QMZsuecZeSiFdHfgNeoHe13N79F9En_BxUoMzw7E,1573
2
- model_compression_toolkit/constants.py,sha256=dexmfFCQ6VgoWuFBeM6MZykfgiVVdVxgkiSnpfjN8Dw,4005
1
+ model_compression_toolkit/__init__.py,sha256=eZGxT7hh5fuuq7hfe_N0zCm01ZqSyAOk0HqlhKGtXxY,1573
2
+ model_compression_toolkit/constants.py,sha256=0qrEGjX36Oo7Lt8mR0LD2aSe2xA7gKrhkzBGp7g5eiA,4345
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
5
- model_compression_toolkit/metadata.py,sha256=IyoON37lBv3TI0rZGCP4K5t3oYI4TOmYy-LRXOwHGpE,1136
5
+ model_compression_toolkit/metadata.py,sha256=UtXS5ClK-qPoxGRuytlDGZSzgLo911dMni2EFRcg6io,3623
6
6
  model_compression_toolkit/core/__init__.py,sha256=TrRgkWpT1AN2Faw1M_1HXyJkJnbxfn9p-RigDZl7pg0,1982
7
7
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
8
8
  model_compression_toolkit/core/graph_prep_runner.py,sha256=kM70wmNG3yMFiGQc0uO0wn9j4ZbSWxUEykpxDK55doc,10567
9
9
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=K9eJ7VbB_rpeyxX4yEnorOmSxFW3DkvofzxS6QI8Hp8,6454
10
- model_compression_toolkit/core/runner.py,sha256=JvX0Ht164BOKIsPPxp6z-Nlk1Vlhlg7wKBl6lc2yIaQ,12675
10
+ model_compression_toolkit/core/runner.py,sha256=uXpyYaX1uFNhKituGmSfKb3ZkguXG2V_Cg6XCnprplg,13569
11
11
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
12
12
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
13
13
  model_compression_toolkit/core/common/framework_implementation.py,sha256=kSg2f7wS7e2EyvX6y0eKfNTTFvVFVrB8lvldJvcPvN8,20724
@@ -28,6 +28,7 @@ model_compression_toolkit/core/common/collectors/mean_collector.py,sha256=mjr3U_
28
28
  model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py,sha256=5oKsJEKdVmj4C7fKdHhmrFN5k4G2BaFETpmf_xKNs7s,5207
29
29
  model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=vcf7Pk1v09SJC4fbAWf_8AgTktE6tPizJbQpSmocP2U,7930
30
30
  model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
31
+ model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=8seu9jBpC7HartP1nJd7S_SYFICyemVpDV9ZJ0QUQ7E,6212
31
32
  model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=lOubqpc18TslhXZijWUJQAa1c3jIB2S-M-5HK78wJPQ,5548
32
33
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
33
34
  model_compression_toolkit/core/common/graph/base_graph.py,sha256=lg5QaBkRbmvM3tGZ0Q34S3m0CbFql3LUv5BaXLe5TG8,37824
@@ -39,7 +40,7 @@ model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-
39
40
  model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=3el-A7j1oyoo1_9zq3faQp7IeRsFXFCvnrb3zZFXpU0,9803
40
41
  model_compression_toolkit/core/common/graph/memory_graph/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
41
42
  model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py,sha256=X6FK3C3y8ixFRPjC_wm3ClloCX8_06SOdA1TRi7o_LA,3800
42
- model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py,sha256=MNQ957rjy5hbuZYP9fYR2cxuzxpyS8U3p5r4MYCzkjQ,2612
43
+ model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py,sha256=yJ0GncHVOP0Wj9ntzluklDQsgRFow89gteNCvIxRVXU,2857
43
44
  model_compression_toolkit/core/common/graph/memory_graph/cut.py,sha256=aPdXJPP5a5Rnu5Z5XqTZZkuGtdgHVu0RmX_NOfNM6Tc,2470
44
45
  model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256=crV2NCLVO8jx9MlryZBYuJKFe_G9HfM7rUR64fDymlw,17045
45
46
  model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
@@ -99,7 +100,7 @@ model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256
99
100
  model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
100
101
  model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=yU-Cr6S4wOSkDk57iH2NVe-WII0whOhLryejkomCOt4,4940
101
102
  model_compression_toolkit/core/common/quantization/core_config.py,sha256=KYdyfSmjSL4ye24nKlC_c4_AxYb14qoqaeMnZj4-8kE,2257
102
- model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
103
+ model_compression_toolkit/core/common/quantization/debug_config.py,sha256=8G8SpE_4rb8xBp8d6mMq8R_OnXJ_1oxB2g-Lxk9EJCM,1691
103
104
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
104
105
  model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=YycYN8_JMzvSR3pTVm5dT5x4zP3yBHn0Z9agnwrvOKI,26395
105
106
  model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=du0VdsxfkOSYaP1EU9gHA5qbXpfQNZL0jXrjk1wBA0U,7106
@@ -145,7 +146,7 @@ model_compression_toolkit/core/common/substitutions/weights_activation_split.py,
145
146
  model_compression_toolkit/core/common/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
146
147
  model_compression_toolkit/core/common/visualization/final_config_visualizer.py,sha256=6I10jKLesB-RQKaXA75Xgz2wPvylQUrnPtCcQZIynGo,6371
147
148
  model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=HOq7AObkmEZiDSZXUMJDAEJzUY-fSXUT0AMgwiyH7dg,7388
148
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=g5c5BCJnJ1Lgu2aQvHh2NeeQE954ndr3H4cKmvtr5IM,22510
149
+ model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=1-OQu3RNKXA55qfKG1MPq4JxTzmFeVKFDWv5i3TktRw,23676
149
150
  model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
150
151
  model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
151
152
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
@@ -348,7 +349,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
348
349
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
349
350
  model_compression_toolkit/gptq/keras/gptq_training.py,sha256=NXTNsVrO9DTh0uvc8V7rFaM0fYg2OA18ZrYd-cKZ7Z4,19159
350
351
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=MKIfrRTRH3zCuxCR1g9ZVIFyuSSr0e0sDybqh4LDM7E,4672
351
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=SjmBTuSwki4JTPVhxvJMFK9uAsmEm2c6VV11NnM6eEo,15117
352
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=K2G9RTBDs9yNCDKyPI6-MbIMduRBGNGepEi2UKpgGbw,15444
352
353
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
353
354
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=2YU-x4-Q5f6hkUJf0tw6vcwdNwRMHdefrFjhhyHYsvA,4782
354
355
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -365,7 +366,7 @@ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa9
365
366
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
366
367
  model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=-daninmlPGfKsBNPB2C3gT6rK0G5YeyJsuOLA0JlfBU,16633
367
368
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
368
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=jcvRKBuMkrerNE8oWIJFp802pyFO0dnA-4hRnclKbWE,13569
369
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=gzg2QUzb3BO5rCtIONjBQr8TXb3qolFxHIkylSv8gMY,13896
369
370
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
370
371
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=TCA1hAc7raPnrjl06sjFtVM4XUtLtuwAhCGX4U3KGZo,4137
371
372
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
@@ -385,14 +386,14 @@ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=cSuvHHCqgr7k9
385
386
  model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
386
387
  model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
387
388
  model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
388
- model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=s6vBCK98l-R12yWASkutPSmNSfPX7457DazroJwhjpo,10517
389
+ model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=4sl28g4zw90hVfhbhboP8Vv1b3jySd5SPH7Euib4Ko0,10808
389
390
  model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
390
- model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=kkdgBXRBkblBTOW5EaySI_bN4_becSUwbdgOTb7FW2c,9012
391
+ model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=g3Fnk7MjZY9YSSJ5BcXgM0wvMT52IudDobu4eyM2uvc,9252
391
392
  model_compression_toolkit/qat/__init__.py,sha256=kj2qsZh_Ca7PncsHKcaL5EVT2H8g4hYtvaQ3KFxOkwE,1143
392
393
  model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
393
394
  model_compression_toolkit/qat/common/qat_config.py,sha256=zoq0Vb74vCY7WlWD8JH_KPrHDoUHSvMc3gcO53u7L2U,3394
394
395
  model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
395
- model_compression_toolkit/qat/keras/quantization_facade.py,sha256=AXwY6p1XFjPUzal_r_c1_su5Ji3ARtVTZYYWpDPZ09k,17026
396
+ model_compression_toolkit/qat/keras/quantization_facade.py,sha256=4yixHJ5j_RP0C6rTyPkMi-hoBmYKOBFunpUL5GnTdK4,17050
396
397
  model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
397
398
  model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py,sha256=0CB5M68zjPXv4yJZ-DzaYP9yYYWX_8J2gJLunxupOAM,2085
398
399
  model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
@@ -404,7 +405,7 @@ model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cc
404
405
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=I4KlaGv17k71IyjuSG9M0OlXlD5P0pfvKa6oCyRQ5FE,13517
405
406
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=EED6LfqhX_OhDRJ9e4GwbpgNC9vq7hoXyJS2VPvG2qc,10789
406
407
  model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
407
- model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=pRy2B5OsaLi33p4hozjr0rzAooT8Gic3_qxTl66J900,13375
408
+ model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=EJ4SPfyD30gyN_HithfITW1NWZ9pOwRvQ2cvDOJP5rQ,13399
408
409
  model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
409
410
  model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py,sha256=2I_WcINn63lpT3mN_skXNL4Rfbm955_wzhYHaiwH2q4,2207
410
411
  model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=sFWGu76PZ9dSRf3L0uZI6YwLIs0biBND1tl76I1piBQ,5721
@@ -433,9 +434,9 @@ model_compression_toolkit/target_platform_capabilities/target_platform/targetpla
433
434
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py,sha256=KP8IWlHzkXzVjqIiRtAW6sTYyHJ2wVFFX4hMt_N6o3s,9910
434
435
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities_component.py,sha256=FvrYI0Qy7DCmDp2gyUYyCZq5pY84JgLtJqSIiVTJ8Ss,1030
435
436
  model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
436
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py,sha256=-jCL-meZWFBF-Dp9wBYTX_14SKmyyUJE-BZ2IQDJIAk,3336
437
+ model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py,sha256=CWind2Kd91lzBTRAh1A9sHuNw17xXhMb3gV436RpK8c,3033
437
438
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
438
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py,sha256=KOSrFJAheWk360kU4UKQRVOaM0xIUaVdEdnU6b3t7Ww,5046
439
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py,sha256=Bfj4ek6-Aii_1FC7814cg-TNAG1nRAzQt7_3-jTlbXs,6028
439
440
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py,sha256=F5RG4MnuAwKcNXbfVbPFLQu30-lNax-7knqu20B6udQ,1522
440
441
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/__init__.py,sha256=1mMOREEMoNHu_KTMGDp4crN61opKWX6aFn1DrDLvqcc,717
441
442
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py,sha256=6mbv-fNVz559j5XCSX5e8aENUJACYuJzQcZBLPh12gU,11057
@@ -470,14 +471,14 @@ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/
470
471
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py,sha256=VSPTv6pt6OX8Zpjdit5GK9WattHpKAi4sVByBzTwsgw,6626
471
472
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py,sha256=j4xvBfGdw-wEctv_mlZ_ottxc656uJH9uXRVrZBtNjk,5840
472
473
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
473
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py,sha256=7KVcuz0LfngRKOsfcvBysxGVb9fqgoAO6MVTl1CmB5c,2082
474
+ model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py,sha256=uDsmbGZSPuTXjWGmHhhvXhIC3LmUBwuIDJC_-fuDLfA,2753
474
475
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py,sha256=UUvUCcTots_sehdRnDfgkaE8WPQ7dPbeuhDF4Qy2nzw,1510
475
476
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/__init__.py,sha256=t4JKsPcor-7KSCKzIwuaBv0NLNwfhuewAQGlDl6iBeo,717
476
477
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py,sha256=k1cYUXpVNAvuBVUinSZGu_wDZQvUGAp8e4x9xHBUAOE,8275
477
478
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_keras.py,sha256=h_hePXCggG2qktLuoNAOE1XNtc0qEwMyky7om1c8eC8,4483
478
479
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py,sha256=65WJPRCjliXEUL4AjZRxcyVS3y7KHTMDdkqy6D95kRw,3814
479
480
  model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
480
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py,sha256=Go0RJ1KcKoynCUSwGhxA1nsYsMmZEFSrxiL59iyE6LA,2077
481
+ model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py,sha256=xWA9GNqJrLZfGNAfWQtMQc64z1wUUvDIX3ozprxgwuQ,2749
481
482
  model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py,sha256=sK9PnyB2R9g0rqHr_9vyUFX7wSyrZe7x9yqYUlbaiqo,1505
482
483
  model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/__init__.py,sha256=t4JKsPcor-7KSCKzIwuaBv0NLNwfhuewAQGlDl6iBeo,717
483
484
  model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py,sha256=rxDkISGCxTB2RaVm59zJWxaJKxGgt4uceDgQ_9E_RmI,10033
@@ -501,32 +502,32 @@ model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=hu
501
502
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=MxylaVFPgN7zBiRBy6WV610EA4scLgRJFbMucKvvNDU,2896
502
503
  model_compression_toolkit/xquant/__init__.py,sha256=vdmr8sQw3jIBLF9ck7qrskPoXzDKtksHWlMOkU1JUnQ,1003
503
504
  model_compression_toolkit/xquant/common/__init__.py,sha256=ycb1Xt7PtixY2Uabr94JGSwBMcct66O8ZMVf3Qa3ud8,719
504
- model_compression_toolkit/xquant/common/constants.py,sha256=LRh7q0GtyLTSwOc-XL5yNcPKVq68RvKnORYEC4KK-Ss,1513
505
- model_compression_toolkit/xquant/common/core_report_generator.py,sha256=LQ9QUST9xyvm4B5sp68rjVPnpnxyosn_9jDBcyRciLs,4951
505
+ model_compression_toolkit/xquant/common/constants.py,sha256=k-9LOEv1n_m8dV4chX0dNOTWyhhF7S00E0lkUxtO84E,1592
506
+ model_compression_toolkit/xquant/common/core_report_generator.py,sha256=GHnJJpK6JxArc38vSmKrKj-UJMCmMI7aJAt00Zq_PSc,5403
506
507
  model_compression_toolkit/xquant/common/dataset_utils.py,sha256=91uXF9UwxdY7BvUT0FNkFm8a69c8oK8Xdl-y7lbuJxk,1649
507
- model_compression_toolkit/xquant/common/framework_report_utils.py,sha256=3hzTg5xqdcxHnxmxO8B06o5sW8R-NH1Ixa75U0kie-o,3891
508
+ model_compression_toolkit/xquant/common/framework_report_utils.py,sha256=YE49232ESflW6ZaUABF1pk_GGHBxa_F1X5oRN2Jogys,3734
508
509
  model_compression_toolkit/xquant/common/model_analyzer.py,sha256=T_8OetIQNqR0nkfSatWsEceXSPYpHfYjboBPIyR03-w,3953
509
510
  model_compression_toolkit/xquant/common/model_folding_utils.py,sha256=y5Vmc-hJ2rJhzWdM53HdY-PrT5LlspejTUNlXaCrq9Q,4720
510
511
  model_compression_toolkit/xquant/common/similarity_calculator.py,sha256=yCs_vlOThLzq7z-u2PkcEErLj7N7qCBPpRa6_5h34J8,10460
511
512
  model_compression_toolkit/xquant/common/similarity_functions.py,sha256=Atah1otdX9oUUch2JK-p-e291QHtkP_c4DfLG9WWo1Y,2935
512
- model_compression_toolkit/xquant/common/tensorboard_utils.py,sha256=YWvTvp7DyZDhybLnjte1Em90lev-NAa-hYp445BX-Y4,4473
513
+ model_compression_toolkit/xquant/common/tensorboard_utils.py,sha256=85ABGQGKPZzctyZCHLazK0GxZ2ZUtQA3hZ_9fPiuMs0,6533
513
514
  model_compression_toolkit/xquant/common/xquant_config.py,sha256=Qt56cra2tU1PeHlLx_Cqztf5q-ED8MPelhb8coSumFw,1675
514
515
  model_compression_toolkit/xquant/keras/__init__.py,sha256=zbtceCVRsi-Gvl_pOmq5laqVqu55vAU1ie2FR2RK1Po,709
515
516
  model_compression_toolkit/xquant/keras/dataset_utils.py,sha256=quvVymhvpcPIOneCu5J6K_QAqBHOCIj8IxZxSN2fItA,2258
516
- model_compression_toolkit/xquant/keras/facade_xquant_report.py,sha256=ZBwu1PwBgMbhQK-GvVCmn8CE6a1joKxZPluNNt9RqSw,3375
517
- model_compression_toolkit/xquant/keras/keras_report_utils.py,sha256=Yk-VpyNYi5NWKTVYz-alfLK0JvM9CZDwGXBLu6HNJtI,2987
517
+ model_compression_toolkit/xquant/keras/facade_xquant_report.py,sha256=uf5szQE2QY86It_3VsBrKxYN5fuQddCCpiUBa6u5gFo,3402
518
+ model_compression_toolkit/xquant/keras/keras_report_utils.py,sha256=zUvhqehKKRHEkk6y8g1xQH47b6fTMuPy6stGEZ6mI24,3081
518
519
  model_compression_toolkit/xquant/keras/model_analyzer.py,sha256=WXi9BPI9_TzRWn50lM1i-6cwPPRW0p43Shg_xpHFclU,6521
519
520
  model_compression_toolkit/xquant/keras/similarity_functions.py,sha256=P2qMJAo94Sz_BCao-bnhEeewKtjeLLDDH2r9luDXJ04,2710
520
- model_compression_toolkit/xquant/keras/tensorboard_utils.py,sha256=I1JMlSYe8eoYBpvHmc7H08iC9jdwgAWT4O5c7SMFOfc,4230
521
+ model_compression_toolkit/xquant/keras/tensorboard_utils.py,sha256=635ZcK6_5jdpa7G6Tjq0hkveEYLJQyYRXCFCKL0EioM,9163
521
522
  model_compression_toolkit/xquant/pytorch/__init__.py,sha256=ycb1Xt7PtixY2Uabr94JGSwBMcct66O8ZMVf3Qa3ud8,719
522
523
  model_compression_toolkit/xquant/pytorch/dataset_utils.py,sha256=KFKiFkhIPpEr1ZH5jekZFrgs20VzzKVxSV9YMgH68yI,2894
523
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py,sha256=g5uHlFW9vECkTsrgUs8iohbCCQ4_9tPUcoUv1QZH9uI,3146
524
+ model_compression_toolkit/xquant/pytorch/facade_xquant_report.py,sha256=GGx0YTw_Z22x0IJ_WJmF5W6jWjf10fuy8bwDIaq7KC4,3173
524
525
  model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-ihJBLy5Cic-MQiUM_ZGV6SCXoNdscE,5549
525
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=yrZNVRm2IRU7r7R-hjS2lOQ6wvEEvbeunvf2jKoWjXk,3277
526
+ model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
526
527
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
527
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=eyMoXt5o5EnMr6d-rpCwQdX5mAiYiymvbgKv4tf7-a0,4576
528
- mct_nightly-2.1.0.20240807.445.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
529
- mct_nightly-2.1.0.20240807.445.dist-info/METADATA,sha256=bMkhTL4ymQUwdgNfwOSYLZK8Mt63mh1f5VaKqLqOMuQ,19718
530
- mct_nightly-2.1.0.20240807.445.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
531
- mct_nightly-2.1.0.20240807.445.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
532
- mct_nightly-2.1.0.20240807.445.dist-info/RECORD,,
528
+ model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=yjghWXxqOtT-QXoXBOuJyh45yUpFI0pKjdDegum2i68,9705
529
+ mct_nightly-2.1.0.20240808.431.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
530
+ mct_nightly-2.1.0.20240808.431.dist-info/METADATA,sha256=TLgM4kXLvSU1OF14zuk7f_yLnXKJpufQZEBxyJtWYx8,19718
531
+ mct_nightly-2.1.0.20240808.431.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
532
+ mct_nightly-2.1.0.20240808.431.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
533
+ mct_nightly-2.1.0.20240808.431.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.1.0.20240807.000445"
30
+ __version__ = "2.1.0.20240808.000431"
@@ -130,4 +130,17 @@ GPTQ_HESSIAN_NUM_SAMPLES = 32
130
130
  MP_DEFAULT_NUM_SAMPLES = 32
131
131
 
132
132
  # Pruning constants
133
- PRUNING_NUM_SCORE_APPROXIMATIONS = 32
133
+ PRUNING_NUM_SCORE_APPROXIMATIONS = 32
134
+
135
+ # Scheduling information fields
136
+ OPERATORS_SCHEDULING = 'operators_scheduling'
137
+ MAX_CUT = 'max_cut'
138
+ CUTS = 'cuts'
139
+ FUSED_NODES_MAPPING = 'fused_nodes_mapping'
140
+ OP_ORDER = 'op_order'
141
+ OP_RECORD = 'op_record'
142
+ MEM_ELEMENTS = 'mem_elements'
143
+ SHAPE = 'shape'
144
+ NODE_NAME = 'node_name'
145
+ TOTAL_SIZE = 'total_size'
146
+ NODE_OUTPUT_INDEX = 'node_output_index'
@@ -0,0 +1,135 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Dict, List
17
+
18
+ from model_compression_toolkit.core.common import Graph, BaseNode
19
+ from model_compression_toolkit.core.common.graph.base_graph import OutTensor
20
+
21
+
22
+ class FusedLayerType:
23
+ """
24
+ Used to represent the type of fused layers, since __name__
25
+ is accessed when the graph is displayed.
26
+ """
27
+ def __init__(self):
28
+ self.__name__ = 'FusedLayer'
29
+ class GraphFuser:
30
+
31
+ def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
32
+ """
33
+ GraphFuser is responsible for fusing nodes in a networkx graph.
34
+ The fusion process involves:
35
+ 1. Creating new fused nodes to represent these groups.
36
+ 2. Updating the graph structure to replace the original nodes with fused nodes.
37
+ 3. Maintaining mapping mapping of original node names to their fused node names.
38
+
39
+ Args:
40
+ graph: Graph to sue its nodes.
41
+
42
+ Returns:
43
+ Mapping of original node names to their fused node names
44
+ """
45
+ fused_nodes_mapping = {}
46
+ # Iterate through each group of nodes to be fused
47
+ for fused_nodes_list in graph.fused_nodes:
48
+ new_fused_node = self._create_fused_node(fused_nodes_list)
49
+ self._replace_nodes_with_fused_node(graph, fused_nodes_list, new_fused_node)
50
+ # Update the mapping to keep track of which original nodes are now part of which fused nodes
51
+ for node in fused_nodes_list:
52
+ fused_nodes_mapping[node.name] = new_fused_node.name
53
+ return fused_nodes_mapping
54
+
55
+ def _create_fused_node(self, nodes: List[BaseNode]) -> BaseNode:
56
+ """
57
+ Create a new node that represents the fusion of the given nodes.
58
+
59
+ Args:
60
+ nodes: Nodes to create the fuse node that contain them.
61
+
62
+ Returns:
63
+ Node that represents the nodes to be fused.
64
+ """
65
+ # Create a new node with a name that reflects its components
66
+ # Use the input shape of the first node and output shape of the last node
67
+ fused_node = BaseNode(name='FusedNode_' + '_'.join([node.name for node in nodes]),
68
+ framework_attr={},
69
+ input_shape=nodes[0].input_shape,
70
+ output_shape=nodes[-1].output_shape,
71
+ weights={},
72
+ layer_class=FusedLayerType)
73
+
74
+ # Preserve the final activation quantization configuration
75
+ # This is important for maintaining the correct behavior of the fused node
76
+ fused_node.final_activation_quantization_cfg = nodes[-1].final_activation_quantization_cfg
77
+
78
+ return fused_node
79
+
80
+ def _replace_nodes_with_fused_node(self,
81
+ graph: Graph,
82
+ nodes_to_fuse: List[BaseNode],
83
+ fused_node: BaseNode):
84
+ """
85
+ Replace the specified nodes in the graph with a new fused node.
86
+
87
+ Args:
88
+ graph: Graph to replace the nodes_to_fuse with fused_node
89
+ nodes_to_fuse: List of nodes to replace with a new fused node.
90
+ fused_node: Node to add instead of nodes in fused_node.
91
+
92
+ """
93
+ if not nodes_to_fuse:
94
+ return
95
+
96
+ first_node = nodes_to_fuse[0]
97
+ last_node = nodes_to_fuse[-1]
98
+
99
+ # Update incoming edges: Connect predecessors of the first node to the fused node
100
+ for predecessor in graph.get_prev_nodes(first_node):
101
+ e_attr = graph.get_edge_data(predecessor, first_node)
102
+ graph.add_edge(predecessor, fused_node, **(e_attr[0]))
103
+ graph.remove_edge(predecessor, first_node)
104
+
105
+ # Update outgoing edges: Connect the fused node to successors of the last node
106
+ for successor in graph.get_next_nodes(last_node):
107
+ e_attr = graph.get_edge_data(last_node, successor)
108
+ graph.add_edge(fused_node, successor, **(e_attr[0]))
109
+ graph.remove_edge(last_node, successor)
110
+
111
+ # Remove internal edges between fused nodes
112
+ # This step is necessary to maintain graph consistency
113
+ for current_node in nodes_to_fuse[:-1]:
114
+ subsequent_nodes = graph.get_next_nodes(current_node)
115
+ for next_node in subsequent_nodes:
116
+ assert next_node in nodes_to_fuse # Ensure we're not removing edges outside the fusion
117
+ graph.remove_edge(current_node, next_node)
118
+
119
+ # Handle the case where fused nodes are part of the graph's outputs
120
+ graph_output_tensors = graph.get_outputs()
121
+ graph_output_nodes = [ot.node for ot in graph_output_tensors]
122
+ for node in nodes_to_fuse:
123
+ if node in graph_output_nodes:
124
+ # If a fused node was an output, update the graph's outputs to use the new fused node
125
+ node_to_remove_index = graph_output_nodes.index(node)
126
+ graph_output_tensors[node_to_remove_index] = OutTensor(node=fused_node,
127
+ node_out_index=graph_output_tensors[
128
+ node_to_remove_index].node_out_index)
129
+ graph.remove_node(node, new_graph_outputs=graph_output_tensors)
130
+ else:
131
+ # Remove the original node from the graph
132
+ graph.remove_node(node)
133
+
134
+ # Finally, add the new fused node to the graph
135
+ graph.add_node(fused_node)
@@ -12,13 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from collections import namedtuple
16
+
15
17
  from typing import Tuple, List
16
18
 
19
+ from model_compression_toolkit.constants import OPERATORS_SCHEDULING, MAX_CUT, CUTS, FUSED_NODES_MAPPING
17
20
  from model_compression_toolkit.core.common import BaseNode
18
21
  from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut
19
22
  from model_compression_toolkit.core.common.graph.memory_graph.max_cut_astar import MaxCutAstar
20
23
  from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph
21
24
 
25
+ SchedulerInfo = namedtuple('SchedulerInfo', [OPERATORS_SCHEDULING, MAX_CUT, CUTS, FUSED_NODES_MAPPING])
22
26
 
23
27
  def compute_graph_max_cut(memory_graph: MemoryGraph,
24
28
  n_iter: int = 50,
@@ -25,7 +25,8 @@ class DebugConfig:
25
25
  """
26
26
  def __init__(self,
27
27
  analyze_similarity: bool = False,
28
- network_editor: List[EditRule] = []):
28
+ network_editor: List[EditRule] = [],
29
+ simulate_scheduler: bool = False):
29
30
  """
30
31
 
31
32
  Args:
@@ -33,6 +34,8 @@ class DebugConfig:
33
34
  analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is
34
35
  enabled) or not. Can be used to pinpoint problematic layers in the quantization process.
35
36
  network_editor (List[EditRule]): A list of rules and actions to edit the network for quantization.
37
+ simulate_scheduler (bool): Simulate scheduler behaviour to compute operators order and cuts.
36
38
  """
37
39
  self.analyze_similarity = analyze_similarity
38
40
  self.network_editor = network_editor
41
+ self.simulate_scheduler = simulate_scheduler
@@ -26,9 +26,11 @@ from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
26
26
  from tensorboard.compat.proto.graph_pb2 import GraphDef
27
27
  from tensorboard.compat.proto.node_def_pb2 import NodeDef
28
28
  from tensorboard.compat.proto.step_stats_pb2 import StepStats, NodeExecStats, DeviceStepStats, AllocatorMemoryUsed
29
- from tensorboard.compat.proto.summary_pb2 import HistogramProto
29
+ from tensorboard.compat.proto.summary_pb2 import HistogramProto, SummaryMetadata
30
30
  from tensorboard.compat.proto.summary_pb2 import Summary
31
+ from tensorboard.compat.proto.tensor_pb2 import TensorProto
31
32
  from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
33
+ from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
32
34
  from tensorboard.summary.writer.event_file_writer import EventFileWriter
33
35
  from typing import List, Any, Dict
34
36
  from networkx import topological_sort
@@ -497,6 +499,32 @@ class TensorboardWriter(object):
497
499
  er.add_event(event)
498
500
  er.flush()
499
501
 
502
+ def add_text(self,
503
+ text: str,
504
+ main_tag_name: str):
505
+ """
506
+ Add a text summary to the TensorBoard log.
507
+
508
+ Args:
509
+ text: The text content to be added to the summary.
510
+ main_tag_name: The name of the tag under which the text will be grouped in TensorBoard.
511
+
512
+ """
513
+ plugin_data = SummaryMetadata.PluginData(
514
+ plugin_name="text", content=TextPluginData(version=0).SerializeToString()
515
+ )
516
+ smd = SummaryMetadata(plugin_data=plugin_data)
517
+ tensor = TensorProto(
518
+ dtype="DT_STRING",
519
+ string_val=[text.encode(encoding="utf_8")],
520
+ tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
521
+ )
522
+ event = Event(summary=Summary(value=[Summary.Value(tag=main_tag_name, metadata=smd, tensor=tensor)]))
523
+
524
+ # Get the event writer for this tag name
525
+ er = self.__get_event_writer_by_tag_name(main_tag_name)
526
+ er.add_event(event)
527
+ er.flush()
500
528
 
501
529
  def init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
502
530
  """
@@ -12,13 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from collections import namedtuple
15
16
 
17
+ import copy
16
18
 
17
19
  from typing import Callable, Tuple, Any, List, Dict
18
20
 
19
21
  import numpy as np
20
22
 
21
23
  from model_compression_toolkit.core.common import FrameworkInfo
24
+ from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser
25
+
26
+ from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import compute_graph_max_cut, \
27
+ SchedulerInfo
28
+ from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph
22
29
  from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
23
30
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import \
24
31
  requires_mixed_precision
@@ -174,7 +181,20 @@ def core_runner(in_model: Any,
174
181
  if tb_w is not None:
175
182
  finalize_bitwidth_in_tb(tb_w, weights_conf_nodes_bitwidth, activation_conf_nodes_bitwidth)
176
183
 
177
- return tg, bit_widths_config, hessian_info_service
184
+ scheduler_info = None
185
+ if core_config.debug_config.simulate_scheduler:
186
+ graph_to_fuse = copy.deepcopy(tg)
187
+ fused_nodes_mapping = GraphFuser().create_fused_graph(graph_to_fuse)
188
+ memory_graph = MemoryGraph(graph_to_fuse)
189
+ schedule, max_cut, cuts = compute_graph_max_cut(memory_graph)
190
+ scheduler_info = SchedulerInfo(
191
+ operators_scheduling=schedule,
192
+ max_cut=float(max_cut),
193
+ cuts=cuts,
194
+ fused_nodes_mapping=fused_nodes_mapping
195
+ )
196
+
197
+ return tg, bit_widths_config, hessian_info_service, scheduler_info
178
198
 
179
199
 
180
200
  def _set_final_resource_utilization(graph: Graph,
@@ -31,7 +31,7 @@ from model_compression_toolkit.core.runner import core_runner
31
31
  from model_compression_toolkit.gptq.runner import gptq_runner
32
32
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
33
33
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
34
- from model_compression_toolkit.metadata import get_versions_dict
34
+ from model_compression_toolkit.metadata import get_versions_dict, create_model_metadata
35
35
 
36
36
  LR_DEFAULT = 0.15
37
37
  LR_REST_DEFAULT = 1e-4
@@ -208,15 +208,15 @@ if FOUND_TF:
208
208
 
209
209
  fw_impl = GPTQKerasImplemantation()
210
210
 
211
- tg, bit_widths_config, hessian_info_service = core_runner(in_model=in_model,
212
- representative_data_gen=representative_data_gen,
213
- core_config=core_config,
214
- fw_info=DEFAULT_KERAS_INFO,
215
- fw_impl=fw_impl,
216
- tpc=target_platform_capabilities,
217
- target_resource_utilization=target_resource_utilization,
218
- tb_w=tb_w,
219
- running_gptq=True)
211
+ tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
212
+ representative_data_gen=representative_data_gen,
213
+ core_config=core_config,
214
+ fw_info=DEFAULT_KERAS_INFO,
215
+ fw_impl=fw_impl,
216
+ tpc=target_platform_capabilities,
217
+ target_resource_utilization=target_resource_utilization,
218
+ tb_w=tb_w,
219
+ running_gptq=True)
220
220
 
221
221
  float_graph = copy.deepcopy(tg)
222
222
 
@@ -242,7 +242,9 @@ if FOUND_TF:
242
242
 
243
243
  exportable_model, user_info = get_exportable_keras_model(tg_gptq)
244
244
  if target_platform_capabilities.tp_model.add_metadata:
245
- exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
245
+ exportable_model = add_metadata(exportable_model,
246
+ create_model_metadata(tpc=target_platform_capabilities,
247
+ scheduling_info=scheduling_info))
246
248
  return exportable_model, user_info
247
249
 
248
250
  else:
@@ -31,7 +31,7 @@ from model_compression_toolkit.core.analyzer import analyzer_model_quantization
31
31
  from model_compression_toolkit.core import CoreConfig
32
32
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
33
33
  MixedPrecisionQuantizationConfig
34
- from model_compression_toolkit.metadata import get_versions_dict
34
+ from model_compression_toolkit.metadata import get_versions_dict, create_model_metadata
35
35
 
36
36
  LR_DEFAULT = 1e-4
37
37
  LR_REST_DEFAULT = 1e-4
@@ -177,15 +177,15 @@ if FOUND_TORCH:
177
177
  # ---------------------- #
178
178
  # Core Runner
179
179
  # ---------------------- #
180
- graph, bit_widths_config, hessian_info_service = core_runner(in_model=model,
181
- representative_data_gen=representative_data_gen,
182
- core_config=core_config,
183
- fw_info=DEFAULT_PYTORCH_INFO,
184
- fw_impl=fw_impl,
185
- tpc=target_platform_capabilities,
186
- target_resource_utilization=target_resource_utilization,
187
- tb_w=tb_w,
188
- running_gptq=True)
180
+ graph, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=model,
181
+ representative_data_gen=representative_data_gen,
182
+ core_config=core_config,
183
+ fw_info=DEFAULT_PYTORCH_INFO,
184
+ fw_impl=fw_impl,
185
+ tpc=target_platform_capabilities,
186
+ target_resource_utilization=target_resource_utilization,
187
+ tb_w=tb_w,
188
+ running_gptq=True)
189
189
 
190
190
  float_graph = copy.deepcopy(graph)
191
191
 
@@ -212,7 +212,9 @@ if FOUND_TORCH:
212
212
 
213
213
  exportable_model, user_info = get_exportable_pytorch_model(graph_gptq)
214
214
  if target_platform_capabilities.tp_model.add_metadata:
215
- exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
215
+ exportable_model = add_metadata(exportable_model,
216
+ create_model_metadata(tpc=target_platform_capabilities,
217
+ scheduling_info=scheduling_info))
216
218
  return exportable_model, user_info
217
219
 
218
220