mct-nightly 2.2.0.20241022.507__py3-none-any.whl → 2.2.0.20241024.501__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/RECORD +38 -31
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +43 -29
  5. model_compression_toolkit/core/common/hessian/__init__.py +1 -1
  6. model_compression_toolkit/core/common/hessian/hessian_info_service.py +222 -371
  7. model_compression_toolkit/core/common/hessian/hessian_scores_request.py +27 -41
  8. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -10
  9. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +11 -9
  10. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +10 -6
  11. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +17 -15
  12. model_compression_toolkit/core/keras/data_util.py +67 -0
  13. model_compression_toolkit/core/keras/keras_implementation.py +7 -1
  14. model_compression_toolkit/core/keras/tf_tensor_numpy.py +1 -1
  15. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  16. model_compression_toolkit/core/pytorch/data_util.py +163 -0
  17. model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +6 -31
  18. model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py +11 -21
  19. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +9 -7
  20. model_compression_toolkit/core/pytorch/pytorch_implementation.py +8 -2
  21. model_compression_toolkit/core/pytorch/utils.py +22 -19
  22. model_compression_toolkit/core/quantization_prep_runner.py +2 -1
  23. model_compression_toolkit/core/runner.py +1 -2
  24. model_compression_toolkit/gptq/common/gptq_config.py +0 -2
  25. model_compression_toolkit/gptq/common/gptq_training.py +58 -114
  26. model_compression_toolkit/gptq/keras/gptq_training.py +15 -6
  27. model_compression_toolkit/gptq/pytorch/gptq_loss.py +3 -2
  28. model_compression_toolkit/gptq/pytorch/gptq_training.py +97 -64
  29. model_compression_toolkit/gptq/pytorch/quantization_facade.py +0 -2
  30. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +4 -3
  31. tests_pytest/keras/__init__.py +14 -0
  32. tests_pytest/keras/core/__init__.py +14 -0
  33. tests_pytest/keras/core/test_data_util.py +91 -0
  34. tests_pytest/pytorch/core/__init__.py +14 -0
  35. tests_pytest/pytorch/core/test_data_util.py +125 -0
  36. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/LICENSE.md +0 -0
  37. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/WHEEL +0 -0
  38. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20241022.507
3
+ Version: 2.2.0.20241024.501
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=0HVjlkuiw1WkOHhfE-m3X9oF9OXrUyMhp0klnmfPEDM,1573
1
+ model_compression_toolkit/__init__.py,sha256=1QmmvVxsZd2Hv0uinv8FVjWpnz6neE89qwpot6MfqqY,1573
2
2
  model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -7,11 +7,11 @@ model_compression_toolkit/verify_packages.py,sha256=TlS-K1EP-QsghqWUW7SDPkAJiUf7
7
7
  model_compression_toolkit/core/__init__.py,sha256=tnDtL9KmT0vsOU27SsJ19TKDEbIH-tXYeGxTo5YnNUM,2077
8
8
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
9
9
  model_compression_toolkit/core/graph_prep_runner.py,sha256=7-b7Jd5jBVaXOWg5nSqbEyzBtdaGDbCxs8aqMV6GZ6I,11287
10
- model_compression_toolkit/core/quantization_prep_runner.py,sha256=K9eJ7VbB_rpeyxX4yEnorOmSxFW3DkvofzxS6QI8Hp8,6454
11
- model_compression_toolkit/core/runner.py,sha256=Wd0cNVMLOPX5cGY5kwz0J64rm87JKd-onJ2k01S9nLo,14362
10
+ model_compression_toolkit/core/quantization_prep_runner.py,sha256=OtL6g2rTC5mfdKrkzm47EPPW-voGGVYMYxpy2_sfu1U,6547
11
+ model_compression_toolkit/core/runner.py,sha256=lahkYyfdsb3HJPJ5Lui7hp4vVWyIOJLXJQ5ATxiIyos,14264
12
12
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
13
13
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
14
- model_compression_toolkit/core/common/framework_implementation.py,sha256=kSg2f7wS7e2EyvX6y0eKfNTTFvVFVrB8lvldJvcPvN8,20724
14
+ model_compression_toolkit/core/common/framework_implementation.py,sha256=7k66e0b06eLnFLmu67onWPiM2lJfhWiuyQZPsRJm3lk,21294
15
15
  model_compression_toolkit/core/common/framework_info.py,sha256=1ZMMGS9ip-kSflqkartyNRt9aQ5ub1WepuTRcTy-YSQ,6337
16
16
  model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
17
17
  model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
@@ -46,11 +46,11 @@ model_compression_toolkit/core/common/graph/memory_graph/cut.py,sha256=aPdXJPP5a
46
46
  model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256=crV2NCLVO8jx9MlryZBYuJKFe_G9HfM7rUR64fDymlw,17045
47
47
  model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
48
48
  model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=gw4av_rzn_3oEAPpD3B7PHZDqnxHMjIESevl6ppPnkk,7175
49
- model_compression_toolkit/core/common/hessian/__init__.py,sha256=Sj3I9mLBq-yrcBFxpUkOy0Rb5pxJQBPcECvgyOqhHSY,1064
50
- model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=fUgW-AUhRu609_RSRd1WKaQAfPk2SmLnlkT74v6TZwY,23769
49
+ model_compression_toolkit/core/common/hessian/__init__.py,sha256=E7LK3K_1AwMCQokanNc1JODMwUKNOKmwXQiGQ7GO10I,1033
50
+ model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=OH8Xadv0ZSD_yoymgSfaNg8tqr4vxUfAbNLCBMRz6pQ,13233
51
51
  model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=1axmN0tjJSo_7hUr2d2KMv4y1pBi19cqWSQpi4BbdsA,1458
52
52
  model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=Pe4uKerx-MeDQPJ7Slr8fvFUHfv02q33w3gbQK5kBKs,4186
53
- model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=fYXcOMa2bpbJjQ2S4r021WOvhoDWFa_jy95hofqVBFA,3632
53
+ model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=U2n5fz6fK633HWzIvEuQ7N6dekMqH9-DecOXAgd3v4E,3140
54
54
  model_compression_toolkit/core/common/matchers/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
55
55
  model_compression_toolkit/core/common/matchers/base_graph_filter.py,sha256=mTk54z0mIbFmPOb4h0xfLtLDookcFyNh8H0pIN5js_M,3091
56
56
  model_compression_toolkit/core/common/matchers/base_matcher.py,sha256=JCj-NLAXOJa-GcSX-94PVUTWjooQUd0NemiyNg5uKGQ,2210
@@ -67,7 +67,7 @@ model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates
67
67
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=r1t025_QHshyoop-PZvL7x6UuXaeplCCU3h4VNBhJHo,4309
68
68
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=pk8HRoShDhiUprBC4m1AFQv1SacS4hOrj0MRdbq-5gY,7556
69
69
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=TTTux4YiOnQqt-2h7Y38959XaDwNZc0eufLMx_yws5U,37578
70
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=QdxFQ0JxsrcSfk5LlUU_3oZpEK7bYwKelGzEHh0mnJY,27558
70
+ model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=CYp2VuxXf95lYivolIuVRjAyaY5dFsDn2qh8ZhTmc9A,27525
71
71
  model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
72
72
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=KifDMbm7qkSfvSl6pcZzQ82naIXzeKL6aT-VsvWZYyc,7901
73
73
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -95,7 +95,7 @@ model_compression_toolkit/core/common/pruning/pruning_section.py,sha256=I4vxh5iP
95
95
  model_compression_toolkit/core/common/pruning/importance_metrics/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
96
96
  model_compression_toolkit/core/common/pruning/importance_metrics/base_importance_metric.py,sha256=qMAtLWs5fjbSco8nhbig5TkuacdhnDW7cy3avMHRGX4,1988
97
97
  model_compression_toolkit/core/common/pruning/importance_metrics/importance_metric_factory.py,sha256=E-fKuRfrNYlN3nNcRAbnkJkFNwClvyrL_Js1qDPxIKA,1999
98
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py,sha256=sJGsR8vTs43SV7Wfz929pCwLM-_7aXUyO5nBUig9K9s,14055
98
+ model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py,sha256=WeJZ_2LCwdKuGe7VOEHBDBTxr-JIwYmB-YNi5AFaWEE,14073
99
99
  model_compression_toolkit/core/common/pruning/mask/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
100
100
  model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=APY8BsM9B7ZxVCH6n1xs9fSCTB_A9ou9gHrCQl1DOdI,5131
101
101
  model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=4ohJrJHNzZk5uMnZEYkwLx2TDGzkh5kRhLGNVYNC6dc,5978
@@ -113,12 +113,12 @@ model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha
113
113
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
114
114
  model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=IjqFX0EGk4YCTaQsJp4-UycCVc2Ec6GTbu890dkGVns,21318
115
115
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
116
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=Fd_gxr5js-mqEwucaRR1CQAZ1W_wna19L1gAPeOzxRQ,23610
116
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=fzUvqmXVgzp_IV5ER-20kKzl4m8U_shZsAKs-ehhjFo,23887
117
117
  model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=RL-PklAjGyC-26anSt8fU07a6pB_LBQFQy9o4e9giN0,8739
118
118
  model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
119
119
  model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=y-mEST-0fVbyLiprQu7elOQawSc70TkVdpPsL7o1BmM,11197
120
120
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=pKmaeu7jrxqSI-SHmY8SFwPCRV6FrqiqJS9EAYQLbK4,4606
121
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=oME8T6Slgl1SJNpXV4oY3UhuX0YmKYbcWDsLiCYq7oE,8651
121
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=wV5RMqKkhedzRFBFwLYgc9BvCKlIKDKmJC0lmkpOvTM,8784
122
122
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=Nv_b3DECVjQnlrUet2kbuSvSKVnxcc-gf2zhFb2jSZk,43482
123
123
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=UI-NW9K-yA6qxtk3Uin1wKmo59FNy0LUnySpxodgeEs,3796
124
124
  model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=iPukBikpzuJhKfwnnBgyJ71HhaDIpSoTUuYsjt4rR7w,12587
@@ -153,12 +153,13 @@ model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256
153
153
  model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
154
154
  model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiOcnLAq1v0MMBCJqMJzzFk,3225
155
155
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
156
+ model_compression_toolkit/core/keras/data_util.py,sha256=JdomIJZfep0QYPtx2jlg0xJ40cd9S_I7BakaWQi0wKw,2681
156
157
  model_compression_toolkit/core/keras/default_framework_info.py,sha256=PYcER89eEXjKtR0T7-2Y4f7cckqoD5OQbpHePoRkMec,5030
157
- model_compression_toolkit/core/keras/keras_implementation.py,sha256=uOTGpsgH4h9MBduVBp8v7mm2S8njbkC72qvXcrZUjeI,30604
158
+ model_compression_toolkit/core/keras/keras_implementation.py,sha256=Tn4_rkcx9bH3x-pEoUbGu94S7_nj3Hl3BfvL8SPIL3g,30957
158
159
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
159
160
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
160
161
  model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=s56UIgiPipUQRNd2sd1xW6GFfYNMBmrocRCNtvpYLbY,4977
161
- model_compression_toolkit/core/keras/tf_tensor_numpy.py,sha256=AJMPD_cAwf7nzTlLMf_Y1kofXkh_xm8Ji7J6yDpbAKc,2691
162
+ model_compression_toolkit/core/keras/tf_tensor_numpy.py,sha256=jzD8FGEEa8ZD7w8IpTRdp-Udf1MwOTgjg2XTS1Givic,2696
162
163
  model_compression_toolkit/core/keras/back2framework/__init__.py,sha256=rhIiXg_nBgUZ-baE3M6SzCuQbcnq4iebY1jtJBvKHOM,808
163
164
  model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha256=UIQgOOdexycrSKombTMJVvTthR7MlrCihoqM8Kg-rnE,2293
164
165
  model_compression_toolkit/core/keras/back2framework/float_model_builder.py,sha256=9SFHhX-JnkB8PvYIIHRYlReBDI_RkZY9LditzW_ElLk,2444
@@ -216,18 +217,19 @@ model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_c
216
217
  model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
217
218
  model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
218
219
  model_compression_toolkit/core/pytorch/constants.py,sha256=YwD_joIF0vK8UG2vW1NVvg36pCNWA0vHOXjAgy_XWn0,2794
220
+ model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
219
221
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=-Vls1P_8Ckm_18nnOsmQkZ71SmzHwtQLbQ383Z4Rb-U,4365
220
222
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
221
- model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=2RGf4ii9zxJwGLA3mp-qzDp4khFaYNUNN95bNuNNZ0c,27868
223
+ model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=TWA5Eu_85TIoCii1Owx2yx_ECckOnGg7xgQkiueuZPE,28245
222
224
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
223
225
  model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=xpKj99OZKT9NT0vKIl_cOe8d89d2gef1gKoNT6PFElE,4989
224
- model_compression_toolkit/core/pytorch/utils.py,sha256=GE7T8q93I5C4As0iOias_dk9HpOvXM1N6---dJlyD60,3863
226
+ model_compression_toolkit/core/pytorch/utils.py,sha256=7VbgcLwtQvdEEc_AJgSOQ3U3KRKCICFPaBirN1fIQxg,3940
225
227
  model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN0elVWf3exgGYsi58imPoYDj5eYPeh6x4yfug,813
226
228
  model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,sha256=bwppTPRs6gL96nm7qPiKrNcBj4Krr0yEsOWjRF0aXmQ,2339
227
229
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
228
230
  model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
229
231
  model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
230
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=BJeKGMv5VU4Z3jLOIQ-Ifs_2vGELQSmEQmje3ZmaUl4,19948
232
+ model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=duQkaURjJv2FbDtX8udDRZbvPZSsIujLAL9Oa40dMK0,19934
231
233
  model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
232
234
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
233
235
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
@@ -255,9 +257,9 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transfo
255
257
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=WmEa8Xjji-_tIbthDxlLAGSr69nWk-YKcHNaVqLa7sg,1375
256
258
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/weights_activation_split.py,sha256=tp78axmUQc0Zpj3KwVmV0PGYHvCf7sAW_sRmXXw7gsY,1616
257
259
  model_compression_toolkit/core/pytorch/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
258
- model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py,sha256=fKeql1cXOieHTbxQDOIMpFO1sVktqXVCRBgZkv3R13Q,10929
259
- model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py,sha256=vXluX-awgavv7DGihG9HrlvLhak8qIHy837PPTOd4jg,3471
260
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py,sha256=C4-7naBQUh8TN6fEwkyKY6rlY_nvHSAmCnWT4iMBs8E,8497
260
+ model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py,sha256=Tt1oDiTWDROQP2b5OtlV3VGM-jInjGoGjEV-OVwW2lI,9854
261
+ model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py,sha256=8f_XlM8ZFVQPNGr1iECr1hv8QusYDrNU_vTkLQZE9RU,2477
262
+ model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py,sha256=UzWxWDbr8koKZatEcPn8RCb0Zjm_7fKTvIGb98sp18k,8487
261
263
  model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
262
264
  model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=-6oep2WJ85-JmIxZa-e2AmBpbORoKe4Xdduz2ZidwvM,4871
263
265
  model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=KVZTKCYzJqqzF5nFEiuGMv_sNeVuBTxhmxWMFacKOxE,6337
@@ -344,15 +346,15 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantiz
344
346
  model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
345
347
  model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
346
348
  model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
347
- model_compression_toolkit/gptq/common/gptq_config.py,sha256=GP4lcDeyVgXA-QFArDW28UucOOKY0zeYJpq2pvyNVM8,6510
349
+ model_compression_toolkit/gptq/common/gptq_config.py,sha256=Z6T5B3q4k2Tlr2bBWvC6TAF3d2opyA7ZT_D_mz6D1_0,6297
348
350
  model_compression_toolkit/gptq/common/gptq_constants.py,sha256=QSm6laLkIV0LYmU0BLtmKp3Fi3SqDfbncFQWOGA1cGU,611
349
351
  model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
350
352
  model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
351
- model_compression_toolkit/gptq/common/gptq_training.py,sha256=dRNEjjKdVqlazbGWjZNE9q-MsU0PBffGKHfDpy3NX5Q,16661
353
+ model_compression_toolkit/gptq/common/gptq_training.py,sha256=tt4O8PjSChquzl4c6NojvQWZmvCdTxcMLtmEVIGx1ns,13252
352
354
  model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
353
355
  model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
354
356
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
355
- model_compression_toolkit/gptq/keras/gptq_training.py,sha256=NXTNsVrO9DTh0uvc8V7rFaM0fYg2OA18ZrYd-cKZ7Z4,19159
357
+ model_compression_toolkit/gptq/keras/gptq_training.py,sha256=tFHucF7YHKtHmYGkdMpqSf14H9c7x60Il7ZTMNXSesE,19751
356
358
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=MKIfrRTRH3zCuxCR1g9ZVIFyuSSr0e0sDybqh4LDM7E,4672
357
359
  model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=iSHnMEdoIqHYqLCTsdK8uxhKbZuuaDOu_BeQ10Z492U,15715
358
360
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
@@ -367,11 +369,11 @@ model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quanti
367
369
  model_compression_toolkit/gptq/keras/quantizer/ste_rounding/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
368
370
  model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=pgZADwaNWUwm9QTrYaW6yXE3-zfedPZSa9TKBVedNd4,8356
369
371
  model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
370
- model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=nVW3URcCWQywoXfmTOBMxliZVvosshf4-G0Sq7dNwzU,3877
372
+ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=_07Zx_43bnNokwR5S8phIqeu5-_7_5VBT4DT-FCw7Do,3892
371
373
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
372
- model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=j_FZcs8ey_9voI83TrL4q1Mne59zO2_v0MzdhZcxWuY,20071
374
+ model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=QBxTnwVvLyZDTdpkR81wjj9o5aGtmp9qiBt5FR8ImJ0,21777
373
375
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
374
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=7UPaLBx66mJIlDTpT1uLI9LpHPzOr8EtywZ0aawveDA,16527
376
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=yv2DWPWpFVRmtB_FhcRwnLUumyPPHC_hHaMxeQBTQ1k,16333
375
377
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
376
378
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
377
379
  model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py,sha256=nngu2TeXjngkqt_6-wciFmCvo-dbpeh_tJJxBV_cfHk,3686
@@ -379,7 +381,7 @@ model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl
379
381
  model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=5EyAzvlU01vLyXmMwY_8dNyb7GwYktXmnrvUON8n8WI,4696
380
382
  model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=H6pARLK-jq3cKoaipY0SK9wMGrqy6CSEZTk14KdrKA0,2105
381
383
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
382
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=vlQEhif-R49UstORkXmpMA4ZE82Aqh-mJqKCnB31gag,3005
384
+ model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=f7B95Bx-MX-HKheqAUn1GG8cVHFI2ldFReXrUPwk2tY,3002
383
385
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=kLVQC1hXzDpP4Jx7AwnA764oGnY5AMEuvUUhAvhz09M,12347
384
386
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py,sha256=FgPSKoV8p8y-gLNz359XdOPD6w_wpDvcJFtTNLWqYb0,9099
385
387
  model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -552,14 +554,19 @@ model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3
552
554
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
553
555
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
554
556
  tests_pytest/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
557
+ tests_pytest/keras/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
558
+ tests_pytest/keras/core/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
559
+ tests_pytest/keras/core/test_data_util.py,sha256=XSoPu_ci1xy2EtK-3OWGpESr-Meg1GDaxuSvcj3yt-w,3915
555
560
  tests_pytest/pytorch/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
561
+ tests_pytest/pytorch/core/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
562
+ tests_pytest/pytorch/core/test_data_util.py,sha256=Bg3c21YVfXE1SAUlTao553gXcITTKF4CPeKtl3peBTE,5604
556
563
  tests_pytest/pytorch/gptq/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
557
564
  tests_pytest/pytorch/gptq/test_annealing_cfg.py,sha256=hGC7L6mp3N1ygcJ3OctgS_Fz2JY75q5aswolJkbHkZM,2208
558
565
  tests_pytest/pytorch/gptq/test_gradual_act_quantization.py,sha256=tI01aFIUaiCILL5Qn--p1E_rLBUelxLdSY3k52lwcx0,4594
559
566
  tests_pytest/pytorch/trainable_infrastructure/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
560
567
  tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py,sha256=eNOpSp0GoLxtEdiRypBp8jaujXfdNxBwKh5Rd-P7WLs,1786
561
- mct_nightly-2.2.0.20241022.507.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
562
- mct_nightly-2.2.0.20241022.507.dist-info/METADATA,sha256=HYlWsJmgchcFkR7Hxd7QgyThGdYJouwoYJnnetjMfv8,20830
563
- mct_nightly-2.2.0.20241022.507.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
564
- mct_nightly-2.2.0.20241022.507.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
565
- mct_nightly-2.2.0.20241022.507.dist-info/RECORD,,
568
+ mct_nightly-2.2.0.20241024.501.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
569
+ mct_nightly-2.2.0.20241024.501.dist-info/METADATA,sha256=5WlF1OFmMSVFGGpe1Co0cav94dCyyPdoBtXZ3NbKiMo,20830
570
+ mct_nightly-2.2.0.20241024.501.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
571
+ mct_nightly-2.2.0.20241024.501.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
572
+ mct_nightly-2.2.0.20241024.501.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20241022.000507"
30
+ __version__ = "2.2.0.20241024.000501"
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from abc import ABC, abstractmethod
16
- from typing import Callable, Any, List, Tuple, Dict
16
+ from typing import Callable, Any, List, Tuple, Dict, Generator
17
17
 
18
18
  import numpy as np
19
19
 
@@ -46,7 +46,7 @@ class FrameworkImplementation(ABC):
46
46
  Returns: Module of the framework constants.
47
47
 
48
48
  """
49
- raise NotImplemented(f'{self.__class__.__name__} did not supply a constants module.') # pragma: no cover
49
+ raise NotImplementedError(f'{self.__class__.__name__} did not supply a constants module.') # pragma: no cover
50
50
 
51
51
  @abstractmethod
52
52
  def get_hessian_scores_calculator(self,
@@ -64,7 +64,7 @@ class FrameworkImplementation(ABC):
64
64
 
65
65
  Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request.
66
66
  """
67
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
67
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
68
68
  f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover
69
69
 
70
70
  @abstractmethod
@@ -77,7 +77,7 @@ class FrameworkImplementation(ABC):
77
77
  Returns:
78
78
  Numpy array converted from the input tensor.
79
79
  """
80
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
80
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
81
81
  f'framework\'s to_numpy method.') # pragma: no cover
82
82
 
83
83
  @abstractmethod
@@ -90,7 +90,7 @@ class FrameworkImplementation(ABC):
90
90
  Returns:
91
91
  Framework's tensor converted from the input Numpy array.
92
92
  """
93
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
93
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
94
94
  f'framework\'s to_tensor method.') # pragma: no cover
95
95
 
96
96
  @abstractmethod
@@ -106,7 +106,7 @@ class FrameworkImplementation(ABC):
106
106
  Returns:
107
107
  Graph representing the input model.
108
108
  """
109
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
109
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
110
110
  f'framework\'s model_reader method.') # pragma: no cover
111
111
 
112
112
  @abstractmethod
@@ -131,7 +131,7 @@ class FrameworkImplementation(ABC):
131
131
  Returns:
132
132
  A tuple with the model and additional relevant supporting objects.
133
133
  """
134
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
134
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
135
135
  f'framework\'s model_builder method.') # pragma: no cover
136
136
 
137
137
  @abstractmethod
@@ -148,7 +148,7 @@ class FrameworkImplementation(ABC):
148
148
  Returns:
149
149
  The frameworks model's output.
150
150
  """
151
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
151
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
152
152
  f'framework\'s run_model_inference method.') # pragma: no cover
153
153
 
154
154
  @abstractmethod
@@ -167,7 +167,7 @@ class FrameworkImplementation(ABC):
167
167
  Returns:
168
168
  Graph after SNC.
169
169
  """
170
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
170
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
171
171
  f'framework\'s apply_shift_negative_correction method.') # pragma: no cover
172
172
 
173
173
  @abstractmethod
@@ -184,7 +184,7 @@ class FrameworkImplementation(ABC):
184
184
  Returns:
185
185
  A list of the framework substitutions used after we collect statistics.
186
186
  """
187
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
187
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
188
188
  f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover
189
189
 
190
190
  @abstractmethod
@@ -194,7 +194,7 @@ class FrameworkImplementation(ABC):
194
194
  Returns: A list of the framework substitutions used to prepare the graph.
195
195
 
196
196
  """
197
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
197
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
198
198
  f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover
199
199
 
200
200
  @abstractmethod
@@ -208,7 +208,7 @@ class FrameworkImplementation(ABC):
208
208
  Returns: A list of the framework substitutions used before we collect statistics.
209
209
 
210
210
  """
211
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
211
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
212
212
  f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover
213
213
 
214
214
  @abstractmethod
@@ -216,7 +216,7 @@ class FrameworkImplementation(ABC):
216
216
  """
217
217
  Returns: linear collapsing substitution
218
218
  """
219
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
219
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
220
220
  f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover
221
221
 
222
222
  @abstractmethod
@@ -224,7 +224,7 @@ class FrameworkImplementation(ABC):
224
224
  """
225
225
  Returns: conv2d add const collapsing substitution
226
226
  """
227
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
227
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
228
228
  f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover
229
229
 
230
230
  @abstractmethod
@@ -239,7 +239,7 @@ class FrameworkImplementation(ABC):
239
239
  Returns:
240
240
  A list of the framework substitutions used for statistics correction.
241
241
  """
242
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
242
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
243
243
  f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover
244
244
 
245
245
  @abstractmethod
@@ -247,7 +247,7 @@ class FrameworkImplementation(ABC):
247
247
  """
248
248
  Returns: A list of the framework substitutions used for residual collapsing
249
249
  """
250
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
250
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
251
251
  f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover
252
252
 
253
253
 
@@ -263,7 +263,7 @@ class FrameworkImplementation(ABC):
263
263
  Returns:
264
264
  A list of the framework substitutions used after we collect statistics.
265
265
  """
266
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
266
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
267
267
  f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover
268
268
 
269
269
  @abstractmethod
@@ -272,7 +272,7 @@ class FrameworkImplementation(ABC):
272
272
  Returns: A list of Keras substitutions used to build a virtual graph with composed activation-weights pairs.
273
273
  """
274
274
 
275
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
275
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
276
276
  f'framework\'s get_substitutions_virtual_weights_activation_coupling '
277
277
  f'method.') # pragma: no cover
278
278
 
@@ -288,7 +288,7 @@ class FrameworkImplementation(ABC):
288
288
  Returns:
289
289
  A list of the framework substitutions used after we apply second moment statistics.
290
290
  """
291
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
291
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
292
292
  f'framework\'s get_substitutions_after_second_moment_correction '
293
293
  f'method.') # pragma: no cover
294
294
 
@@ -316,7 +316,7 @@ class FrameworkImplementation(ABC):
316
316
  A function that computes the metric.
317
317
  """
318
318
 
319
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
319
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
320
320
  f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover
321
321
 
322
322
  def get_node_prior_info(self, node: BaseNode,
@@ -334,7 +334,7 @@ class FrameworkImplementation(ABC):
334
334
  NodePriorInfo with information about the node.
335
335
  """
336
336
 
337
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
337
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
338
338
  f'framework\'s get_node_prior_info method.') # pragma: no cover
339
339
 
340
340
  def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
@@ -345,7 +345,7 @@ class FrameworkImplementation(ABC):
345
345
  Returns: True if the node should be considered an interest point, False otherwise.
346
346
  """
347
347
 
348
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
348
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
349
349
  f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover
350
350
 
351
351
  def get_mp_node_distance_fn(self, n: BaseNode,
@@ -364,7 +364,7 @@ class FrameworkImplementation(ABC):
364
364
  Returns: A distance function between two tensors and a axis on which the distance is computed (if exists).
365
365
  """
366
366
 
367
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
367
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
368
368
  f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover
369
369
 
370
370
 
@@ -381,7 +381,7 @@ class FrameworkImplementation(ABC):
381
381
 
382
382
  """
383
383
 
384
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
384
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
385
385
  f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover
386
386
 
387
387
  @abstractmethod
@@ -398,7 +398,7 @@ class FrameworkImplementation(ABC):
398
398
  Returns: The MAC count of the operation
399
399
  """
400
400
 
401
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
401
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
402
402
  f'framework\'s get_node_mac_operations method.') # pragma: no cover
403
403
 
404
404
  @abstractmethod
@@ -419,7 +419,7 @@ class FrameworkImplementation(ABC):
419
419
  Returns:
420
420
  A Graph after second moment correction.
421
421
  """
422
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
422
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
423
423
  f'framework\'s apply_second_moment_correction method.') # pragma: no cover
424
424
 
425
425
  @abstractmethod
@@ -436,7 +436,7 @@ class FrameworkImplementation(ABC):
436
436
  Returns:
437
437
  The output of the model inference on the given input.
438
438
  """
439
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
439
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
440
440
  f'framework\'s sensitivity_eval_inference method.') # pragma: no cover
441
441
 
442
442
  def get_inferable_quantizers(self, node: BaseNode):
@@ -452,5 +452,19 @@ class FrameworkImplementation(ABC):
452
452
 
453
453
  """
454
454
 
455
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
456
- f'framework\'s get_inferable_quantizers method.') # pragma: no cover
455
+ raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
456
+ f'framework\'s get_inferable_quantizers method.') # pragma: no cover
457
+
458
+ @staticmethod
459
+ def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
460
+ """
461
+ Create DataLoader based on samples yielded by data_gen.
462
+
463
+ Args:
464
+ data_gen_fn: data generator factory.
465
+ batch_size: target batch size.
466
+
467
+ Returns:
468
+ Framework dataloader.
469
+ """
470
+ raise NotImplementedError() # pragma: no cover
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from model_compression_toolkit.core.common.hessian.hessian_scores_request import (
16
- HessianScoresRequest, HessianMode, HessianScoresGranularity, HessianEstimationDistribution
16
+ HessianScoresRequest, HessianMode, HessianScoresGranularity
17
17
  )
18
18
  from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
19
19
  import model_compression_toolkit.core.common.hessian.hessian_info_utils as hessian_utils