mct-nightly 2.2.0.20241022.507__py3-none-any.whl → 2.2.0.20241024.501__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/RECORD +38 -31
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +43 -29
- model_compression_toolkit/core/common/hessian/__init__.py +1 -1
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +222 -371
- model_compression_toolkit/core/common/hessian/hessian_scores_request.py +27 -41
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -10
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +11 -9
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +10 -6
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +17 -15
- model_compression_toolkit/core/keras/data_util.py +67 -0
- model_compression_toolkit/core/keras/keras_implementation.py +7 -1
- model_compression_toolkit/core/keras/tf_tensor_numpy.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/data_util.py +163 -0
- model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +6 -31
- model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py +11 -21
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +9 -7
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +8 -2
- model_compression_toolkit/core/pytorch/utils.py +22 -19
- model_compression_toolkit/core/quantization_prep_runner.py +2 -1
- model_compression_toolkit/core/runner.py +1 -2
- model_compression_toolkit/gptq/common/gptq_config.py +0 -2
- model_compression_toolkit/gptq/common/gptq_training.py +58 -114
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -6
- model_compression_toolkit/gptq/pytorch/gptq_loss.py +3 -2
- model_compression_toolkit/gptq/pytorch/gptq_training.py +97 -64
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +0 -2
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +4 -3
- tests_pytest/keras/__init__.py +14 -0
- tests_pytest/keras/core/__init__.py +14 -0
- tests_pytest/keras/core/test_data_util.py +91 -0
- tests_pytest/pytorch/core/__init__.py +14 -0
- tests_pytest/pytorch/core/test_data_util.py +125 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/top_level.txt +0 -0
{mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=1QmmvVxsZd2Hv0uinv8FVjWpnz6neE89qwpot6MfqqY,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -7,11 +7,11 @@ model_compression_toolkit/verify_packages.py,sha256=TlS-K1EP-QsghqWUW7SDPkAJiUf7
|
|
7
7
|
model_compression_toolkit/core/__init__.py,sha256=tnDtL9KmT0vsOU27SsJ19TKDEbIH-tXYeGxTo5YnNUM,2077
|
8
8
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
9
9
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=7-b7Jd5jBVaXOWg5nSqbEyzBtdaGDbCxs8aqMV6GZ6I,11287
|
10
|
-
model_compression_toolkit/core/quantization_prep_runner.py,sha256=
|
11
|
-
model_compression_toolkit/core/runner.py,sha256=
|
10
|
+
model_compression_toolkit/core/quantization_prep_runner.py,sha256=OtL6g2rTC5mfdKrkzm47EPPW-voGGVYMYxpy2_sfu1U,6547
|
11
|
+
model_compression_toolkit/core/runner.py,sha256=lahkYyfdsb3HJPJ5Lui7hp4vVWyIOJLXJQ5ATxiIyos,14264
|
12
12
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
13
13
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
14
|
-
model_compression_toolkit/core/common/framework_implementation.py,sha256=
|
14
|
+
model_compression_toolkit/core/common/framework_implementation.py,sha256=7k66e0b06eLnFLmu67onWPiM2lJfhWiuyQZPsRJm3lk,21294
|
15
15
|
model_compression_toolkit/core/common/framework_info.py,sha256=1ZMMGS9ip-kSflqkartyNRt9aQ5ub1WepuTRcTy-YSQ,6337
|
16
16
|
model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
|
17
17
|
model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
|
@@ -46,11 +46,11 @@ model_compression_toolkit/core/common/graph/memory_graph/cut.py,sha256=aPdXJPP5a
|
|
46
46
|
model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256=crV2NCLVO8jx9MlryZBYuJKFe_G9HfM7rUR64fDymlw,17045
|
47
47
|
model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
|
48
48
|
model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=gw4av_rzn_3oEAPpD3B7PHZDqnxHMjIESevl6ppPnkk,7175
|
49
|
-
model_compression_toolkit/core/common/hessian/__init__.py,sha256=
|
50
|
-
model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=
|
49
|
+
model_compression_toolkit/core/common/hessian/__init__.py,sha256=E7LK3K_1AwMCQokanNc1JODMwUKNOKmwXQiGQ7GO10I,1033
|
50
|
+
model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=OH8Xadv0ZSD_yoymgSfaNg8tqr4vxUfAbNLCBMRz6pQ,13233
|
51
51
|
model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=1axmN0tjJSo_7hUr2d2KMv4y1pBi19cqWSQpi4BbdsA,1458
|
52
52
|
model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=Pe4uKerx-MeDQPJ7Slr8fvFUHfv02q33w3gbQK5kBKs,4186
|
53
|
-
model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=
|
53
|
+
model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=U2n5fz6fK633HWzIvEuQ7N6dekMqH9-DecOXAgd3v4E,3140
|
54
54
|
model_compression_toolkit/core/common/matchers/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
55
55
|
model_compression_toolkit/core/common/matchers/base_graph_filter.py,sha256=mTk54z0mIbFmPOb4h0xfLtLDookcFyNh8H0pIN5js_M,3091
|
56
56
|
model_compression_toolkit/core/common/matchers/base_matcher.py,sha256=JCj-NLAXOJa-GcSX-94PVUTWjooQUd0NemiyNg5uKGQ,2210
|
@@ -67,7 +67,7 @@ model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates
|
|
67
67
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=r1t025_QHshyoop-PZvL7x6UuXaeplCCU3h4VNBhJHo,4309
|
68
68
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=pk8HRoShDhiUprBC4m1AFQv1SacS4hOrj0MRdbq-5gY,7556
|
69
69
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=TTTux4YiOnQqt-2h7Y38959XaDwNZc0eufLMx_yws5U,37578
|
70
|
-
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=
|
70
|
+
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=CYp2VuxXf95lYivolIuVRjAyaY5dFsDn2qh8ZhTmc9A,27525
|
71
71
|
model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
|
72
72
|
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=KifDMbm7qkSfvSl6pcZzQ82naIXzeKL6aT-VsvWZYyc,7901
|
73
73
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
@@ -95,7 +95,7 @@ model_compression_toolkit/core/common/pruning/pruning_section.py,sha256=I4vxh5iP
|
|
95
95
|
model_compression_toolkit/core/common/pruning/importance_metrics/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
|
96
96
|
model_compression_toolkit/core/common/pruning/importance_metrics/base_importance_metric.py,sha256=qMAtLWs5fjbSco8nhbig5TkuacdhnDW7cy3avMHRGX4,1988
|
97
97
|
model_compression_toolkit/core/common/pruning/importance_metrics/importance_metric_factory.py,sha256=E-fKuRfrNYlN3nNcRAbnkJkFNwClvyrL_Js1qDPxIKA,1999
|
98
|
-
model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py,sha256=
|
98
|
+
model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py,sha256=WeJZ_2LCwdKuGe7VOEHBDBTxr-JIwYmB-YNi5AFaWEE,14073
|
99
99
|
model_compression_toolkit/core/common/pruning/mask/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
100
100
|
model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=APY8BsM9B7ZxVCH6n1xs9fSCTB_A9ou9gHrCQl1DOdI,5131
|
101
101
|
model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=4ohJrJHNzZk5uMnZEYkwLx2TDGzkh5kRhLGNVYNC6dc,5978
|
@@ -113,12 +113,12 @@ model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha
|
|
113
113
|
model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
|
114
114
|
model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=IjqFX0EGk4YCTaQsJp4-UycCVc2Ec6GTbu890dkGVns,21318
|
115
115
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
|
116
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=
|
116
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=fzUvqmXVgzp_IV5ER-20kKzl4m8U_shZsAKs-ehhjFo,23887
|
117
117
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=RL-PklAjGyC-26anSt8fU07a6pB_LBQFQy9o4e9giN0,8739
|
118
118
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
|
119
119
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=y-mEST-0fVbyLiprQu7elOQawSc70TkVdpPsL7o1BmM,11197
|
120
120
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=pKmaeu7jrxqSI-SHmY8SFwPCRV6FrqiqJS9EAYQLbK4,4606
|
121
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=
|
121
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=wV5RMqKkhedzRFBFwLYgc9BvCKlIKDKmJC0lmkpOvTM,8784
|
122
122
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=Nv_b3DECVjQnlrUet2kbuSvSKVnxcc-gf2zhFb2jSZk,43482
|
123
123
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=UI-NW9K-yA6qxtk3Uin1wKmo59FNy0LUnySpxodgeEs,3796
|
124
124
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=iPukBikpzuJhKfwnnBgyJ71HhaDIpSoTUuYsjt4rR7w,12587
|
@@ -153,12 +153,13 @@ model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256
|
|
153
153
|
model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
154
154
|
model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiOcnLAq1v0MMBCJqMJzzFk,3225
|
155
155
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
156
|
+
model_compression_toolkit/core/keras/data_util.py,sha256=JdomIJZfep0QYPtx2jlg0xJ40cd9S_I7BakaWQi0wKw,2681
|
156
157
|
model_compression_toolkit/core/keras/default_framework_info.py,sha256=PYcER89eEXjKtR0T7-2Y4f7cckqoD5OQbpHePoRkMec,5030
|
157
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
158
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=Tn4_rkcx9bH3x-pEoUbGu94S7_nj3Hl3BfvL8SPIL3g,30957
|
158
159
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
159
160
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
|
160
161
|
model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=s56UIgiPipUQRNd2sd1xW6GFfYNMBmrocRCNtvpYLbY,4977
|
161
|
-
model_compression_toolkit/core/keras/tf_tensor_numpy.py,sha256=
|
162
|
+
model_compression_toolkit/core/keras/tf_tensor_numpy.py,sha256=jzD8FGEEa8ZD7w8IpTRdp-Udf1MwOTgjg2XTS1Givic,2696
|
162
163
|
model_compression_toolkit/core/keras/back2framework/__init__.py,sha256=rhIiXg_nBgUZ-baE3M6SzCuQbcnq4iebY1jtJBvKHOM,808
|
163
164
|
model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha256=UIQgOOdexycrSKombTMJVvTthR7MlrCihoqM8Kg-rnE,2293
|
164
165
|
model_compression_toolkit/core/keras/back2framework/float_model_builder.py,sha256=9SFHhX-JnkB8PvYIIHRYlReBDI_RkZY9LditzW_ElLk,2444
|
@@ -216,18 +217,19 @@ model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_c
|
|
216
217
|
model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
217
218
|
model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
218
219
|
model_compression_toolkit/core/pytorch/constants.py,sha256=YwD_joIF0vK8UG2vW1NVvg36pCNWA0vHOXjAgy_XWn0,2794
|
220
|
+
model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
|
219
221
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=-Vls1P_8Ckm_18nnOsmQkZ71SmzHwtQLbQ383Z4Rb-U,4365
|
220
222
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
|
221
|
-
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=
|
223
|
+
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=TWA5Eu_85TIoCii1Owx2yx_ECckOnGg7xgQkiueuZPE,28245
|
222
224
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
|
223
225
|
model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=xpKj99OZKT9NT0vKIl_cOe8d89d2gef1gKoNT6PFElE,4989
|
224
|
-
model_compression_toolkit/core/pytorch/utils.py,sha256=
|
226
|
+
model_compression_toolkit/core/pytorch/utils.py,sha256=7VbgcLwtQvdEEc_AJgSOQ3U3KRKCICFPaBirN1fIQxg,3940
|
225
227
|
model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN0elVWf3exgGYsi58imPoYDj5eYPeh6x4yfug,813
|
226
228
|
model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,sha256=bwppTPRs6gL96nm7qPiKrNcBj4Krr0yEsOWjRF0aXmQ,2339
|
227
229
|
model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
|
228
230
|
model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
|
229
231
|
model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
|
230
|
-
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=
|
232
|
+
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=duQkaURjJv2FbDtX8udDRZbvPZSsIujLAL9Oa40dMK0,19934
|
231
233
|
model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
|
232
234
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
233
235
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
|
@@ -255,9 +257,9 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transfo
|
|
255
257
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=WmEa8Xjji-_tIbthDxlLAGSr69nWk-YKcHNaVqLa7sg,1375
|
256
258
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/weights_activation_split.py,sha256=tp78axmUQc0Zpj3KwVmV0PGYHvCf7sAW_sRmXXw7gsY,1616
|
257
259
|
model_compression_toolkit/core/pytorch/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
258
|
-
model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py,sha256=
|
259
|
-
model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py,sha256=
|
260
|
-
model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py,sha256=
|
260
|
+
model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py,sha256=Tt1oDiTWDROQP2b5OtlV3VGM-jInjGoGjEV-OVwW2lI,9854
|
261
|
+
model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py,sha256=8f_XlM8ZFVQPNGr1iECr1hv8QusYDrNU_vTkLQZE9RU,2477
|
262
|
+
model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py,sha256=UzWxWDbr8koKZatEcPn8RCb0Zjm_7fKTvIGb98sp18k,8487
|
261
263
|
model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
262
264
|
model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=-6oep2WJ85-JmIxZa-e2AmBpbORoKe4Xdduz2ZidwvM,4871
|
263
265
|
model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=KVZTKCYzJqqzF5nFEiuGMv_sNeVuBTxhmxWMFacKOxE,6337
|
@@ -344,15 +346,15 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantiz
|
|
344
346
|
model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
|
345
347
|
model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
|
346
348
|
model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
347
|
-
model_compression_toolkit/gptq/common/gptq_config.py,sha256=
|
349
|
+
model_compression_toolkit/gptq/common/gptq_config.py,sha256=Z6T5B3q4k2Tlr2bBWvC6TAF3d2opyA7ZT_D_mz6D1_0,6297
|
348
350
|
model_compression_toolkit/gptq/common/gptq_constants.py,sha256=QSm6laLkIV0LYmU0BLtmKp3Fi3SqDfbncFQWOGA1cGU,611
|
349
351
|
model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
|
350
352
|
model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
|
351
|
-
model_compression_toolkit/gptq/common/gptq_training.py,sha256=
|
353
|
+
model_compression_toolkit/gptq/common/gptq_training.py,sha256=tt4O8PjSChquzl4c6NojvQWZmvCdTxcMLtmEVIGx1ns,13252
|
352
354
|
model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
353
355
|
model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
|
354
356
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
|
355
|
-
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=
|
357
|
+
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=tFHucF7YHKtHmYGkdMpqSf14H9c7x60Il7ZTMNXSesE,19751
|
356
358
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=MKIfrRTRH3zCuxCR1g9ZVIFyuSSr0e0sDybqh4LDM7E,4672
|
357
359
|
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=iSHnMEdoIqHYqLCTsdK8uxhKbZuuaDOu_BeQ10Z492U,15715
|
358
360
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
@@ -367,11 +369,11 @@ model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quanti
|
|
367
369
|
model_compression_toolkit/gptq/keras/quantizer/ste_rounding/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
368
370
|
model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=pgZADwaNWUwm9QTrYaW6yXE3-zfedPZSa9TKBVedNd4,8356
|
369
371
|
model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
370
|
-
model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=
|
372
|
+
model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=_07Zx_43bnNokwR5S8phIqeu5-_7_5VBT4DT-FCw7Do,3892
|
371
373
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
372
|
-
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=
|
374
|
+
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=QBxTnwVvLyZDTdpkR81wjj9o5aGtmp9qiBt5FR8ImJ0,21777
|
373
375
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
|
374
|
-
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=
|
376
|
+
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=yv2DWPWpFVRmtB_FhcRwnLUumyPPHC_hHaMxeQBTQ1k,16333
|
375
377
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
376
378
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
|
377
379
|
model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py,sha256=nngu2TeXjngkqt_6-wciFmCvo-dbpeh_tJJxBV_cfHk,3686
|
@@ -379,7 +381,7 @@ model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl
|
|
379
381
|
model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=5EyAzvlU01vLyXmMwY_8dNyb7GwYktXmnrvUON8n8WI,4696
|
380
382
|
model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=H6pARLK-jq3cKoaipY0SK9wMGrqy6CSEZTk14KdrKA0,2105
|
381
383
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
382
|
-
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=
|
384
|
+
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=f7B95Bx-MX-HKheqAUn1GG8cVHFI2ldFReXrUPwk2tY,3002
|
383
385
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=kLVQC1hXzDpP4Jx7AwnA764oGnY5AMEuvUUhAvhz09M,12347
|
384
386
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py,sha256=FgPSKoV8p8y-gLNz359XdOPD6w_wpDvcJFtTNLWqYb0,9099
|
385
387
|
model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
@@ -552,14 +554,19 @@ model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3
|
|
552
554
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
553
555
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
554
556
|
tests_pytest/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
557
|
+
tests_pytest/keras/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
558
|
+
tests_pytest/keras/core/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
559
|
+
tests_pytest/keras/core/test_data_util.py,sha256=XSoPu_ci1xy2EtK-3OWGpESr-Meg1GDaxuSvcj3yt-w,3915
|
555
560
|
tests_pytest/pytorch/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
561
|
+
tests_pytest/pytorch/core/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
562
|
+
tests_pytest/pytorch/core/test_data_util.py,sha256=Bg3c21YVfXE1SAUlTao553gXcITTKF4CPeKtl3peBTE,5604
|
556
563
|
tests_pytest/pytorch/gptq/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
557
564
|
tests_pytest/pytorch/gptq/test_annealing_cfg.py,sha256=hGC7L6mp3N1ygcJ3OctgS_Fz2JY75q5aswolJkbHkZM,2208
|
558
565
|
tests_pytest/pytorch/gptq/test_gradual_act_quantization.py,sha256=tI01aFIUaiCILL5Qn--p1E_rLBUelxLdSY3k52lwcx0,4594
|
559
566
|
tests_pytest/pytorch/trainable_infrastructure/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
560
567
|
tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py,sha256=eNOpSp0GoLxtEdiRypBp8jaujXfdNxBwKh5Rd-P7WLs,1786
|
561
|
-
mct_nightly-2.2.0.
|
562
|
-
mct_nightly-2.2.0.
|
563
|
-
mct_nightly-2.2.0.
|
564
|
-
mct_nightly-2.2.0.
|
565
|
-
mct_nightly-2.2.0.
|
568
|
+
mct_nightly-2.2.0.20241024.501.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
569
|
+
mct_nightly-2.2.0.20241024.501.dist-info/METADATA,sha256=5WlF1OFmMSVFGGpe1Co0cav94dCyyPdoBtXZ3NbKiMo,20830
|
570
|
+
mct_nightly-2.2.0.20241024.501.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
571
|
+
mct_nightly-2.2.0.20241024.501.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
|
572
|
+
mct_nightly-2.2.0.20241024.501.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.2.0.
|
30
|
+
__version__ = "2.2.0.20241024.000501"
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
from abc import ABC, abstractmethod
|
16
|
-
from typing import Callable, Any, List, Tuple, Dict
|
16
|
+
from typing import Callable, Any, List, Tuple, Dict, Generator
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
|
@@ -46,7 +46,7 @@ class FrameworkImplementation(ABC):
|
|
46
46
|
Returns: Module of the framework constants.
|
47
47
|
|
48
48
|
"""
|
49
|
-
raise
|
49
|
+
raise NotImplementedError(f'{self.__class__.__name__} did not supply a constants module.') # pragma: no cover
|
50
50
|
|
51
51
|
@abstractmethod
|
52
52
|
def get_hessian_scores_calculator(self,
|
@@ -64,7 +64,7 @@ class FrameworkImplementation(ABC):
|
|
64
64
|
|
65
65
|
Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request.
|
66
66
|
"""
|
67
|
-
raise
|
67
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
68
68
|
f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover
|
69
69
|
|
70
70
|
@abstractmethod
|
@@ -77,7 +77,7 @@ class FrameworkImplementation(ABC):
|
|
77
77
|
Returns:
|
78
78
|
Numpy array converted from the input tensor.
|
79
79
|
"""
|
80
|
-
raise
|
80
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
81
81
|
f'framework\'s to_numpy method.') # pragma: no cover
|
82
82
|
|
83
83
|
@abstractmethod
|
@@ -90,7 +90,7 @@ class FrameworkImplementation(ABC):
|
|
90
90
|
Returns:
|
91
91
|
Framework's tensor converted from the input Numpy array.
|
92
92
|
"""
|
93
|
-
raise
|
93
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
94
94
|
f'framework\'s to_tensor method.') # pragma: no cover
|
95
95
|
|
96
96
|
@abstractmethod
|
@@ -106,7 +106,7 @@ class FrameworkImplementation(ABC):
|
|
106
106
|
Returns:
|
107
107
|
Graph representing the input model.
|
108
108
|
"""
|
109
|
-
raise
|
109
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
110
110
|
f'framework\'s model_reader method.') # pragma: no cover
|
111
111
|
|
112
112
|
@abstractmethod
|
@@ -131,7 +131,7 @@ class FrameworkImplementation(ABC):
|
|
131
131
|
Returns:
|
132
132
|
A tuple with the model and additional relevant supporting objects.
|
133
133
|
"""
|
134
|
-
raise
|
134
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
135
135
|
f'framework\'s model_builder method.') # pragma: no cover
|
136
136
|
|
137
137
|
@abstractmethod
|
@@ -148,7 +148,7 @@ class FrameworkImplementation(ABC):
|
|
148
148
|
Returns:
|
149
149
|
The frameworks model's output.
|
150
150
|
"""
|
151
|
-
raise
|
151
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
152
152
|
f'framework\'s run_model_inference method.') # pragma: no cover
|
153
153
|
|
154
154
|
@abstractmethod
|
@@ -167,7 +167,7 @@ class FrameworkImplementation(ABC):
|
|
167
167
|
Returns:
|
168
168
|
Graph after SNC.
|
169
169
|
"""
|
170
|
-
raise
|
170
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
171
171
|
f'framework\'s apply_shift_negative_correction method.') # pragma: no cover
|
172
172
|
|
173
173
|
@abstractmethod
|
@@ -184,7 +184,7 @@ class FrameworkImplementation(ABC):
|
|
184
184
|
Returns:
|
185
185
|
A list of the framework substitutions used after we collect statistics.
|
186
186
|
"""
|
187
|
-
raise
|
187
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
188
188
|
f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover
|
189
189
|
|
190
190
|
@abstractmethod
|
@@ -194,7 +194,7 @@ class FrameworkImplementation(ABC):
|
|
194
194
|
Returns: A list of the framework substitutions used to prepare the graph.
|
195
195
|
|
196
196
|
"""
|
197
|
-
raise
|
197
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
198
198
|
f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover
|
199
199
|
|
200
200
|
@abstractmethod
|
@@ -208,7 +208,7 @@ class FrameworkImplementation(ABC):
|
|
208
208
|
Returns: A list of the framework substitutions used before we collect statistics.
|
209
209
|
|
210
210
|
"""
|
211
|
-
raise
|
211
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
212
212
|
f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover
|
213
213
|
|
214
214
|
@abstractmethod
|
@@ -216,7 +216,7 @@ class FrameworkImplementation(ABC):
|
|
216
216
|
"""
|
217
217
|
Returns: linear collapsing substitution
|
218
218
|
"""
|
219
|
-
raise
|
219
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
220
220
|
f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover
|
221
221
|
|
222
222
|
@abstractmethod
|
@@ -224,7 +224,7 @@ class FrameworkImplementation(ABC):
|
|
224
224
|
"""
|
225
225
|
Returns: conv2d add const collapsing substitution
|
226
226
|
"""
|
227
|
-
raise
|
227
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
228
228
|
f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover
|
229
229
|
|
230
230
|
@abstractmethod
|
@@ -239,7 +239,7 @@ class FrameworkImplementation(ABC):
|
|
239
239
|
Returns:
|
240
240
|
A list of the framework substitutions used for statistics correction.
|
241
241
|
"""
|
242
|
-
raise
|
242
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
243
243
|
f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover
|
244
244
|
|
245
245
|
@abstractmethod
|
@@ -247,7 +247,7 @@ class FrameworkImplementation(ABC):
|
|
247
247
|
"""
|
248
248
|
Returns: A list of the framework substitutions used for residual collapsing
|
249
249
|
"""
|
250
|
-
raise
|
250
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
251
251
|
f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover
|
252
252
|
|
253
253
|
|
@@ -263,7 +263,7 @@ class FrameworkImplementation(ABC):
|
|
263
263
|
Returns:
|
264
264
|
A list of the framework substitutions used after we collect statistics.
|
265
265
|
"""
|
266
|
-
raise
|
266
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
267
267
|
f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover
|
268
268
|
|
269
269
|
@abstractmethod
|
@@ -272,7 +272,7 @@ class FrameworkImplementation(ABC):
|
|
272
272
|
Returns: A list of Keras substitutions used to build a virtual graph with composed activation-weights pairs.
|
273
273
|
"""
|
274
274
|
|
275
|
-
raise
|
275
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
276
276
|
f'framework\'s get_substitutions_virtual_weights_activation_coupling '
|
277
277
|
f'method.') # pragma: no cover
|
278
278
|
|
@@ -288,7 +288,7 @@ class FrameworkImplementation(ABC):
|
|
288
288
|
Returns:
|
289
289
|
A list of the framework substitutions used after we apply second moment statistics.
|
290
290
|
"""
|
291
|
-
raise
|
291
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
292
292
|
f'framework\'s get_substitutions_after_second_moment_correction '
|
293
293
|
f'method.') # pragma: no cover
|
294
294
|
|
@@ -316,7 +316,7 @@ class FrameworkImplementation(ABC):
|
|
316
316
|
A function that computes the metric.
|
317
317
|
"""
|
318
318
|
|
319
|
-
raise
|
319
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
320
320
|
f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover
|
321
321
|
|
322
322
|
def get_node_prior_info(self, node: BaseNode,
|
@@ -334,7 +334,7 @@ class FrameworkImplementation(ABC):
|
|
334
334
|
NodePriorInfo with information about the node.
|
335
335
|
"""
|
336
336
|
|
337
|
-
raise
|
337
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
338
338
|
f'framework\'s get_node_prior_info method.') # pragma: no cover
|
339
339
|
|
340
340
|
def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
|
@@ -345,7 +345,7 @@ class FrameworkImplementation(ABC):
|
|
345
345
|
Returns: True if the node should be considered an interest point, False otherwise.
|
346
346
|
"""
|
347
347
|
|
348
|
-
raise
|
348
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
349
349
|
f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover
|
350
350
|
|
351
351
|
def get_mp_node_distance_fn(self, n: BaseNode,
|
@@ -364,7 +364,7 @@ class FrameworkImplementation(ABC):
|
|
364
364
|
Returns: A distance function between two tensors and a axis on which the distance is computed (if exists).
|
365
365
|
"""
|
366
366
|
|
367
|
-
raise
|
367
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
368
368
|
f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover
|
369
369
|
|
370
370
|
|
@@ -381,7 +381,7 @@ class FrameworkImplementation(ABC):
|
|
381
381
|
|
382
382
|
"""
|
383
383
|
|
384
|
-
raise
|
384
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
385
385
|
f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover
|
386
386
|
|
387
387
|
@abstractmethod
|
@@ -398,7 +398,7 @@ class FrameworkImplementation(ABC):
|
|
398
398
|
Returns: The MAC count of the operation
|
399
399
|
"""
|
400
400
|
|
401
|
-
raise
|
401
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
402
402
|
f'framework\'s get_node_mac_operations method.') # pragma: no cover
|
403
403
|
|
404
404
|
@abstractmethod
|
@@ -419,7 +419,7 @@ class FrameworkImplementation(ABC):
|
|
419
419
|
Returns:
|
420
420
|
A Graph after second moment correction.
|
421
421
|
"""
|
422
|
-
raise
|
422
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
423
423
|
f'framework\'s apply_second_moment_correction method.') # pragma: no cover
|
424
424
|
|
425
425
|
@abstractmethod
|
@@ -436,7 +436,7 @@ class FrameworkImplementation(ABC):
|
|
436
436
|
Returns:
|
437
437
|
The output of the model inference on the given input.
|
438
438
|
"""
|
439
|
-
raise
|
439
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
440
440
|
f'framework\'s sensitivity_eval_inference method.') # pragma: no cover
|
441
441
|
|
442
442
|
def get_inferable_quantizers(self, node: BaseNode):
|
@@ -452,5 +452,19 @@ class FrameworkImplementation(ABC):
|
|
452
452
|
|
453
453
|
"""
|
454
454
|
|
455
|
-
raise
|
456
|
-
f'framework\'s get_inferable_quantizers method.') # pragma: no cover
|
455
|
+
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
|
456
|
+
f'framework\'s get_inferable_quantizers method.') # pragma: no cover
|
457
|
+
|
458
|
+
@staticmethod
|
459
|
+
def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
|
460
|
+
"""
|
461
|
+
Create DataLoader based on samples yielded by data_gen.
|
462
|
+
|
463
|
+
Args:
|
464
|
+
data_gen_fn: data generator factory.
|
465
|
+
batch_size: target batch size.
|
466
|
+
|
467
|
+
Returns:
|
468
|
+
Framework dataloader.
|
469
|
+
"""
|
470
|
+
raise NotImplementedError() # pragma: no cover
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
from model_compression_toolkit.core.common.hessian.hessian_scores_request import (
|
16
|
-
HessianScoresRequest, HessianMode, HessianScoresGranularity
|
16
|
+
HessianScoresRequest, HessianMode, HessianScoresGranularity
|
17
17
|
)
|
18
18
|
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
|
19
19
|
import model_compression_toolkit.core.common.hessian.hessian_info_utils as hessian_utils
|