mct-nightly 2.2.0.20250115.152408__py3-none-any.whl → 2.2.0.20250117.527__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {mct_nightly-2.2.0.20250115.152408.dist-info → mct_nightly-2.2.0.20250117.527.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20250115.152408.dist-info → mct_nightly-2.2.0.20250117.527.dist-info}/RECORD +23 -23
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +12 -2
  5. model_compression_toolkit/core/common/graph/memory_graph/cut.py +2 -2
  6. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +17 -13
  7. model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py +5 -1
  8. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +2 -0
  9. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +2 -2
  10. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -3
  11. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +5 -3
  12. model_compression_toolkit/gptq/keras/quantization_facade.py +5 -2
  13. model_compression_toolkit/gptq/pytorch/quantization_facade.py +4 -2
  14. model_compression_toolkit/pruning/keras/pruning_facade.py +6 -3
  15. model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -3
  16. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -1
  17. model_compression_toolkit/ptq/pytorch/quantization_facade.py +5 -3
  18. model_compression_toolkit/qat/keras/quantization_facade.py +5 -4
  19. model_compression_toolkit/qat/pytorch/quantization_facade.py +7 -4
  20. model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +3 -3
  21. {mct_nightly-2.2.0.20250115.152408.dist-info → mct_nightly-2.2.0.20250117.527.dist-info}/LICENSE.md +0 -0
  22. {mct_nightly-2.2.0.20250115.152408.dist-info → mct_nightly-2.2.0.20250117.527.dist-info}/WHEEL +0 -0
  23. {mct_nightly-2.2.0.20250115.152408.dist-info → mct_nightly-2.2.0.20250117.527.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20250115.152408
3
+ Version: 2.2.0.20250117.527
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: Apache Software License
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=XhsbE4HIyMgv8e1ehZhkkE08uFoS_LDP9ZhVECCHkNM,1557
1
+ model_compression_toolkit/__init__.py,sha256=9wy-eBj_iVmaSe9zp5-Pq8QOeeSZS_srcqrDNVuDsuE,1557
2
2
  model_compression_toolkit/constants.py,sha256=i_R6uXBfO1ph_X6DNJych2x59SUojfJbn7dNjs_mZnc,3846
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -41,11 +41,11 @@ model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-
41
41
  model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=3el-A7j1oyoo1_9zq3faQp7IeRsFXFCvnrb3zZFXpU0,9803
42
42
  model_compression_toolkit/core/common/graph/memory_graph/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
43
43
  model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py,sha256=X6FK3C3y8ixFRPjC_wm3ClloCX8_06SOdA1TRi7o_LA,3800
44
- model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py,sha256=Na1lAoCJCSQw7XGYsV5xCZg762lbP6Y_uAhsUeLP0yM,2870
45
- model_compression_toolkit/core/common/graph/memory_graph/cut.py,sha256=GctdLnhsPJgY6UGwRcLNpKE8OLkfVWT3wgby2r9QDD4,2645
46
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256=ZBFIOBBRHuRsiEW31EMwCVb9J7dJo5XBShA_9nnkrRI,17521
44
+ model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py,sha256=S3m34BY9P8NPx1I4d9G94X1Zk93MobX5SOVmqipwCOE,3458
45
+ model_compression_toolkit/core/common/graph/memory_graph/cut.py,sha256=7Dfq4TVJIrnencHLJqjhxYKhY7ooUo_ml33WH2IIAgc,2576
46
+ model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256=-Gt4MTnQiyljQWtqMmYKtFKvtZBpj5cPH-Mf6n5Iimo,17753
47
47
  model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=ISD2BvJWj5mB91jrFjG8VQb0oOoLBoita_thCZWzCPI,4238
48
- model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=3OC8kMXuzBv-R7wWmKY-i1AQNAr5x3LBZ4aj7hHF-cQ,7791
48
+ model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=FCzK4HmX4lWI4qGoGv94wpGv7o6_f5wPBfeBPMerZ18,7752
49
49
  model_compression_toolkit/core/common/hessian/__init__.py,sha256=E7LK3K_1AwMCQokanNc1JODMwUKNOKmwXQiGQ7GO10I,1033
50
50
  model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=YynbVHdHH2gPlk1QHXH6GygIkXRZ9qxR14cpgKrHPT0,13238
51
51
  model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=1axmN0tjJSo_7hUr2d2KMv4y1pBi19cqWSQpi4BbdsA,1458
@@ -73,10 +73,10 @@ model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,s
73
73
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=UWgxzhKWFOoESLq0TFVz0M1PhkU9d9n6wccSA3RgUxk,7903
74
74
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
75
75
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=T5yVr7lay-6QLuTDBZNI1Ufj02EMBWuY_yHjC8eHx5I,3998
76
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=qkYrYORLL5wmdmCkEY3tDSgabsGYt3OaTDVsgHWYBfE,34885
76
+ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=eaZX_Sng1uBpqjKUKuWMQO8wUfnjoQJqEoGwPFD3gsw,35051
77
77
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=tfcbMppa5KP_brfkFWRiOX9LQVHGXJtlgxyAt9oDGuw,8529
78
78
  model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
79
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=ZvLxFIfMUPAyKKzPhJcuZyjjngLD9_1wWFU8e14vEbA,17176
79
+ model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=uhC0az5OVSfeYexcasoy0cT8ZOonFKIedk_1U-ZPLhA,17171
80
80
  model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
81
81
  model_compression_toolkit/core/common/network_editors/actions.py,sha256=nid0_j-Cn10xvmztT8yCKW_6uA7JEnom9SW9syx7wc0,19594
82
82
  model_compression_toolkit/core/common/network_editors/edit_network.py,sha256=dfgawi-nB0ocAJ0xcGn9E-Zv203oUnQLuMiXpX8vTgA,1748
@@ -159,7 +159,7 @@ model_compression_toolkit/core/keras/default_framework_info.py,sha256=IGEHKH3Icm
159
159
  model_compression_toolkit/core/keras/keras_implementation.py,sha256=HwbIR7x4t-TBNbWHVvVNFk8z-KFt6zM0LWAUXQuNZrk,31753
160
160
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
161
161
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
162
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=bbp9jn0pyxcVUkfm_356m-hY2IQUWe_QLz8kclDC7SQ,5453
162
+ model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=XBCmUrHy_fNQCfSjnXCpwuEtc7cda4hXySuiIzhFGqc,5696
163
163
  model_compression_toolkit/core/keras/tf_tensor_numpy.py,sha256=jzD8FGEEa8ZD7w8IpTRdp-Udf1MwOTgjg2XTS1Givic,2696
164
164
  model_compression_toolkit/core/keras/back2framework/__init__.py,sha256=rhIiXg_nBgUZ-baE3M6SzCuQbcnq4iebY1jtJBvKHOM,808
165
165
  model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha256=UIQgOOdexycrSKombTMJVvTthR7MlrCihoqM8Kg-rnE,2293
@@ -224,7 +224,7 @@ model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=NLdmiig5
224
224
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
225
225
  model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=7jzJ4TBKNwwQ9E7W-My8LkmYEJHHNn8weNuO1PCGS10,29830
226
226
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
227
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=3whyWyfMIkQYYV-NX6eSyMM2eKpmCnJJ00RqamZouRg,5374
227
+ model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=aIHl-dTAC4ISnWSKLD99c-1W3827vfRGyLjMBib-l3s,5618
228
228
  model_compression_toolkit/core/pytorch/utils.py,sha256=7VbgcLwtQvdEEc_AJgSOQ3U3KRKCICFPaBirN1fIQxg,3940
229
229
  model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN0elVWf3exgGYsi58imPoYDj5eYPeh6x4yfug,813
230
230
  model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,sha256=bwppTPRs6gL96nm7qPiKrNcBj4Krr0yEsOWjRF0aXmQ,2339
@@ -363,7 +363,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
363
363
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=2hzWzsbuVd5XcL85NM57YeOyHxRY0qMArKn8NvQ1UWw,7643
364
364
  model_compression_toolkit/gptq/keras/gptq_training.py,sha256=km9tcuugOkRvprGXQZrsq_GPtA3-7Du_-rnbR_Gyups,23228
365
365
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=zwoeHX67nJJ5-zYLjzvMXS9TLsy9BsizARbZiDVjVSA,4473
366
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=_3SM3aKJrSayArnOXVu8F5-XCsVmBzjNYHz9-3qRj4E,18534
366
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=onQSR1YPjQ6IZdqzeeqFMs3IeBT-nWLbI0yXuOkdpKs,18827
367
367
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
368
368
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
369
369
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -379,7 +379,7 @@ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=_07Zx_43bnNokwR5S8phI
379
379
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
380
380
  model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=WtehnyiYXdUXf8-uNpV0mdsalF7YF7eKnL7tcFrzZoE,19549
381
381
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
382
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=JuMzdeAaS2Ak2NdULsJpOoKju_Kv5L690-ftabr6quo,16631
382
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=Dkanqdv7Eo5lWRoa56aomU5VdH9yqA6zd8I4WE37hxk,16874
383
383
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
384
384
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
385
385
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
@@ -392,20 +392,20 @@ model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256
392
392
  model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py,sha256=DOlLc4C05TTQN0hZ7xRuqV6wgGp9r2xq7JYun_Hi5jM,8712
393
393
  model_compression_toolkit/pruning/__init__.py,sha256=lQMZS8G0pvR1LVi53nnJHNXgLNTan_MWMdwsVxhjrow,1106
394
394
  model_compression_toolkit/pruning/keras/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
395
- model_compression_toolkit/pruning/keras/pruning_facade.py,sha256=gc00ebAnJEygRETXPxnjfUYE6Ze8zWKVpduhjD0APLs,9072
395
+ model_compression_toolkit/pruning/keras/pruning_facade.py,sha256=-cFNawfLeH0VxYVsauByTvjajt1uiycrkBQ0xcWHQEg,9350
396
396
  model_compression_toolkit/pruning/pytorch/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
397
- model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=dPGN78I8ZQdcCj_R3DB1hszUJmyRvKFEzZehxjZk-Ro,9757
397
+ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=FmUQvT0T247XaLv8Y6AxBv1G3fCgvndmP1RQdiE3pSU,10044
398
398
  model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
399
399
  model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
400
400
  model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
401
- model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=EeigR6O00Ir4X8nB_T3KsKE939Pg2lbQf5S3VA0orPE,11336
401
+ model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=f8sa46eUNHmeaVs3huhZv14DHm5j1X-VInCYdI7nXAY,11567
402
402
  model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
403
- model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=oqXBlHee7L10heMP11WkiFCNgVAc6RqmDa2HZFWGK0U,9771
403
+ model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=p5FwojKaybYdsOUVI7qBNa7R8Nge3EXdu38Jf2jHr84,10021
404
404
  model_compression_toolkit/qat/__init__.py,sha256=AaC4KBha4jDW_tyg2SOxZaKh_idIz0gZtDK3_zxs64E,1241
405
405
  model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
406
406
  model_compression_toolkit/qat/common/qat_config.py,sha256=xtfVSoyELGXynHNrw86dB9FU3Inu0zwehc3wLrh7JvY,2918
407
407
  model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
408
- model_compression_toolkit/qat/keras/quantization_facade.py,sha256=cmFRLBVsyv-fyYzz3YY0y1opztcDvDF08Dj9tbHzWvc,17626
408
+ model_compression_toolkit/qat/keras/quantization_facade.py,sha256=DsFAl4FtnR4QQoztUAMD1FgL6DgcdK5jdTp0lk9MHLY,17793
409
409
  model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
410
410
  model_compression_toolkit/qat/keras/quantizer/base_keras_qat_weight_quantizer.py,sha256=EbIt4lMlh6cU4awFLMBp0IlZ2zUUp-WtnlW5Wn19FDM,1793
411
411
  model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
@@ -417,7 +417,7 @@ model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cc
417
417
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=lXeMPI-n24jbZDGrtOs5eQZ14QvmhFd0e7Y1_QRQxw0,8214
418
418
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=ZdZwMwLa1Ws2eo3DiQYYTvPS1JfiswZL1xlQPtRnIgE,7067
419
419
  model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
420
- model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=yhMxF3Ah4PRjaTBlmVoRmoCX-pZ0mC9Bq9uslIo6Ud0,13780
420
+ model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=BvKYsLXyWvE3MXN7khYhBQXVLm-r-C17XpJkEwit7KM,14095
421
421
  model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
422
422
  model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_weight_quantizer.py,sha256=gjzrnBAZr5c_OrDpSjxpQYa_jKImv7ll52cng07_2oE,1813
423
423
  model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=lM10cGUkkTDtRyLLdWj5Rk0cgvcxp0uaCseyvrnk_Vg,5752
@@ -430,7 +430,7 @@ model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py,sha2
430
430
  model_compression_toolkit/target_platform_capabilities/__init__.py,sha256=9ZcT9JVlYzy8k7MlAXhj086gn6SxlGFsjMvy7ubcnfc,1392
431
431
  model_compression_toolkit/target_platform_capabilities/constants.py,sha256=BFSgDwYWU1sZShjoW2S7eH3AI0D4SqDOeOu_sQ971LE,1518
432
432
  model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROBiXEIB3TU-bAFrnL3qbAsb1yuWPBAQ_CLOJbYUU,1827
433
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=tbDBSSh7sJejDPOfLZ-riGnDfhPqBeIY4ZXqZjZd_eM,4136
433
+ model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=4ydTWWKv_PEOAFok2JtxFNj8rav-0IlqcXKF6lnhHNE,4157
434
434
  model_compression_toolkit/target_platform_capabilities/schema/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
435
435
  model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=PvO8eHxnb3A55gyExT5fZGnOUl3ce7BbbT5SPxCEXNo,541
436
436
  model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=vBkXxVJagm9JKB9cdm4Pvi7u_luriXUjvNn0-m8Zr0k,4653
@@ -523,8 +523,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
523
523
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
524
524
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
525
525
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
526
- mct_nightly-2.2.0.20250115.152408.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
527
- mct_nightly-2.2.0.20250115.152408.dist-info/METADATA,sha256=bRUgbaQMx5oDrEiJyTqJodZ-mKXrnM38lHCVuiWvSxA,26604
528
- mct_nightly-2.2.0.20250115.152408.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
529
- mct_nightly-2.2.0.20250115.152408.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
530
- mct_nightly-2.2.0.20250115.152408.dist-info/RECORD,,
526
+ mct_nightly-2.2.0.20250117.527.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
527
+ mct_nightly-2.2.0.20250117.527.dist-info/METADATA,sha256=ywy9ErTqUzchvEY5i9iwgRiAi2lr186UmPlZ59OADA4,26601
528
+ mct_nightly-2.2.0.20250117.527.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
529
+ mct_nightly-2.2.0.20250117.527.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
530
+ mct_nightly-2.2.0.20250117.527.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20250115.152408"
30
+ __version__ = "2.2.0.20250117.000527"
@@ -13,9 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from collections import namedtuple
16
-
17
16
  from typing import Tuple, List
18
17
 
18
+ from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.constants import OPERATORS_SCHEDULING, MAX_CUT, CUTS, FUSED_NODES_MAPPING
20
20
  from model_compression_toolkit.core.common import BaseNode
21
21
  from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut
@@ -49,7 +49,17 @@ def compute_graph_max_cut(memory_graph: MemoryGraph,
49
49
  it = 0
50
50
  while it < n_iter:
51
51
  estimate = (u_bound + l_bound) / 2
52
- schedule, max_cut_size, cuts = max_cut_astar.solve(estimate=estimate, iter_limit=astar_n_iter)
52
+ # Add a timeout of 5 minutes to the solver from the 2nd iteration.
53
+ try:
54
+ schedule, max_cut_size, cuts = max_cut_astar.solve(estimate=estimate, iter_limit=astar_n_iter,
55
+ time_limit=None if it == 0 else 300)
56
+ except TimeoutError:
57
+ if last_result[0] is None:
58
+ Logger.critical(f"Max-cut solver stopped on timeout in iteration {it} before finding a solution.") # pragma: no cover
59
+ else:
60
+ Logger.warning(f"Max-cut solver stopped on timeout in iteration {it}.")
61
+ return last_result
62
+
53
63
  if schedule is None:
54
64
  l_bound = estimate
55
65
  else:
@@ -67,7 +67,7 @@ class Cut:
67
67
  return False # pragma: no cover
68
68
 
69
69
  def __hash__(self):
70
- return hash((frozenset(self.op_order), frozenset(self.op_record), self.mem_elements))
70
+ return id(self)
71
71
 
72
72
  def __repr__(self):
73
- return f"<Cut: Nodes={[e.node_name for e in self.mem_elements.elements]}, size={self.memory_size()}>" # pragma: no cover
73
+ return f"<Cut: Nodes={[e.node_name for e in self.mem_elements.elements]}, size={self.memory_size()}>" # pragma: no cover
@@ -13,7 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import copy
16
- from typing import List, Tuple, Dict
16
+ from typing import List, Tuple, Dict, Set
17
+ from time import time
17
18
 
18
19
  from model_compression_toolkit.core.common import BaseNode
19
20
  from model_compression_toolkit.constants import DUMMY_TENSOR, DUMMY_NODE
@@ -122,7 +123,7 @@ class MaxCutAstar:
122
123
  self.target_cut = Cut([], set(), MemoryElements(elements={target_dummy_b, target_dummy_b2},
123
124
  total_size=0))
124
125
 
125
- def solve(self, estimate: float, iter_limit: int = 500) -> Tuple[List[BaseNode], float, List[Cut]]:
126
+ def solve(self, estimate: float, iter_limit: int = 500, time_limit: int = None) -> Tuple[List[BaseNode], float, List[Cut]]:
126
127
  """
127
128
  The AStar solver function. This method runs an AStar-like search on the memory graph,
128
129
  using the given estimate as a heuristic gap for solutions to consider.
@@ -131,6 +132,7 @@ class MaxCutAstar:
131
132
  estimate: Cut size estimation to consider larger size of nodes in each
132
133
  expansion step, in order to fasten the algorithm divergence towards a solution.
133
134
  iter_limit: An upper limit for the number of expansion steps that the algorithm preforms.
135
+ time_limit: Optional time limit to the solver. Defaults to None which means no limit.
134
136
 
135
137
  Returns: A solution (if found within the steps limit) which contains:
136
138
  - A schedule for computation of the model (List of nodes).
@@ -139,14 +141,17 @@ class MaxCutAstar:
139
141
 
140
142
  """
141
143
 
142
- open_list = [self.src_cut]
143
- closed_list = []
144
+ open_list = {self.src_cut}
145
+ closed_list = set()
144
146
  costs = {self.src_cut: self.src_cut.memory_size()}
145
147
  routes = {self.src_cut: [self.src_cut]}
146
148
 
147
149
  expansion_count = 0
148
150
 
151
+ t1 = time()
149
152
  while expansion_count < iter_limit and len(open_list) > 0:
153
+ if time_limit is not None and time() - t1 > time_limit:
154
+ raise TimeoutError
150
155
  # Choose next node to expand
151
156
  next_cut = self._get_cut_to_expand(open_list, costs, routes, estimate)
152
157
 
@@ -159,22 +164,21 @@ class MaxCutAstar:
159
164
 
160
165
  if self.is_pivot(next_cut):
161
166
  # Can clear all search history
162
- open_list = []
163
- closed_list = []
167
+ open_list.clear()
168
+ closed_list.clear()
164
169
  routes = {}
165
170
  else:
166
171
  # Can remove only next_cut and put it in closed_list
167
172
  open_list.remove(next_cut)
168
173
  del routes[next_cut]
169
- closed_list.append(next_cut)
174
+ closed_list.add(next_cut)
170
175
 
171
176
  # Expand the chosen cut
172
177
  expanded_cuts = self.expand(next_cut)
173
178
  expansion_count += 1
174
179
 
175
180
  # Only consider nodes that where not already visited
176
- expanded_cuts = [_c for _c in expanded_cuts if _c not in closed_list]
177
- for c in expanded_cuts:
181
+ for c in filter(lambda _c: _c not in closed_list, expanded_cuts):
178
182
  cost = self.accumulate(cut_cost, c.memory_size())
179
183
  if c not in open_list:
180
184
  self._update_expanded_node(c, cost, cut_route, open_list, costs, routes)
@@ -192,7 +196,7 @@ class MaxCutAstar:
192
196
  return None, 0, None # pragma: no cover
193
197
 
194
198
  @staticmethod
195
- def _update_expanded_node(cut: Cut, cost: float, route: List[Cut], open_list: List[Cut],
199
+ def _update_expanded_node(cut: Cut, cost: float, route: List[Cut], open_list: Set[Cut],
196
200
  costs: Dict[Cut, float], routes: Dict[Cut, List[Cut]]):
197
201
  """
198
202
  An auxiliary method for updating search data structures according to an expanded node.
@@ -201,16 +205,16 @@ class MaxCutAstar:
201
205
  cut: A cut to expand the search to.
202
206
  cost: The cost of the cut.
203
207
  route: The rout to the cut.
204
- open_list: The search open list.
208
+ open_list: The search open set.
205
209
  costs: The search utility mapping between cuts and their cost.
206
210
  routes: The search utility mapping between cuts and their routes.
207
211
 
208
212
  """
209
- open_list.append(cut)
213
+ open_list.add(cut)
210
214
  costs.update({cut: cost})
211
215
  routes.update({cut: [cut] + route})
212
216
 
213
- def _get_cut_to_expand(self, open_list: List[Cut], costs: Dict[Cut, float], routes: Dict[Cut, List[Cut]],
217
+ def _get_cut_to_expand(self, open_list: Set[Cut], costs: Dict[Cut, float], routes: Dict[Cut, List[Cut]],
214
218
  estimate: float) -> Cut:
215
219
  """
216
220
  An auxiliary method for finding a cut for expanding the search out of a set of potential cuts for expansion.
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import List
16
16
  from operator import getitem
17
+ from functools import cache
17
18
 
18
19
  from model_compression_toolkit.core.common import Graph, BaseNode
19
20
  from model_compression_toolkit.core.common.graph.edge import EDGE_SOURCE_INDEX
@@ -82,7 +83,6 @@ class MemoryGraph(DirectedBipartiteGraph):
82
83
  inputs_tensors_memory = [sum([t.total_size for t in self.operation_node_children(n)])
83
84
  for n in nodes if n in model_graph.get_inputs()]
84
85
 
85
- # TODO maxcut: why both inputs and outputs of each nodes, while the A* solves for node outputs only???
86
86
  nodes_total_memory = [sum([t.total_size for t in self.operation_node_children(n)] +
87
87
  [t.total_size for t in self.operation_node_parents(n)])
88
88
  for n in nodes if n not in model_graph.get_inputs()]
@@ -117,6 +117,7 @@ class MemoryGraph(DirectedBipartiteGraph):
117
117
  """
118
118
  self.sinks_b = [n for n in self.b_nodes if len(list(self.successors(n))) == 0]
119
119
 
120
+ @cache
120
121
  def activation_tensor_children(self, activation_tensor: ActivationMemoryTensor) -> List[BaseNode]:
121
122
  """
122
123
  Returns the children nodes of a side B node (activation tensor) in the bipartite graph.
@@ -129,6 +130,7 @@ class MemoryGraph(DirectedBipartiteGraph):
129
130
  """
130
131
  return [oe[1] for oe in self.out_edges(activation_tensor)]
131
132
 
133
+ @cache
132
134
  def activation_tensor_parents(self, activation_tensor: ActivationMemoryTensor) -> List[BaseNode]:
133
135
  """
134
136
  Returns the parents nodes of a side B node (activation tensor) in the bipartite graph.
@@ -141,6 +143,7 @@ class MemoryGraph(DirectedBipartiteGraph):
141
143
  """
142
144
  return [ie[0] for ie in self.in_edges(activation_tensor)]
143
145
 
146
+ @cache
144
147
  def operation_node_children(self, op_node: BaseNode) -> List[ActivationMemoryTensor]:
145
148
  """
146
149
  Returns the children nodes of a side A node (operation) in the bipartite graph.
@@ -153,6 +156,7 @@ class MemoryGraph(DirectedBipartiteGraph):
153
156
  """
154
157
  return [oe[1] for oe in self.out_edges(op_node)]
155
158
 
159
+ @cache
156
160
  def operation_node_parents(self, op_node: BaseNode) -> List[ActivationMemoryTensor]:
157
161
  """
158
162
  Returns the parents nodes of a side A node (operation) in the bipartite graph.
@@ -17,6 +17,7 @@ from copy import deepcopy
17
17
  from enum import Enum, auto
18
18
  from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence, Set
19
19
 
20
+ from model_compression_toolkit.logger import Logger
20
21
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
21
22
  from model_compression_toolkit.core import FrameworkInfo
22
23
  from model_compression_toolkit.core.common import Graph, BaseNode
@@ -169,6 +170,7 @@ class ResourceUtilizationCalculator:
169
170
  w_total, *_ = self.compute_weights_utilization(target_criterion, bitwidth_mode, w_qcs)
170
171
 
171
172
  if {RUTarget.ACTIVATION, RUTarget.TOTAL}.intersection(ru_targets):
173
+ Logger.warning("Using an experimental feature max-cut for activation memory utilization estimation.")
172
174
  a_total = self.compute_activations_utilization(target_criterion, bitwidth_mode, act_qcs)
173
175
 
174
176
  ru = ResourceUtilization()
@@ -16,7 +16,7 @@
16
16
  import numpy as np
17
17
  from pulp import *
18
18
  from tqdm import tqdm
19
- from typing import Dict, Tuple, Set, Any
19
+ from typing import Dict, Tuple, Any, Optional
20
20
 
21
21
  from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
@@ -182,7 +182,7 @@ def _add_ru_constraints(search_manager: MixedPrecisionSearchManager,
182
182
  target_resource_utilization: ResourceUtilization,
183
183
  indicators_matrix: np.ndarray,
184
184
  lp_problem: LpProblem,
185
- non_conf_ru_dict: Optional[Dict[RUTarget, np.ndarray]]):
185
+ non_conf_ru_dict: Dict[RUTarget, np.ndarray]):
186
186
  """
187
187
  Adding targets constraints for the Lp problem for the given target resource utilization.
188
188
  The update to the Lp problem object is done inplace.
@@ -13,13 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Callable
16
+ from typing import Callable, Union
17
17
  from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, CoreConfig
18
18
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
19
19
  from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.constants import TENSORFLOW
21
21
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
22
22
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
23
+ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
23
24
  from model_compression_toolkit.verify_packages import FOUND_TF
24
25
 
25
26
  if FOUND_TF:
@@ -38,7 +39,7 @@ if FOUND_TF:
38
39
  representative_data_gen: Callable,
39
40
  core_config: CoreConfig = CoreConfig(
40
41
  mixed_precision_config=MixedPrecisionQuantizationConfig()),
41
- target_platform_capabilities: TargetPlatformCapabilities = KERAS_DEFAULT_TPC
42
+ target_platform_capabilities: Union[TargetPlatformCapabilities, str] = KERAS_DEFAULT_TPC
42
43
  ) -> ResourceUtilization:
43
44
  """
44
45
  Computes resource utilization data that can be used to calculate the desired target resource utilization
@@ -50,7 +51,7 @@ if FOUND_TF:
50
51
  in_model (Model): Keras model to quantize.
51
52
  representative_data_gen (Callable): Dataset used for calibration.
52
53
  core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision of how the model should be quantized.
53
- target_platform_capabilities (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to optimize the Keras model according to.
54
+ target_platform_capabilities (Union[TargetPlatformCapabilities, str]): FrameworkQuantizationCapabilities to optimize the Keras model according to.
54
55
 
55
56
  Returns:
56
57
 
@@ -81,6 +82,7 @@ if FOUND_TF:
81
82
 
82
83
  fw_impl = KerasImplementation()
83
84
 
85
+ target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
84
86
  # Attach tpc model to framework
85
87
  attach2keras = AttachTpcToKeras()
86
88
  target_platform_capabilities = attach2keras.attach(
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Callable
16
+ from typing import Callable, Union
17
17
 
18
18
  from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.constants import PYTORCH
@@ -23,6 +23,7 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
23
23
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
24
24
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
25
25
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
26
+ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
26
27
  from model_compression_toolkit.verify_packages import FOUND_TORCH
27
28
 
28
29
  if FOUND_TORCH:
@@ -40,7 +41,7 @@ if FOUND_TORCH:
40
41
  def pytorch_resource_utilization_data(in_model: Module,
41
42
  representative_data_gen: Callable,
42
43
  core_config: CoreConfig = CoreConfig(),
43
- target_platform_capabilities: TargetPlatformCapabilities= PYTORCH_DEFAULT_TPC
44
+ target_platform_capabilities: Union[TargetPlatformCapabilities, str] = PYTORCH_DEFAULT_TPC
44
45
  ) -> ResourceUtilization:
45
46
  """
46
47
  Computes resource utilization data that can be used to calculate the desired target resource utilization for mixed-precision quantization.
@@ -50,7 +51,7 @@ if FOUND_TORCH:
50
51
  in_model (Model): PyTorch model to quantize.
51
52
  representative_data_gen (Callable): Dataset used for calibration.
52
53
  core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision
53
- target_platform_capabilities (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to optimize the PyTorch model according to.
54
+ target_platform_capabilities (Union[TargetPlatformCapabilities, str]): FrameworkQuantizationCapabilities to optimize the PyTorch model according to.
54
55
 
55
56
  Returns:
56
57
 
@@ -81,6 +82,7 @@ if FOUND_TORCH:
81
82
 
82
83
  fw_impl = PytorchImplementation()
83
84
 
85
+ target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
84
86
  # Attach tpc model to framework
85
87
  attach2pytorch = AttachTpcToPytorch()
86
88
  target_platform_capabilities = (
@@ -25,6 +25,7 @@ from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_
25
25
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
26
26
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
27
27
  AttachTpcToKeras
28
+ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
28
29
  from model_compression_toolkit.verify_packages import FOUND_TF
29
30
  from model_compression_toolkit.core.common.user_info import UserInformation
30
31
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig, \
@@ -156,7 +157,8 @@ if FOUND_TF:
156
157
  gptq_representative_data_gen: Callable = None,
157
158
  target_resource_utilization: ResourceUtilization = None,
158
159
  core_config: CoreConfig = CoreConfig(),
159
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
160
+ target_platform_capabilities: Union[TargetPlatformCapabilities, str]
161
+ = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
160
162
  """
161
163
  Quantize a trained Keras model using post-training quantization. The model is quantized using a
162
164
  symmetric constraint quantization thresholds (power of two).
@@ -180,7 +182,7 @@ if FOUND_TF:
180
182
  gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
181
183
  target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired.
182
184
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
183
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
185
+ target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the Keras model according to.
184
186
 
185
187
  Returns:
186
188
 
@@ -241,6 +243,7 @@ if FOUND_TF:
241
243
 
242
244
  fw_impl = GPTQKerasImplemantation()
243
245
 
246
+ target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
244
247
  # Attach tpc model to framework
245
248
  attach2keras = AttachTpcToKeras()
246
249
  framework_platform_capabilities = attach2keras.attach(
@@ -32,6 +32,7 @@ from model_compression_toolkit.gptq.runner import gptq_runner
32
32
  from model_compression_toolkit.logger import Logger
33
33
  from model_compression_toolkit.metadata import create_model_metadata
34
34
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
35
+ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
35
36
  from model_compression_toolkit.verify_packages import FOUND_TORCH
36
37
 
37
38
 
@@ -145,7 +146,7 @@ if FOUND_TORCH:
145
146
  core_config: CoreConfig = CoreConfig(),
146
147
  gptq_config: GradientPTQConfig = None,
147
148
  gptq_representative_data_gen: Callable = None,
148
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
149
+ target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC):
149
150
  """
150
151
  Quantize a trained Pytorch module using post-training quantization.
151
152
  By default, the module is quantized using a symmetric constraint quantization thresholds
@@ -169,7 +170,7 @@ if FOUND_TORCH:
169
170
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
170
171
  gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
171
172
  gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
172
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
173
+ target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the PyTorch model according to.
173
174
 
174
175
  Returns:
175
176
  A quantized module and information the user may need to handle the quantized module.
@@ -214,6 +215,7 @@ if FOUND_TORCH:
214
215
 
215
216
  fw_impl = GPTQPytorchImplemantation()
216
217
 
218
+ target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
217
219
  # Attach tpc model to framework
218
220
  attach2pytorch = AttachTpcToPytorch()
219
221
  framework_quantization_capabilities = attach2pytorch.attach(target_platform_capabilities,
@@ -13,11 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Callable, Tuple
16
+ from typing import Callable, Tuple, Union
17
17
 
18
18
  from model_compression_toolkit import get_target_platform_capabilities
19
19
  from model_compression_toolkit.constants import TENSORFLOW
20
20
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
21
+ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
21
22
  from model_compression_toolkit.verify_packages import FOUND_TF
22
23
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
23
24
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
@@ -43,7 +44,8 @@ if FOUND_TF:
43
44
  target_resource_utilization: ResourceUtilization,
44
45
  representative_data_gen: Callable,
45
46
  pruning_config: PruningConfig = PruningConfig(),
46
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
47
+ target_platform_capabilities: Union[TargetPlatformCapabilities, str]
48
+ = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
47
49
  """
48
50
  Perform structured pruning on a Keras model to meet a specified target resource utilization.
49
51
  This function prunes the provided model according to the target resource utilization by grouping and pruning
@@ -61,7 +63,7 @@ if FOUND_TF:
61
63
  target_resource_utilization (ResourceUtilization): The target Key Performance Indicators to be achieved through pruning.
62
64
  representative_data_gen (Callable): A function to generate representative data for pruning analysis.
63
65
  pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
64
- target_platform_capabilities (FrameworkQuantizationCapabilities): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC.
66
+ target_platform_capabilities (Union[TargetPlatformCapabilities, str]): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC.
65
67
 
66
68
  Returns:
67
69
  Tuple[Model, PruningInfo]: A tuple containing the pruned Keras model and associated pruning information.
@@ -112,6 +114,7 @@ if FOUND_TF:
112
114
  # Instantiate the Keras framework implementation.
113
115
  fw_impl = PruningKerasImplementation()
114
116
 
117
+ target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
115
118
  # Attach tpc model to framework
116
119
  attach2keras = AttachTpcToKeras()
117
120
  target_platform_capabilities = attach2keras.attach(target_platform_capabilities)
@@ -13,10 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Callable, Tuple
16
+ from typing import Callable, Tuple, Union
17
17
  from model_compression_toolkit import get_target_platform_capabilities
18
18
  from model_compression_toolkit.constants import PYTORCH
19
19
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
20
+ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
20
21
  from model_compression_toolkit.verify_packages import FOUND_TORCH
21
22
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
22
23
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
@@ -47,7 +48,8 @@ if FOUND_TORCH:
47
48
  target_resource_utilization: ResourceUtilization,
48
49
  representative_data_gen: Callable,
49
50
  pruning_config: PruningConfig = PruningConfig(),
50
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
51
+ target_platform_capabilities: Union[TargetPlatformCapabilities, str]
52
+ = DEFAULT_PYOTRCH_TPC) -> \
51
53
  Tuple[Module, PruningInfo]:
52
54
  """
53
55
  Perform structured pruning on a Pytorch model to meet a specified target resource utilization.
@@ -66,7 +68,7 @@ if FOUND_TORCH:
66
68
  target_resource_utilization (ResourceUtilization): Key Performance Indicators specifying the pruning targets.
67
69
  representative_data_gen (Callable): A function to generate representative data for pruning analysis.
68
70
  pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
69
- target_platform_capabilities (TargetPlatformCapabilities): Platform-specific constraints and capabilities.
71
+ target_platform_capabilities (Union[TargetPlatformCapabilities, str]): Platform-specific constraints and capabilities.
70
72
  Defaults to DEFAULT_PYTORCH_TPC.
71
73
 
72
74
  Returns:
@@ -118,6 +120,7 @@ if FOUND_TORCH:
118
120
  # Instantiate the Pytorch framework implementation.
119
121
  fw_impl = PruningPytorchImplementation()
120
122
 
123
+ target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
121
124
  # Attach TPC to framework
122
125
  attach2pytorch = AttachTpcToPytorch()
123
126
  framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities)
@@ -23,6 +23,7 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
23
23
  from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.constants import TENSORFLOW
25
25
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
26
+ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
26
27
  from model_compression_toolkit.verify_packages import FOUND_TF
27
28
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
28
29
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
@@ -70,7 +71,7 @@ if FOUND_TF:
70
71
  representative_data_gen (Callable): Dataset used for calibration.
71
72
  target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired.
72
73
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
73
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
74
+ target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the Keras model according to.
74
75
 
75
76
  Returns:
76
77
 
@@ -137,6 +138,7 @@ if FOUND_TF:
137
138
 
138
139
  fw_impl = KerasImplementation()
139
140
 
141
+ target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
140
142
  attach2keras = AttachTpcToKeras()
141
143
  framework_platform_capabilities = attach2keras.attach(
142
144
  target_platform_capabilities,
@@ -14,12 +14,13 @@
14
14
  # ==============================================================================
15
15
  import copy
16
16
 
17
- from typing import Callable
17
+ from typing import Callable, Union
18
18
 
19
19
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
20
20
  from model_compression_toolkit.logger import Logger
21
21
  from model_compression_toolkit.constants import PYTORCH
22
22
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
23
+ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
23
24
  from model_compression_toolkit.verify_packages import FOUND_TORCH
24
25
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
25
26
  from model_compression_toolkit.core import CoreConfig
@@ -48,7 +49,7 @@ if FOUND_TORCH:
48
49
  representative_data_gen: Callable,
49
50
  target_resource_utilization: ResourceUtilization = None,
50
51
  core_config: CoreConfig = CoreConfig(),
51
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
52
+ target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC):
52
53
  """
53
54
  Quantize a trained Pytorch module using post-training quantization.
54
55
  By default, the module is quantized using a symmetric constraint quantization thresholds
@@ -67,7 +68,7 @@ if FOUND_TORCH:
67
68
  representative_data_gen (Callable): Dataset used for calibration.
68
69
  target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired.
69
70
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
70
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
71
+ target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the PyTorch model according to.
71
72
 
72
73
  Returns:
73
74
  A quantized module and information the user may need to handle the quantized module.
@@ -109,6 +110,7 @@ if FOUND_TORCH:
109
110
 
110
111
  fw_impl = PytorchImplementation()
111
112
 
113
+ target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
112
114
  # Attach tpc model to framework
113
115
  attach2pytorch = AttachTpcToPytorch()
114
116
  framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Callable
16
+ from typing import Callable, Union
17
17
  from functools import partial
18
18
 
19
19
  from model_compression_toolkit.core import CoreConfig
@@ -22,6 +22,7 @@ from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
23
23
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
24
24
  AttachTpcToKeras
25
+ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
25
26
  from model_compression_toolkit.verify_packages import FOUND_TF
26
27
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
27
28
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
@@ -49,7 +50,6 @@ if FOUND_TF:
49
50
  from model_compression_toolkit.core import common
50
51
  from model_compression_toolkit.core.common import BaseNode
51
52
  from model_compression_toolkit.constants import TENSORFLOW
52
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
53
53
  from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
54
54
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
55
55
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
@@ -92,7 +92,7 @@ if FOUND_TF:
92
92
  target_resource_utilization: ResourceUtilization = None,
93
93
  core_config: CoreConfig = CoreConfig(),
94
94
  qat_config: QATConfig = QATConfig(),
95
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
95
+ target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_KERAS_TPC):
96
96
  """
97
97
  Prepare a trained Keras model for quantization aware training. First the model quantization is optimized
98
98
  with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
@@ -114,7 +114,7 @@ if FOUND_TF:
114
114
  target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired.
115
115
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
116
116
  qat_config (QATConfig): QAT configuration
117
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
117
+ target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the Keras model according to.
118
118
 
119
119
  Returns:
120
120
 
@@ -188,6 +188,7 @@ if FOUND_TF:
188
188
 
189
189
  fw_impl = KerasImplementation()
190
190
 
191
+ target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
191
192
  attach2keras = AttachTpcToKeras()
192
193
  target_platform_capabilities = attach2keras.attach(
193
194
  target_platform_capabilities,
@@ -12,13 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Callable
15
+ from typing import Callable, Union
16
16
  from functools import partial
17
17
 
18
18
  from model_compression_toolkit.constants import PYTORCH
19
19
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
20
20
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
21
21
  AttachTpcToPytorch
22
+ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
22
23
  from model_compression_toolkit.verify_packages import FOUND_TORCH
23
24
 
24
25
  from model_compression_toolkit.core import CoreConfig
@@ -78,7 +79,8 @@ if FOUND_TORCH:
78
79
  target_resource_utilization: ResourceUtilization = None,
79
80
  core_config: CoreConfig = CoreConfig(),
80
81
  qat_config: QATConfig = QATConfig(),
81
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
82
+ target_platform_capabilities: Union[TargetPlatformCapabilities, str]
83
+ = DEFAULT_PYTORCH_TPC):
82
84
  """
83
85
  Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
84
86
  with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
@@ -100,7 +102,7 @@ if FOUND_TORCH:
100
102
  target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired.
101
103
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
102
104
  qat_config (QATConfig): QAT configuration
103
- target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Pytorch model according to.
105
+ target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the Pytorch model according to.
104
106
 
105
107
  Returns:
106
108
 
@@ -153,10 +155,11 @@ if FOUND_TORCH:
153
155
  tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
154
156
  fw_impl = PytorchImplementation()
155
157
 
158
+ target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
156
159
  # Attach tpc model to framework
157
160
  attach2pytorch = AttachTpcToPytorch()
158
161
  framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
159
- core_config.quantization_config.custom_tpc_opset_to_layer)
162
+ core_config.quantization_config.custom_tpc_opset_to_layer)
160
163
 
161
164
  # Ignore hessian scores service as we do not use it here
162
165
  tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
@@ -20,13 +20,13 @@ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_s
20
20
  import json
21
21
 
22
22
 
23
- def load_target_platform_model(tpc_obj_or_path: Union[TargetPlatformCapabilities, str]) -> TargetPlatformCapabilities:
23
+ def load_target_platform_capabilities(tpc_obj_or_path: Union[TargetPlatformCapabilities, str]) -> TargetPlatformCapabilities:
24
24
  """
25
25
  Parses the tpc input, which can be either a TargetPlatformCapabilities object
26
26
  or a string path to a JSON file.
27
27
 
28
28
  Parameters:
29
- tpc_obj_or_path (Union[TargetPlatformModel, str]): Input target platform model or path to .JSON file.
29
+ tpc_obj_or_path (Union[TargetPlatformCapabilities, str]): Input target platform model or path to .JSON file.
30
30
 
31
31
  Returns:
32
32
  TargetPlatformCapabilities: The parsed TargetPlatformCapabilities.
@@ -66,7 +66,7 @@ def load_target_platform_model(tpc_obj_or_path: Union[TargetPlatformCapabilities
66
66
  )
67
67
 
68
68
 
69
- def export_target_platform_model(model: TargetPlatformCapabilities, export_path: Union[str, Path]) -> None:
69
+ def export_target_platform_capabilities(model: TargetPlatformCapabilities, export_path: Union[str, Path]) -> None:
70
70
  """
71
71
  Exports a TargetPlatformCapabilities instance to a JSON file.
72
72