mct-nightly 2.1.0.20240708.453__py3-none-any.whl → 2.1.0.20240710.440__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240708.453.dist-info → mct_nightly-2.1.0.20240710.440.dist-info}/METADATA +1 -1
- {mct_nightly-2.1.0.20240708.453.dist-info → mct_nightly-2.1.0.20240710.440.dist-info}/RECORD +31 -31
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +12 -12
- model_compression_toolkit/core/common/hessian/__init__.py +1 -1
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +74 -69
- model_compression_toolkit/core/common/hessian/hessian_info_utils.py +1 -1
- model_compression_toolkit/core/common/hessian/{trace_hessian_calculator.py → hessian_scores_calculator.py} +11 -11
- model_compression_toolkit/core/common/hessian/{trace_hessian_request.py → hessian_scores_request.py} +15 -15
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -8
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -5
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +4 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -5
- model_compression_toolkit/core/keras/hessian/{activation_trace_hessian_calculator_keras.py → activation_hessian_scores_calculator_keras.py} +26 -26
- model_compression_toolkit/core/keras/hessian/{trace_hessian_calculator_keras.py → hessian_scores_calculator_keras.py} +14 -14
- model_compression_toolkit/core/keras/hessian/{weights_trace_hessian_calculator_keras.py → weights_hessian_scores_calculator_keras.py} +27 -27
- model_compression_toolkit/core/keras/keras_implementation.py +30 -30
- model_compression_toolkit/core/pytorch/hessian/{activation_trace_hessian_calculator_pytorch.py → activation_hessian_scores_calculator_pytorch.py} +25 -25
- model_compression_toolkit/core/pytorch/hessian/{trace_hessian_calculator_pytorch.py → hessian_scores_calculator_pytorch.py} +14 -14
- model_compression_toolkit/core/pytorch/hessian/{weights_trace_hessian_calculator_pytorch.py → weights_hessian_scores_calculator_pytorch.py} +25 -25
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +30 -30
- model_compression_toolkit/core/quantization_prep_runner.py +1 -1
- model_compression_toolkit/gptq/common/gptq_training.py +30 -30
- model_compression_toolkit/gptq/keras/gptq_training.py +1 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
- model_compression_toolkit/gptq/runner.py +2 -2
- model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
- {mct_nightly-2.1.0.20240708.453.dist-info → mct_nightly-2.1.0.20240710.440.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240708.453.dist-info → mct_nightly-2.1.0.20240710.440.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240708.453.dist-info → mct_nightly-2.1.0.20240710.440.dist-info}/top_level.txt +0 -0
{mct_nightly-2.1.0.20240708.453.dist-info → mct_nightly-2.1.0.20240710.440.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=3GQ9gfo5RfzoWI0wjlyOHA6G4dz31vXoA1cBJdGk3Mw,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -6,11 +6,11 @@ model_compression_toolkit/metadata.py,sha256=IyoON37lBv3TI0rZGCP4K5t3oYI4TOmYy-L
|
|
6
6
|
model_compression_toolkit/core/__init__.py,sha256=TrRgkWpT1AN2Faw1M_1HXyJkJnbxfn9p-RigDZl7pg0,1982
|
7
7
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
8
8
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=kM70wmNG3yMFiGQc0uO0wn9j4ZbSWxUEykpxDK55doc,10567
|
9
|
-
model_compression_toolkit/core/quantization_prep_runner.py,sha256=
|
9
|
+
model_compression_toolkit/core/quantization_prep_runner.py,sha256=K9eJ7VbB_rpeyxX4yEnorOmSxFW3DkvofzxS6QI8Hp8,6454
|
10
10
|
model_compression_toolkit/core/runner.py,sha256=4TtOgyNb4cXr52dOlDqYxLm3rnLR6uHPDNoZiEFL9XA,12655
|
11
11
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
12
12
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
13
|
-
model_compression_toolkit/core/common/framework_implementation.py,sha256=
|
13
|
+
model_compression_toolkit/core/common/framework_implementation.py,sha256=zg2Sznd-uqpN8hy9C22jah_27qPYXNajYx3xMR3AlMw,20938
|
14
14
|
model_compression_toolkit/core/common/framework_info.py,sha256=1ZMMGS9ip-kSflqkartyNRt9aQ5ub1WepuTRcTy-YSQ,6337
|
15
15
|
model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
|
16
16
|
model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
|
@@ -44,11 +44,11 @@ model_compression_toolkit/core/common/graph/memory_graph/cut.py,sha256=aPdXJPP5a
|
|
44
44
|
model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256=crV2NCLVO8jx9MlryZBYuJKFe_G9HfM7rUR64fDymlw,17045
|
45
45
|
model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
|
46
46
|
model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=gw4av_rzn_3oEAPpD3B7PHZDqnxHMjIESevl6ppPnkk,7175
|
47
|
-
model_compression_toolkit/core/common/hessian/__init__.py,sha256=
|
48
|
-
model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=
|
49
|
-
model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=
|
50
|
-
model_compression_toolkit/core/common/hessian/
|
51
|
-
model_compression_toolkit/core/common/hessian/
|
47
|
+
model_compression_toolkit/core/common/hessian/__init__.py,sha256=6216QgHl7h4DXGn5ForP9Tija-wrBSONNtQ769ikP2s,1025
|
48
|
+
model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=DHbZqFDuDir1QWN-YkYBzaoGDujgYam1hT2ea6uL3yM,21009
|
49
|
+
model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=1axmN0tjJSo_7hUr2d2KMv4y1pBi19cqWSQpi4BbdsA,1458
|
50
|
+
model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=Pe4uKerx-MeDQPJ7Slr8fvFUHfv02q33w3gbQK5kBKs,4186
|
51
|
+
model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=atGJgJBL9uwYRC3t9NnzGgHYxV4XJj4Ai_xPpQH0rhY,3229
|
52
52
|
model_compression_toolkit/core/common/matchers/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
53
53
|
model_compression_toolkit/core/common/matchers/base_graph_filter.py,sha256=mTk54z0mIbFmPOb4h0xfLtLDookcFyNh8H0pIN5js_M,3091
|
54
54
|
model_compression_toolkit/core/common/matchers/base_matcher.py,sha256=JCj-NLAXOJa-GcSX-94PVUTWjooQUd0NemiyNg5uKGQ,2210
|
@@ -62,9 +62,9 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,s
|
|
62
62
|
model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=7dKMi5S0zQZ16m8NWn1XIuoXsKuZUg64G4-uK8-j1PQ,5177
|
63
63
|
model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=H8qYkJsk88OszUJo-Zde7vTmWiypLTg9KbbzIZ-hhvM,2812
|
64
64
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=rppRZJdSCQGiZsd93QxoUIhj51eETvQbuI5JiC2TUeA,4963
|
65
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=
|
65
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=pk8HRoShDhiUprBC4m1AFQv1SacS4hOrj0MRdbq-5gY,7556
|
66
66
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=TTTux4YiOnQqt-2h7Y38959XaDwNZc0eufLMx_yws5U,37578
|
67
|
-
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=
|
67
|
+
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=PNGWP0bPkQNP1jCORP6FQgGIr616Kg1YSe_hy-BDg0I,27546
|
68
68
|
model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
|
69
69
|
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=KifDMbm7qkSfvSl6pcZzQ82naIXzeKL6aT-VsvWZYyc,7901
|
70
70
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
@@ -92,7 +92,7 @@ model_compression_toolkit/core/common/pruning/pruning_section.py,sha256=I4vxh5iP
|
|
92
92
|
model_compression_toolkit/core/common/pruning/importance_metrics/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
|
93
93
|
model_compression_toolkit/core/common/pruning/importance_metrics/base_importance_metric.py,sha256=qMAtLWs5fjbSco8nhbig5TkuacdhnDW7cy3avMHRGX4,1988
|
94
94
|
model_compression_toolkit/core/common/pruning/importance_metrics/importance_metric_factory.py,sha256=E-fKuRfrNYlN3nNcRAbnkJkFNwClvyrL_Js1qDPxIKA,1999
|
95
|
-
model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py,sha256=
|
95
|
+
model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py,sha256=sJGsR8vTs43SV7Wfz929pCwLM-_7aXUyO5nBUig9K9s,14055
|
96
96
|
model_compression_toolkit/core/common/pruning/mask/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
97
97
|
model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=APY8BsM9B7ZxVCH6n1xs9fSCTB_A9ou9gHrCQl1DOdI,5131
|
98
98
|
model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=4ohJrJHNzZk5uMnZEYkwLx2TDGzkh5kRhLGNVYNC6dc,5978
|
@@ -109,12 +109,12 @@ model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha
|
|
109
109
|
model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
|
110
110
|
model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=O4qFJw3nBYUD4cGbO8haGXZ2-piSqoRpDKDD74iXSxw,12417
|
111
111
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
|
112
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=
|
112
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=Fd_gxr5js-mqEwucaRR1CQAZ1W_wna19L1gAPeOzxRQ,23610
|
113
113
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=t0XSwjfOxcq2Sj2PGzccntz1GGv2eqVn9oR3OI0t9wo,8533
|
114
114
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
|
115
115
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=HfnhQ4MxGpb95gOWXD1vnroTxxjFt9VFd4jIdo-rvAQ,10623
|
116
116
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=noEdvGiyyW7acgQ2OFWLedCODibTGYJifC9qo8YIU5U,4558
|
117
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=
|
117
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=oME8T6Slgl1SJNpXV4oY3UhuX0YmKYbcWDsLiCYq7oE,8651
|
118
118
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=o2XNY_0pUUyId02TUVQBtkux_i40NCcnzuobSeQLy3E,42863
|
119
119
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=zSNda0jN8cP41m6g5TOv5WvATwIhV8z6AVM1Es6rq1s,4419
|
120
120
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=4TP41wPYC0azIzFxUt-lNlKUPIIXQeE4H1SYHkON75k,11875
|
@@ -150,7 +150,7 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
|
|
150
150
|
model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
|
151
151
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
152
152
|
model_compression_toolkit/core/keras/default_framework_info.py,sha256=HcHplb7IcnOTyK2p6uhp3OVG4-RV3RDo9C_4evaIzkQ,4981
|
153
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
153
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=qrmFzs749gVCEanw-paEsbM50zECtMV8qDMxKFTmyj0,29906
|
154
154
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
155
155
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
|
156
156
|
model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=Xmk2ZL5CaYdb7iG62HdtZ1F64vap7ffnrsuR3e3G5hc,4851
|
@@ -184,9 +184,9 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/softmax_s
|
|
184
184
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=wH9ocMLL725-uUPU-zCxdd8NwT5nyd0ZShmI7iuTwF8,1462
|
185
185
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/weights_activation_split.py,sha256=rjIheZW7LbSPv9bzMSmC8wl6UUxaTkd4J2IHinObT-Y,1814
|
186
186
|
model_compression_toolkit/core/keras/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
187
|
-
model_compression_toolkit/core/keras/hessian/
|
188
|
-
model_compression_toolkit/core/keras/hessian/
|
189
|
-
model_compression_toolkit/core/keras/hessian/
|
187
|
+
model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculator_keras.py,sha256=p0eM-EO5ltXYjSkd7B3h9BWBcuRZvjxEcA8WaNvdyqc,8901
|
188
|
+
model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py,sha256=Cep-bQEwLyqLYfLxM0ByOQd_oAIT-uXjr3dFUd8T9CY,3954
|
189
|
+
model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py,sha256=970C-8J4HtUalNWvZAKlWFZVfw5r6SBdt5RQU_mZ7M0,12261
|
190
190
|
model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
191
191
|
model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=aW8wR13fK6P6xzbU9XGU60IO1yYzXSo_Hk4qeq486kg,5137
|
192
192
|
model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=Ziydik2j-LvNBXP3TSfUD6rEezPAikzQGib0_IXkmGM,6729
|
@@ -213,7 +213,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
|
|
213
213
|
model_compression_toolkit/core/pytorch/constants.py,sha256=AguUnAsNlj41gwuKIP_7nos3FcJHsIAjewLXSQdrDQM,2624
|
214
214
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
|
215
215
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
|
216
|
-
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=
|
216
|
+
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=XL_RZcfnb_ZY2jdCjOxxz7SbRBzMokbOWsTuYOSjyRU,27569
|
217
217
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
|
218
218
|
model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=E6ifk1HdO60k4IRH2EFBzAYWtwUlrGqJoQ66nknpHoQ,4983
|
219
219
|
model_compression_toolkit/core/pytorch/utils.py,sha256=OT_mrNEJqPgWLdtQuivKMQVjtJY49cmoIVvbRhANl1w,3004
|
@@ -249,9 +249,9 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transfo
|
|
249
249
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=WmEa8Xjji-_tIbthDxlLAGSr69nWk-YKcHNaVqLa7sg,1375
|
250
250
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/weights_activation_split.py,sha256=tp78axmUQc0Zpj3KwVmV0PGYHvCf7sAW_sRmXXw7gsY,1616
|
251
251
|
model_compression_toolkit/core/pytorch/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
252
|
-
model_compression_toolkit/core/pytorch/hessian/
|
253
|
-
model_compression_toolkit/core/pytorch/hessian/
|
254
|
-
model_compression_toolkit/core/pytorch/hessian/
|
252
|
+
model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py,sha256=xc_-utc9_Hq915X02VbT8zXxGqxE4fFz6dhiiZwU3ok,8578
|
253
|
+
model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py,sha256=vXluX-awgavv7DGihG9HrlvLhak8qIHy837PPTOd4jg,3471
|
254
|
+
model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py,sha256=C4-7naBQUh8TN6fEwkyKY6rlY_nvHSAmCnWT4iMBs8E,8497
|
255
255
|
model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
256
256
|
model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=-6oep2WJ85-JmIxZa-e2AmBpbORoKe4Xdduz2ZidwvM,4871
|
257
257
|
model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=KVZTKCYzJqqzF5nFEiuGMv_sNeVuBTxhmxWMFacKOxE,6337
|
@@ -332,17 +332,17 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha
|
|
332
332
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=YT9IVdpKaJbAW3msYRoQNIgqRSEVwSarRy6qlWCrBfk,5389
|
333
333
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=4sN5z-6BXrTE5Dp2FX_jKO9ty5iZ2r4RM7XvXtDVLSI,9348
|
334
334
|
model_compression_toolkit/gptq/__init__.py,sha256=YKg-tMj9D4Yd0xW9VRD5EN1J5JrmlRbNEF2fOSgodqA,1228
|
335
|
-
model_compression_toolkit/gptq/runner.py,sha256=
|
335
|
+
model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
|
336
336
|
model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
337
337
|
model_compression_toolkit/gptq/common/gptq_config.py,sha256=U-NiVEedkOsVaFq-iXU2Xcqp99Rgf0f2I3oANdVMhMY,5672
|
338
338
|
model_compression_toolkit/gptq/common/gptq_constants.py,sha256=QSm6laLkIV0LYmU0BLtmKp3Fi3SqDfbncFQWOGA1cGU,611
|
339
339
|
model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
|
340
340
|
model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
|
341
|
-
model_compression_toolkit/gptq/common/gptq_training.py,sha256=
|
341
|
+
model_compression_toolkit/gptq/common/gptq_training.py,sha256=nI_XVa7WbfCcbgHMFgnPtnD77m5ezAB306z7VE0XFvU,16527
|
342
342
|
model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
343
343
|
model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
|
344
344
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
|
345
|
-
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=
|
345
|
+
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=NXTNsVrO9DTh0uvc8V7rFaM0fYg2OA18ZrYd-cKZ7Z4,19159
|
346
346
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=MKIfrRTRH3zCuxCR1g9ZVIFyuSSr0e0sDybqh4LDM7E,4672
|
347
347
|
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=SjmBTuSwki4JTPVhxvJMFK9uAsmEm2c6VV11NnM6eEo,15117
|
348
348
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
@@ -359,7 +359,7 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
|
|
359
359
|
model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
360
360
|
model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
|
361
361
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
362
|
-
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256
|
362
|
+
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=-daninmlPGfKsBNPB2C3gT6rK0G5YeyJsuOLA0JlfBU,16633
|
363
363
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
|
364
364
|
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=jcvRKBuMkrerNE8oWIJFp802pyFO0dnA-4hRnclKbWE,13569
|
365
365
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
@@ -400,7 +400,7 @@ model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cc
|
|
400
400
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=I4KlaGv17k71IyjuSG9M0OlXlD5P0pfvKa6oCyRQ5FE,13517
|
401
401
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=EED6LfqhX_OhDRJ9e4GwbpgNC9vq7hoXyJS2VPvG2qc,10789
|
402
402
|
model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
403
|
-
model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=
|
403
|
+
model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=pRy2B5OsaLi33p4hozjr0rzAooT8Gic3_qxTl66J900,13375
|
404
404
|
model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
|
405
405
|
model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py,sha256=soxP4moJxQaziq5ccP3Du34fSIVSFyZq6hD8YuaDn88,2187
|
406
406
|
model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=sFWGu76PZ9dSRf3L0uZI6YwLIs0biBND1tl76I1piBQ,5721
|
@@ -517,8 +517,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
517
517
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=yrZNVRm2IRU7r7R-hjS2lOQ6wvEEvbeunvf2jKoWjXk,3277
|
518
518
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
519
519
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=eyMoXt5o5EnMr6d-rpCwQdX5mAiYiymvbgKv4tf7-a0,4576
|
520
|
-
mct_nightly-2.1.0.
|
521
|
-
mct_nightly-2.1.0.
|
522
|
-
mct_nightly-2.1.0.
|
523
|
-
mct_nightly-2.1.0.
|
524
|
-
mct_nightly-2.1.0.
|
520
|
+
mct_nightly-2.1.0.20240710.440.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
521
|
+
mct_nightly-2.1.0.20240710.440.dist-info/METADATA,sha256=MHsFULGCUhQiOAsGEFxkud6ktNEV5C_KQSlzCVbHCSE,19719
|
522
|
+
mct_nightly-2.1.0.20240710.440.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
523
|
+
mct_nightly-2.1.0.20240710.440.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
524
|
+
mct_nightly-2.1.0.20240710.440.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.1.0.
|
30
|
+
__version__ = "2.1.0.20240710.000440"
|
@@ -24,7 +24,7 @@ from model_compression_toolkit.core.common import BaseNode
|
|
24
24
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
25
25
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
26
26
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
27
|
-
from model_compression_toolkit.core.common.hessian import
|
27
|
+
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianInfoService
|
28
28
|
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
|
29
29
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
30
30
|
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
@@ -49,23 +49,23 @@ class FrameworkImplementation(ABC):
|
|
49
49
|
raise NotImplemented(f'{self.__class__.__name__} did not supply a constants module.') # pragma: no cover
|
50
50
|
|
51
51
|
@abstractmethod
|
52
|
-
def
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
52
|
+
def get_hessian_scores_calculator(self,
|
53
|
+
graph: Graph,
|
54
|
+
input_images: List[Any],
|
55
|
+
hessian_scores_request: HessianScoresRequest,
|
56
|
+
num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
|
57
57
|
"""
|
58
|
-
Get framework
|
58
|
+
Get framework hessian-approximation scores calculator based on the hessian scores request.
|
59
59
|
Args:
|
60
60
|
input_images: Images to use for computation.
|
61
61
|
graph: Float graph to compute the approximation of its different nodes.
|
62
|
-
|
63
|
-
num_iterations_for_approximation: Number of iterations to use when approximating the Hessian
|
62
|
+
hessian_scores_request: HessianScoresRequest to search for the desired calculator.
|
63
|
+
num_iterations_for_approximation: Number of iterations to use when approximating the Hessian-approximation scores.
|
64
64
|
|
65
|
-
Returns:
|
65
|
+
Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request.
|
66
66
|
"""
|
67
67
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
68
|
-
f'framework\'s
|
68
|
+
f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover
|
69
69
|
|
70
70
|
@abstractmethod
|
71
71
|
def to_numpy(self, tensor: Any) -> np.ndarray:
|
@@ -310,7 +310,7 @@ class FrameworkImplementation(ABC):
|
|
310
310
|
representative_data_gen: Dataset to use for retrieving images for the models inputs.
|
311
311
|
fw_info: FrameworkInfo object with information about the specific framework's model.
|
312
312
|
disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
|
313
|
-
hessian_info_service: HessianInfoService to fetch
|
313
|
+
hessian_info_service: HessianInfoService to fetch information based on Hessian-approximation.
|
314
314
|
|
315
315
|
Returns:
|
316
316
|
A function that computes the metric.
|
@@ -12,6 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from model_compression_toolkit.core.common.hessian.
|
15
|
+
from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, HessianMode, HessianScoresGranularity
|
16
16
|
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
|
17
17
|
import model_compression_toolkit.core.common.hessian.hessian_info_utils as hessian_utils
|
@@ -19,24 +19,24 @@ from tqdm import tqdm
|
|
19
19
|
from typing import Callable, List, Dict, Any, Tuple
|
20
20
|
|
21
21
|
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
|
22
|
-
from model_compression_toolkit.core.common.hessian.
|
23
|
-
|
22
|
+
from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, \
|
23
|
+
HessianScoresGranularity, HessianMode
|
24
24
|
from model_compression_toolkit.logger import Logger
|
25
25
|
|
26
26
|
|
27
27
|
class HessianInfoService:
|
28
28
|
"""
|
29
|
-
A service to manage, store, and compute
|
29
|
+
A service to manage, store, and compute information based on the Hessian matrix approximation.
|
30
30
|
|
31
|
-
This class provides functionalities to compute
|
32
|
-
on the different parameters (such as number of iterations for approximating the
|
31
|
+
This class provides functionalities to compute information based on the Hessian matrix approximation
|
32
|
+
based on the different parameters (such as number of iterations for approximating the scores)
|
33
33
|
and input images (using representative_dataset_gen).
|
34
34
|
It also offers cache management capabilities for efficient computation and retrieval.
|
35
35
|
|
36
36
|
Note:
|
37
37
|
- The Hessian provides valuable information about the curvature of the loss function.
|
38
38
|
- Computation can be computationally heavy and time-consuming.
|
39
|
-
- The computed
|
39
|
+
- The computed information is based on Hessian approximation (and not the precise Hessian matrix).
|
40
40
|
"""
|
41
41
|
|
42
42
|
def __init__(self,
|
@@ -49,7 +49,7 @@ class HessianInfoService:
|
|
49
49
|
Args:
|
50
50
|
graph: Float graph.
|
51
51
|
representative_dataset_gen: A callable that provides a dataset for sampling.
|
52
|
-
fw_impl: Framework-specific implementation for
|
52
|
+
fw_impl: Framework-specific implementation for Hessian approximation scores computation.
|
53
53
|
"""
|
54
54
|
self.graph = graph
|
55
55
|
|
@@ -58,7 +58,7 @@ class HessianInfoService:
|
|
58
58
|
self.fw_impl = fw_impl
|
59
59
|
self.num_iterations_for_approximation = num_iterations_for_approximation
|
60
60
|
|
61
|
-
self.
|
61
|
+
self.hessian_scores_request_to_scores_list = {}
|
62
62
|
|
63
63
|
def _sample_batch_representative_dataset(self,
|
64
64
|
representative_dataset: Any,
|
@@ -142,11 +142,11 @@ class HessianInfoService:
|
|
142
142
|
|
143
143
|
def _clear_saved_hessian_info(self):
|
144
144
|
"""Clears the saved info approximations."""
|
145
|
-
self.
|
145
|
+
self.hessian_scores_request_to_scores_list={}
|
146
146
|
|
147
|
-
def
|
147
|
+
def count_saved_scores_of_request(self, hessian_request: HessianScoresRequest) -> Dict:
|
148
148
|
"""
|
149
|
-
Counts the saved approximations of Hessian
|
149
|
+
Counts the saved approximations of Hessian scores for a specific request.
|
150
150
|
If some approximations were computed for this request before, the amount of approximations (per image)
|
151
151
|
will be returned. If not, zero is returned.
|
152
152
|
|
@@ -166,55 +166,58 @@ class HessianInfoService:
|
|
166
166
|
Logger.critical(f"Expecting the Hessian request to include only non-reused nodes at this point, "
|
167
167
|
f"but found node {n.name} with 'reuse' status.")
|
168
168
|
# Check if the request for this node is in the saved info and store its count, otherwise store 0
|
169
|
-
per_node_counter[n] = len(self.
|
169
|
+
per_node_counter[n] = len(self.hessian_scores_request_to_scores_list.get(hessian_request, []))
|
170
170
|
|
171
171
|
return per_node_counter
|
172
172
|
|
173
|
-
def compute(self,
|
173
|
+
def compute(self,
|
174
|
+
hessian_scores_request: HessianScoresRequest,
|
175
|
+
representative_dataset_gen,
|
176
|
+
num_hessian_samples: int,
|
174
177
|
last_iter_remain_samples: List[List[np.ndarray]] = None):
|
175
178
|
"""
|
176
|
-
Computes
|
179
|
+
Computes scores based on the Hessian matrix approximation according to the
|
177
180
|
provided request configuration and stores it in the cache.
|
178
181
|
|
179
182
|
Args:
|
180
|
-
|
183
|
+
hessian_scores_request: Configuration for which to compute the approximation.
|
181
184
|
representative_dataset_gen: A callable that provides a dataset for sampling.
|
182
185
|
num_hessian_samples: Number of requested samples to compute batch Hessian approximation scores.
|
183
186
|
last_iter_remain_samples: A list of input samples (for each input layer) with remaining samples from
|
184
187
|
previous iterations.
|
185
188
|
"""
|
186
|
-
Logger.debug(f"Computing Hessian-
|
189
|
+
Logger.debug(f"Computing Hessian-scores approximations for nodes {hessian_scores_request.target_nodes}.")
|
187
190
|
|
188
191
|
images, next_iter_remain_samples = representative_dataset_gen(num_hessian_samples=num_hessian_samples,
|
189
192
|
last_iter_remain_samples=last_iter_remain_samples)
|
190
193
|
|
191
194
|
# Compute and store the computed approximation in the saved info
|
192
195
|
topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
|
193
|
-
|
196
|
+
hessian_scores_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name))
|
194
197
|
|
195
|
-
# Get the framework-specific calculator
|
196
|
-
fw_hessian_calculator = self.fw_impl.
|
197
|
-
|
198
|
-
|
199
|
-
|
198
|
+
# Get the framework-specific calculator Hessian-approximation scores
|
199
|
+
fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph,
|
200
|
+
input_images=images,
|
201
|
+
hessian_scores_request=hessian_scores_request,
|
202
|
+
num_iterations_for_approximation=self.num_iterations_for_approximation)
|
200
203
|
|
201
|
-
|
204
|
+
hessian_scores = fw_hessian_calculator.compute()
|
202
205
|
|
203
|
-
for node, hessian in zip(
|
204
|
-
single_node_request = self._construct_single_node_request(
|
205
|
-
|
206
|
+
for node, hessian in zip(hessian_scores_request.target_nodes, hessian_scores):
|
207
|
+
single_node_request = self._construct_single_node_request(hessian_scores_request.mode,
|
208
|
+
hessian_scores_request.granularity,
|
206
209
|
node)
|
207
210
|
|
208
211
|
# The hessian for each node is expected to be a tensor where the first axis represents the number of
|
209
212
|
# images in the batch on which the approximation was computed.
|
210
213
|
# We collect the results as a list of a result for images, which is combined across batches.
|
211
|
-
# After conversion,
|
214
|
+
# After conversion, hessian_scores_request_to_scores_list for a request of a single node should be a list of
|
212
215
|
# results of all images, where each result is a tensor of the shape depending on the granularity.
|
213
|
-
if single_node_request in self.
|
214
|
-
self.
|
216
|
+
if single_node_request in self.hessian_scores_request_to_scores_list:
|
217
|
+
self.hessian_scores_request_to_scores_list[single_node_request] += (
|
215
218
|
self._convert_tensor_to_list_of_appx_results(hessian))
|
216
219
|
else:
|
217
|
-
self.
|
220
|
+
self.hessian_scores_request_to_scores_list[single_node_request] = (
|
218
221
|
self._convert_tensor_to_list_of_appx_results(hessian))
|
219
222
|
|
220
223
|
# In case that we are required to return a number of scores that is larger that the computation batch size
|
@@ -226,15 +229,15 @@ class HessianInfoService:
|
|
226
229
|
and len(next_iter_remain_samples[0]) > 0 else None
|
227
230
|
|
228
231
|
def fetch_hessian(self,
|
229
|
-
|
232
|
+
hessian_scores_request: HessianScoresRequest,
|
230
233
|
required_size: int,
|
231
234
|
batch_size: int = 1) -> List[List[np.ndarray]]:
|
232
235
|
"""
|
233
|
-
Fetches the computed approximations of the
|
236
|
+
Fetches the computed approximations of the Hessian-based scores for the given
|
234
237
|
request and required size.
|
235
238
|
|
236
239
|
Args:
|
237
|
-
|
240
|
+
hessian_scores_request: Configuration for which to fetch the approximation.
|
238
241
|
required_size: Number of approximations required.
|
239
242
|
batch_size: The Hessian computation batch size.
|
240
243
|
|
@@ -245,28 +248,28 @@ class HessianInfoService:
|
|
245
248
|
OC for per-output-channel when the requested node has OC output-channels, etc.)
|
246
249
|
"""
|
247
250
|
|
248
|
-
if len(
|
251
|
+
if len(hessian_scores_request.target_nodes) == 0:
|
249
252
|
return []
|
250
253
|
|
251
254
|
if required_size == 0:
|
252
|
-
return [[] for _ in
|
255
|
+
return [[] for _ in hessian_scores_request.target_nodes]
|
253
256
|
|
254
|
-
Logger.info(f"\nEnsuring {required_size} Hessian-
|
255
|
-
f"{
|
257
|
+
Logger.info(f"\nEnsuring {required_size} Hessian-approximation scores for nodes "
|
258
|
+
f"{hessian_scores_request.target_nodes}.")
|
256
259
|
|
257
260
|
# Replace node in reused target nodes with a representing node from the 'reuse group'.
|
258
|
-
for n in
|
261
|
+
for n in hessian_scores_request.target_nodes:
|
259
262
|
if n.reuse_group:
|
260
263
|
rep_node = self._get_representing_of_reuse_group(n)
|
261
|
-
|
262
|
-
if rep_node not in
|
263
|
-
|
264
|
+
hessian_scores_request.target_nodes.remove(n)
|
265
|
+
if rep_node not in hessian_scores_request.target_nodes:
|
266
|
+
hessian_scores_request.target_nodes.append(rep_node)
|
264
267
|
|
265
268
|
# Ensure the saved info has the required number of approximations
|
266
|
-
self._populate_saved_info_to_size(
|
269
|
+
self._populate_saved_info_to_size(hessian_scores_request, required_size, batch_size)
|
267
270
|
|
268
271
|
# Return the saved approximations for the given request
|
269
|
-
return self._collect_saved_hessians_for_request(
|
272
|
+
return self._collect_saved_hessians_for_request(hessian_scores_request, required_size)
|
270
273
|
|
271
274
|
def _get_representing_of_reuse_group(self, node) -> Any:
|
272
275
|
"""
|
@@ -286,20 +289,20 @@ class HessianInfoService:
|
|
286
289
|
return father_nodes[0]
|
287
290
|
|
288
291
|
def _populate_saved_info_to_size(self,
|
289
|
-
|
292
|
+
hessian_scores_request: HessianScoresRequest,
|
290
293
|
required_size: int,
|
291
294
|
batch_size: int = 1):
|
292
295
|
"""
|
293
|
-
Ensures that the saved info has the required size of
|
296
|
+
Ensures that the saved info has the required size of Hessian approximation scores for the given request.
|
294
297
|
|
295
298
|
Args:
|
296
|
-
|
297
|
-
required_size: Required number of
|
299
|
+
hessian_scores_request: Configuration of the request to ensure the saved info size.
|
300
|
+
required_size: Required number of Hessian-approximation scores.
|
298
301
|
batch_size: The Hessian computation batch size.
|
299
302
|
"""
|
300
303
|
|
301
304
|
# Get the current number of saved approximations for each node in the request
|
302
|
-
current_existing_hessians = self.
|
305
|
+
current_existing_hessians = self.count_saved_scores_of_request(hessian_scores_request)
|
303
306
|
|
304
307
|
# Compute the required number of approximations to meet the required size.
|
305
308
|
# Since we allow batch and multi-nodes computation, we take the node with the maximal number of missing
|
@@ -308,9 +311,9 @@ class HessianInfoService:
|
|
308
311
|
max_remaining_hessians = required_size - min_exist_hessians
|
309
312
|
|
310
313
|
Logger.info(
|
311
|
-
f"Running Hessian approximation computation for {len(
|
312
|
-
f"The node with minimal existing Hessian-
|
313
|
-
f"
|
314
|
+
f"Running Hessian approximation computation for {len(hessian_scores_request.target_nodes)} nodes.\n "
|
315
|
+
f"The node with minimal existing Hessian-approximation scores has {min_exist_hessians} "
|
316
|
+
f"approximated scores computed.\n"
|
314
317
|
f"{max_remaining_hessians} approximations left to compute...")
|
315
318
|
|
316
319
|
hessian_representative_dataset = partial(self._sample_batch_representative_dataset,
|
@@ -325,30 +328,31 @@ class HessianInfoService:
|
|
325
328
|
pbar.update(1)
|
326
329
|
size_to_compute = min(max_remaining_hessians, batch_size)
|
327
330
|
next_iter_remaining_samples = (
|
328
|
-
self.compute(
|
331
|
+
self.compute(hessian_scores_request, hessian_representative_dataset, size_to_compute,
|
329
332
|
last_iter_remain_samples=next_iter_remaining_samples))
|
330
333
|
max_remaining_hessians -= size_to_compute
|
331
334
|
|
332
|
-
def _collect_saved_hessians_for_request(self,
|
333
|
-
|
335
|
+
def _collect_saved_hessians_for_request(self,
|
336
|
+
hessian_scores_request: HessianScoresRequest,
|
337
|
+
required_size: int) -> List[List[np.ndarray]]:
|
334
338
|
"""
|
335
339
|
Collects Hessian approximation for the nodes in the given request.
|
336
340
|
|
337
341
|
Args:
|
338
|
-
|
339
|
-
required_size: Required number of
|
342
|
+
hessian_scores_request: Configuration for which to fetch the approximation.
|
343
|
+
required_size: Required number of Hessian-approximation scores.
|
340
344
|
|
341
345
|
Returns: A list with List of computed Hessian approximation (a tensor for each score) for each node
|
342
346
|
in the request.
|
343
347
|
|
344
348
|
"""
|
345
349
|
collected_results = []
|
346
|
-
for node in
|
347
|
-
single_node_request = self._construct_single_node_request(
|
348
|
-
|
350
|
+
for node in hessian_scores_request.target_nodes:
|
351
|
+
single_node_request = self._construct_single_node_request(hessian_scores_request.mode,
|
352
|
+
hessian_scores_request.granularity,
|
349
353
|
node)
|
350
354
|
|
351
|
-
res_for_node = self.
|
355
|
+
res_for_node = self.hessian_scores_request_to_scores_list.get(single_node_request)
|
352
356
|
if res_for_node is None: # pragma: no cover
|
353
357
|
Logger.critical(f"Couldn't find saved Hessian approximations for node {node.name}.")
|
354
358
|
if len(res_for_node) < required_size: # pragma: no cover
|
@@ -362,22 +366,23 @@ class HessianInfoService:
|
|
362
366
|
return collected_results
|
363
367
|
|
364
368
|
@staticmethod
|
365
|
-
def _construct_single_node_request(mode: HessianMode,
|
366
|
-
|
369
|
+
def _construct_single_node_request(mode: HessianMode,
|
370
|
+
granularity: HessianScoresGranularity,
|
371
|
+
target_nodes: List) -> HessianScoresRequest:
|
367
372
|
"""
|
368
373
|
Constructs a Hessian request with for a single node. Used for retrieving and maintaining cached results.
|
369
374
|
|
370
375
|
Args:
|
371
|
-
mode (HessianMode): Mode of Hessian's
|
372
|
-
granularity (
|
373
|
-
target_nodes (List[BaseNode]): The node in the float graph for which the Hessian's
|
376
|
+
mode (HessianMode): Mode of Hessian's approximation (w.r.t weights or activations).
|
377
|
+
granularity (HessianScoresGranularity): Granularity level for the approximation.
|
378
|
+
target_nodes (List[BaseNode]): The node in the float graph for which the Hessian's approximation scores is targeted.
|
374
379
|
|
375
|
-
Returns: A
|
380
|
+
Returns: A HessianScoresRequest with the given details for the requested node.
|
376
381
|
|
377
382
|
"""
|
378
|
-
return
|
379
|
-
|
380
|
-
|
383
|
+
return HessianScoresRequest(mode,
|
384
|
+
granularity,
|
385
|
+
target_nodes=[target_nodes])
|
381
386
|
|
382
387
|
@staticmethod
|
383
388
|
def _convert_tensor_to_list_of_appx_results(t: Any) -> List:
|
@@ -19,7 +19,7 @@ from model_compression_toolkit.constants import EPS
|
|
19
19
|
|
20
20
|
def normalize_scores(hessian_approximations: List) -> List[np.ndarray]:
|
21
21
|
"""
|
22
|
-
Normalize Hessian
|
22
|
+
Normalize Hessian scores approximations by dividing their value by the sum of all
|
23
23
|
other values.
|
24
24
|
|
25
25
|
Args:
|