mct-nightly 2.1.0.20240707.450__py3-none-any.whl → 2.1.0.20240709.429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {mct_nightly-2.1.0.20240707.450.dist-info → mct_nightly-2.1.0.20240709.429.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240707.450.dist-info → mct_nightly-2.1.0.20240709.429.dist-info}/RECORD +31 -31
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +12 -12
  5. model_compression_toolkit/core/common/hessian/__init__.py +1 -1
  6. model_compression_toolkit/core/common/hessian/hessian_info_service.py +74 -69
  7. model_compression_toolkit/core/common/hessian/hessian_info_utils.py +1 -1
  8. model_compression_toolkit/core/common/hessian/{trace_hessian_calculator.py → hessian_scores_calculator.py} +11 -11
  9. model_compression_toolkit/core/common/hessian/{trace_hessian_request.py → hessian_scores_request.py} +15 -15
  10. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  11. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -8
  12. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -5
  13. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +4 -4
  14. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -5
  15. model_compression_toolkit/core/keras/hessian/{activation_trace_hessian_calculator_keras.py → activation_hessian_scores_calculator_keras.py} +26 -26
  16. model_compression_toolkit/core/keras/hessian/{trace_hessian_calculator_keras.py → hessian_scores_calculator_keras.py} +14 -14
  17. model_compression_toolkit/core/keras/hessian/{weights_trace_hessian_calculator_keras.py → weights_hessian_scores_calculator_keras.py} +27 -27
  18. model_compression_toolkit/core/keras/keras_implementation.py +30 -30
  19. model_compression_toolkit/core/pytorch/hessian/{activation_trace_hessian_calculator_pytorch.py → activation_hessian_scores_calculator_pytorch.py} +25 -25
  20. model_compression_toolkit/core/pytorch/hessian/{trace_hessian_calculator_pytorch.py → hessian_scores_calculator_pytorch.py} +14 -14
  21. model_compression_toolkit/core/pytorch/hessian/{weights_trace_hessian_calculator_pytorch.py → weights_hessian_scores_calculator_pytorch.py} +25 -25
  22. model_compression_toolkit/core/pytorch/pytorch_implementation.py +30 -30
  23. model_compression_toolkit/core/quantization_prep_runner.py +1 -1
  24. model_compression_toolkit/gptq/common/gptq_training.py +30 -30
  25. model_compression_toolkit/gptq/keras/gptq_training.py +1 -1
  26. model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
  27. model_compression_toolkit/gptq/runner.py +2 -2
  28. model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
  29. {mct_nightly-2.1.0.20240707.450.dist-info → mct_nightly-2.1.0.20240709.429.dist-info}/LICENSE.md +0 -0
  30. {mct_nightly-2.1.0.20240707.450.dist-info → mct_nightly-2.1.0.20240709.429.dist-info}/WHEEL +0 -0
  31. {mct_nightly-2.1.0.20240707.450.dist-info → mct_nightly-2.1.0.20240709.429.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.1.0.20240707.450
3
+ Version: 2.1.0.20240709.429
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=MsAgtOGbbXnfVTRoVbyLqFcxa5nJo3wYrSstXzIuImY,1573
1
+ model_compression_toolkit/__init__.py,sha256=LCoESLcx_LQQJ0jfylbC-ofTwTQJkdKSMaLglyeTbG4,1573
2
2
  model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -6,11 +6,11 @@ model_compression_toolkit/metadata.py,sha256=IyoON37lBv3TI0rZGCP4K5t3oYI4TOmYy-L
6
6
  model_compression_toolkit/core/__init__.py,sha256=TrRgkWpT1AN2Faw1M_1HXyJkJnbxfn9p-RigDZl7pg0,1982
7
7
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
8
8
  model_compression_toolkit/core/graph_prep_runner.py,sha256=kM70wmNG3yMFiGQc0uO0wn9j4ZbSWxUEykpxDK55doc,10567
9
- model_compression_toolkit/core/quantization_prep_runner.py,sha256=0ga95vh_ZXO79r8FB26L5GIZKHkG98wq1hMsNH1bIeU,6453
9
+ model_compression_toolkit/core/quantization_prep_runner.py,sha256=K9eJ7VbB_rpeyxX4yEnorOmSxFW3DkvofzxS6QI8Hp8,6454
10
10
  model_compression_toolkit/core/runner.py,sha256=4TtOgyNb4cXr52dOlDqYxLm3rnLR6uHPDNoZiEFL9XA,12655
11
11
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
12
12
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
13
- model_compression_toolkit/core/common/framework_implementation.py,sha256=8b6M1GcUR9bDgoxwqyNP8C6KSU9OTQ5hIk20Y74eLPo,20896
13
+ model_compression_toolkit/core/common/framework_implementation.py,sha256=zg2Sznd-uqpN8hy9C22jah_27qPYXNajYx3xMR3AlMw,20938
14
14
  model_compression_toolkit/core/common/framework_info.py,sha256=1ZMMGS9ip-kSflqkartyNRt9aQ5ub1WepuTRcTy-YSQ,6337
15
15
  model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
16
16
  model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
@@ -44,11 +44,11 @@ model_compression_toolkit/core/common/graph/memory_graph/cut.py,sha256=aPdXJPP5a
44
44
  model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256=crV2NCLVO8jx9MlryZBYuJKFe_G9HfM7rUR64fDymlw,17045
45
45
  model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
46
46
  model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=gw4av_rzn_3oEAPpD3B7PHZDqnxHMjIESevl6ppPnkk,7175
47
- model_compression_toolkit/core/common/hessian/__init__.py,sha256=bxPVbkIlHFJMiOgTdWMVCqcD9JKV5kb2bVdWUTeLpj8,1021
48
- model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=0Ziwyzv6H5mIG5ptW6uC_w1gmxZIdffCuK8cg0STmJQ,20731
49
- model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=JepOjcyX1XyiC1UblqM3zdKv2xuUvU3HKWjlE1Bnq_U,1490
50
- model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py,sha256=EIV4NVUfvkefqMAFrrjNhQq7cvT3hljHpGz_gpVaFtY,4135
51
- model_compression_toolkit/core/common/hessian/trace_hessian_request.py,sha256=uvnaYtJRRmj_CfnYAO6oehnhDqdalW0NgETWJvSzCxc,3245
47
+ model_compression_toolkit/core/common/hessian/__init__.py,sha256=6216QgHl7h4DXGn5ForP9Tija-wrBSONNtQ769ikP2s,1025
48
+ model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=DHbZqFDuDir1QWN-YkYBzaoGDujgYam1hT2ea6uL3yM,21009
49
+ model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=1axmN0tjJSo_7hUr2d2KMv4y1pBi19cqWSQpi4BbdsA,1458
50
+ model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=Pe4uKerx-MeDQPJ7Slr8fvFUHfv02q33w3gbQK5kBKs,4186
51
+ model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=atGJgJBL9uwYRC3t9NnzGgHYxV4XJj4Ai_xPpQH0rhY,3229
52
52
  model_compression_toolkit/core/common/matchers/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
53
53
  model_compression_toolkit/core/common/matchers/base_graph_filter.py,sha256=mTk54z0mIbFmPOb4h0xfLtLDookcFyNh8H0pIN5js_M,3091
54
54
  model_compression_toolkit/core/common/matchers/base_matcher.py,sha256=JCj-NLAXOJa-GcSX-94PVUTWjooQUd0NemiyNg5uKGQ,2210
@@ -62,9 +62,9 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,s
62
62
  model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=7dKMi5S0zQZ16m8NWn1XIuoXsKuZUg64G4-uK8-j1PQ,5177
63
63
  model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=H8qYkJsk88OszUJo-Zde7vTmWiypLTg9KbbzIZ-hhvM,2812
64
64
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=rppRZJdSCQGiZsd93QxoUIhj51eETvQbuI5JiC2TUeA,4963
65
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=JmHopRNpHjxnoyeqXRVO0t-DdqEOm-jOZI06w5aAl9k,7550
65
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=pk8HRoShDhiUprBC4m1AFQv1SacS4hOrj0MRdbq-5gY,7556
66
66
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=TTTux4YiOnQqt-2h7Y38959XaDwNZc0eufLMx_yws5U,37578
67
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=C29Qg3oO5d4lslB7uZh8JwcEgLKqyKox9B1Ss9mZDLQ,27536
67
+ model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=PNGWP0bPkQNP1jCORP6FQgGIr616Kg1YSe_hy-BDg0I,27546
68
68
  model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
69
69
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=KifDMbm7qkSfvSl6pcZzQ82naIXzeKL6aT-VsvWZYyc,7901
70
70
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -92,7 +92,7 @@ model_compression_toolkit/core/common/pruning/pruning_section.py,sha256=I4vxh5iP
92
92
  model_compression_toolkit/core/common/pruning/importance_metrics/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
93
93
  model_compression_toolkit/core/common/pruning/importance_metrics/base_importance_metric.py,sha256=qMAtLWs5fjbSco8nhbig5TkuacdhnDW7cy3avMHRGX4,1988
94
94
  model_compression_toolkit/core/common/pruning/importance_metrics/importance_metric_factory.py,sha256=E-fKuRfrNYlN3nNcRAbnkJkFNwClvyrL_Js1qDPxIKA,1999
95
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py,sha256=SA1lqFNWhyXAnEyT_ROd3a-9gDYAgoCusk13US2l_QE,14047
95
+ model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py,sha256=sJGsR8vTs43SV7Wfz929pCwLM-_7aXUyO5nBUig9K9s,14055
96
96
  model_compression_toolkit/core/common/pruning/mask/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
97
97
  model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=APY8BsM9B7ZxVCH6n1xs9fSCTB_A9ou9gHrCQl1DOdI,5131
98
98
  model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=4ohJrJHNzZk5uMnZEYkwLx2TDGzkh5kRhLGNVYNC6dc,5978
@@ -109,12 +109,12 @@ model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha
109
109
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
110
110
  model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=O4qFJw3nBYUD4cGbO8haGXZ2-piSqoRpDKDD74iXSxw,12417
111
111
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
112
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=4XH-qSo-zG7XkVTx1J0DFNHEklLOhkhxXeEWnXNJ7z8,23602
112
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=Fd_gxr5js-mqEwucaRR1CQAZ1W_wna19L1gAPeOzxRQ,23610
113
113
  model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=t0XSwjfOxcq2Sj2PGzccntz1GGv2eqVn9oR3OI0t9wo,8533
114
114
  model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
115
115
  model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=HfnhQ4MxGpb95gOWXD1vnroTxxjFt9VFd4jIdo-rvAQ,10623
116
116
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=noEdvGiyyW7acgQ2OFWLedCODibTGYJifC9qo8YIU5U,4558
117
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=JS1nhQUMBVBtEjXbevFbbzHsXM0QLKVTG_3DRhdTAa0,8643
117
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=oME8T6Slgl1SJNpXV4oY3UhuX0YmKYbcWDsLiCYq7oE,8651
118
118
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=o2XNY_0pUUyId02TUVQBtkux_i40NCcnzuobSeQLy3E,42863
119
119
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=zSNda0jN8cP41m6g5TOv5WvATwIhV8z6AVM1Es6rq1s,4419
120
120
  model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=4TP41wPYC0azIzFxUt-lNlKUPIIXQeE4H1SYHkON75k,11875
@@ -150,7 +150,7 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
150
150
  model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
151
151
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
152
152
  model_compression_toolkit/core/keras/default_framework_info.py,sha256=HcHplb7IcnOTyK2p6uhp3OVG4-RV3RDo9C_4evaIzkQ,4981
153
- model_compression_toolkit/core/keras/keras_implementation.py,sha256=bRH39d4lW7Ngm8xi7v9JQd9gNfGlB_lb-bolbzTYUcc,29881
153
+ model_compression_toolkit/core/keras/keras_implementation.py,sha256=qrmFzs749gVCEanw-paEsbM50zECtMV8qDMxKFTmyj0,29906
154
154
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
155
155
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
156
156
  model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=Xmk2ZL5CaYdb7iG62HdtZ1F64vap7ffnrsuR3e3G5hc,4851
@@ -184,9 +184,9 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/softmax_s
184
184
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=wH9ocMLL725-uUPU-zCxdd8NwT5nyd0ZShmI7iuTwF8,1462
185
185
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/weights_activation_split.py,sha256=rjIheZW7LbSPv9bzMSmC8wl6UUxaTkd4J2IHinObT-Y,1814
186
186
  model_compression_toolkit/core/keras/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
187
- model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py,sha256=4eJKq_Fx4mm_VuBDeeti0fTcUk1lL2yjebxCugJhvrA,8871
188
- model_compression_toolkit/core/keras/hessian/trace_hessian_calculator_keras.py,sha256=hRfAjgZakDaIMuERmTVjJSa_Ww6FmEudYPO9R7SuYuQ,3914
189
- model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py,sha256=Xogd90kPZPvKbplZQv5B77Dq_m4aW5-bL6Jxh33VZWs,12213
187
+ model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculator_keras.py,sha256=p0eM-EO5ltXYjSkd7B3h9BWBcuRZvjxEcA8WaNvdyqc,8901
188
+ model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py,sha256=Cep-bQEwLyqLYfLxM0ByOQd_oAIT-uXjr3dFUd8T9CY,3954
189
+ model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py,sha256=970C-8J4HtUalNWvZAKlWFZVfw5r6SBdt5RQU_mZ7M0,12261
190
190
  model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
191
191
  model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=aW8wR13fK6P6xzbU9XGU60IO1yYzXSo_Hk4qeq486kg,5137
192
192
  model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=Ziydik2j-LvNBXP3TSfUD6rEezPAikzQGib0_IXkmGM,6729
@@ -213,7 +213,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
213
213
  model_compression_toolkit/core/pytorch/constants.py,sha256=AguUnAsNlj41gwuKIP_7nos3FcJHsIAjewLXSQdrDQM,2624
214
214
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
215
215
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
216
- model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=7CFt1Y3fiDaKkEVvlDd76ZmucCuVp6OZNQwwqJezKbU,27547
216
+ model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=XL_RZcfnb_ZY2jdCjOxxz7SbRBzMokbOWsTuYOSjyRU,27569
217
217
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
218
218
  model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=E6ifk1HdO60k4IRH2EFBzAYWtwUlrGqJoQ66nknpHoQ,4983
219
219
  model_compression_toolkit/core/pytorch/utils.py,sha256=OT_mrNEJqPgWLdtQuivKMQVjtJY49cmoIVvbRhANl1w,3004
@@ -249,9 +249,9 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transfo
249
249
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=WmEa8Xjji-_tIbthDxlLAGSr69nWk-YKcHNaVqLa7sg,1375
250
250
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/weights_activation_split.py,sha256=tp78axmUQc0Zpj3KwVmV0PGYHvCf7sAW_sRmXXw7gsY,1616
251
251
  model_compression_toolkit/core/pytorch/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
252
- model_compression_toolkit/core/pytorch/hessian/activation_trace_hessian_calculator_pytorch.py,sha256=eDiTiKVvH5NBgUFV6oBe7QeowJRo6tOQbcXx9t9k2S0,8522
253
- model_compression_toolkit/core/pytorch/hessian/trace_hessian_calculator_pytorch.py,sha256=Gat9aobUOQEWGt02x30vVm04mdi3gchdz2Bmmw5p91w,3445
254
- model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py,sha256=-B446KhtZHPU_5Ixtm9v_v-3qDQ05NoIj2iyq5DlgR4,8460
252
+ model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py,sha256=xc_-utc9_Hq915X02VbT8zXxGqxE4fFz6dhiiZwU3ok,8578
253
+ model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py,sha256=vXluX-awgavv7DGihG9HrlvLhak8qIHy837PPTOd4jg,3471
254
+ model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py,sha256=C4-7naBQUh8TN6fEwkyKY6rlY_nvHSAmCnWT4iMBs8E,8497
255
255
  model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
256
256
  model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=-6oep2WJ85-JmIxZa-e2AmBpbORoKe4Xdduz2ZidwvM,4871
257
257
  model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=KVZTKCYzJqqzF5nFEiuGMv_sNeVuBTxhmxWMFacKOxE,6337
@@ -332,17 +332,17 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha
332
332
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=YT9IVdpKaJbAW3msYRoQNIgqRSEVwSarRy6qlWCrBfk,5389
333
333
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=4sN5z-6BXrTE5Dp2FX_jKO9ty5iZ2r4RM7XvXtDVLSI,9348
334
334
  model_compression_toolkit/gptq/__init__.py,sha256=YKg-tMj9D4Yd0xW9VRD5EN1J5JrmlRbNEF2fOSgodqA,1228
335
- model_compression_toolkit/gptq/runner.py,sha256=PQoLK3WhdRuUwZMd1VbtA7KZ9c-zWig_0ShmTtvJSHY,5970
335
+ model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
336
336
  model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
337
337
  model_compression_toolkit/gptq/common/gptq_config.py,sha256=U-NiVEedkOsVaFq-iXU2Xcqp99Rgf0f2I3oANdVMhMY,5672
338
338
  model_compression_toolkit/gptq/common/gptq_constants.py,sha256=QSm6laLkIV0LYmU0BLtmKp3Fi3SqDfbncFQWOGA1cGU,611
339
339
  model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
340
340
  model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
341
- model_compression_toolkit/gptq/common/gptq_training.py,sha256=efnwgKSGk9wtnirlLR48Q1KtZuXoGcrHoPUYq_9YKxc,16394
341
+ model_compression_toolkit/gptq/common/gptq_training.py,sha256=nI_XVa7WbfCcbgHMFgnPtnD77m5ezAB306z7VE0XFvU,16527
342
342
  model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
343
343
  model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
344
344
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
345
- model_compression_toolkit/gptq/keras/gptq_training.py,sha256=RAUZvve-kUMTfXY-aXQWEM4IejaeVedrKejBNrO6szI,19156
345
+ model_compression_toolkit/gptq/keras/gptq_training.py,sha256=NXTNsVrO9DTh0uvc8V7rFaM0fYg2OA18ZrYd-cKZ7Z4,19159
346
346
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=MKIfrRTRH3zCuxCR1g9ZVIFyuSSr0e0sDybqh4LDM7E,4672
347
347
  model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=SjmBTuSwki4JTPVhxvJMFK9uAsmEm2c6VV11NnM6eEo,15117
348
348
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
@@ -359,7 +359,7 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
359
359
  model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
360
360
  model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
361
361
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
362
- model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=2pe_caivE7Fr9zCvmZENKbFTS6AUFbSjHN-TODEhbSY,16631
362
+ model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=-daninmlPGfKsBNPB2C3gT6rK0G5YeyJsuOLA0JlfBU,16633
363
363
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
364
364
  model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=jcvRKBuMkrerNE8oWIJFp802pyFO0dnA-4hRnclKbWE,13569
365
365
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
@@ -400,7 +400,7 @@ model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cc
400
400
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=I4KlaGv17k71IyjuSG9M0OlXlD5P0pfvKa6oCyRQ5FE,13517
401
401
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=EED6LfqhX_OhDRJ9e4GwbpgNC9vq7hoXyJS2VPvG2qc,10789
402
402
  model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
403
- model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=23qsgnXk-U5nYMucqDI9iZVj5sKXMdMf9ceQWd8nfqA,13374
403
+ model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=pRy2B5OsaLi33p4hozjr0rzAooT8Gic3_qxTl66J900,13375
404
404
  model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
405
405
  model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py,sha256=soxP4moJxQaziq5ccP3Du34fSIVSFyZq6hD8YuaDn88,2187
406
406
  model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=sFWGu76PZ9dSRf3L0uZI6YwLIs0biBND1tl76I1piBQ,5721
@@ -517,8 +517,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
517
517
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=yrZNVRm2IRU7r7R-hjS2lOQ6wvEEvbeunvf2jKoWjXk,3277
518
518
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
519
519
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=eyMoXt5o5EnMr6d-rpCwQdX5mAiYiymvbgKv4tf7-a0,4576
520
- mct_nightly-2.1.0.20240707.450.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
521
- mct_nightly-2.1.0.20240707.450.dist-info/METADATA,sha256=O1XwRSSNtzfB3pOlMg7X8DuTGL6LzeJ2A6PgXro2hH4,19719
522
- mct_nightly-2.1.0.20240707.450.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
523
- mct_nightly-2.1.0.20240707.450.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
524
- mct_nightly-2.1.0.20240707.450.dist-info/RECORD,,
520
+ mct_nightly-2.1.0.20240709.429.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
521
+ mct_nightly-2.1.0.20240709.429.dist-info/METADATA,sha256=9TI8xFMqUk_zARYw-va6f-94DupOa_Rv6xaHF7Loet8,19719
522
+ mct_nightly-2.1.0.20240709.429.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
523
+ mct_nightly-2.1.0.20240709.429.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
524
+ mct_nightly-2.1.0.20240709.429.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.1.0.20240707.000450"
30
+ __version__ = "2.1.0.20240709.000429"
@@ -24,7 +24,7 @@ from model_compression_toolkit.core.common import BaseNode
24
24
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
25
25
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
26
26
  from model_compression_toolkit.core.common.graph.base_graph import Graph
27
- from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianInfoService
27
+ from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianInfoService
28
28
  from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
29
29
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
30
30
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
@@ -49,23 +49,23 @@ class FrameworkImplementation(ABC):
49
49
  raise NotImplemented(f'{self.__class__.__name__} did not supply a constants module.') # pragma: no cover
50
50
 
51
51
  @abstractmethod
52
- def get_trace_hessian_calculator(self,
53
- graph: Graph,
54
- input_images: List[Any],
55
- trace_hessian_request: TraceHessianRequest,
56
- num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
52
+ def get_hessian_scores_calculator(self,
53
+ graph: Graph,
54
+ input_images: List[Any],
55
+ hessian_scores_request: HessianScoresRequest,
56
+ num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
57
57
  """
58
- Get framework trace hessian approximations calculator based on the trace hessian request.
58
+ Get framework hessian-approximation scores calculator based on the hessian scores request.
59
59
  Args:
60
60
  input_images: Images to use for computation.
61
61
  graph: Float graph to compute the approximation of its different nodes.
62
- trace_hessian_request: TraceHessianRequest to search for the desired calculator.
63
- num_iterations_for_approximation: Number of iterations to use when approximating the Hessian trace.
62
+ hessian_scores_request: HessianScoresRequest to search for the desired calculator.
63
+ num_iterations_for_approximation: Number of iterations to use when approximating the Hessian-approximation scores.
64
64
 
65
- Returns: TraceHessianCalculator to use for the trace hessian approximation computation for this request.
65
+ Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request.
66
66
  """
67
67
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
68
- f'framework\'s get_trace_hessian_calculator method.') # pragma: no cover
68
+ f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover
69
69
 
70
70
  @abstractmethod
71
71
  def to_numpy(self, tensor: Any) -> np.ndarray:
@@ -310,7 +310,7 @@ class FrameworkImplementation(ABC):
310
310
  representative_data_gen: Dataset to use for retrieving images for the models inputs.
311
311
  fw_info: FrameworkInfo object with information about the specific framework's model.
312
312
  disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
313
- hessian_info_service: HessianInfoService to fetch Hessian traces approximations.
313
+ hessian_info_service: HessianInfoService to fetch information based on Hessian-approximation.
314
314
 
315
315
  Returns:
316
316
  A function that computes the metric.
@@ -12,6 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.core.common.hessian.trace_hessian_request import TraceHessianRequest, HessianMode, HessianInfoGranularity
15
+ from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, HessianMode, HessianScoresGranularity
16
16
  from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
17
17
  import model_compression_toolkit.core.common.hessian.hessian_info_utils as hessian_utils
@@ -19,24 +19,24 @@ from tqdm import tqdm
19
19
  from typing import Callable, List, Dict, Any, Tuple
20
20
 
21
21
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
22
- from model_compression_toolkit.core.common.hessian.trace_hessian_request import TraceHessianRequest, \
23
- HessianInfoGranularity, HessianMode
22
+ from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, \
23
+ HessianScoresGranularity, HessianMode
24
24
  from model_compression_toolkit.logger import Logger
25
25
 
26
26
 
27
27
  class HessianInfoService:
28
28
  """
29
- A service to manage, store, and compute approximation of the Hessian-based information.
29
+ A service to manage, store, and compute information based on the Hessian matrix approximation.
30
30
 
31
- This class provides functionalities to compute approximation based on the Hessian matrix based
32
- on the different parameters (such as number of iterations for approximating the info)
31
+ This class provides functionalities to compute information based on the Hessian matrix approximation
32
+ based on the different parameters (such as number of iterations for approximating the scores)
33
33
  and input images (using representative_dataset_gen).
34
34
  It also offers cache management capabilities for efficient computation and retrieval.
35
35
 
36
36
  Note:
37
37
  - The Hessian provides valuable information about the curvature of the loss function.
38
38
  - Computation can be computationally heavy and time-consuming.
39
- - The computed trace is an approximation.
39
+ - The computed information is based on Hessian approximation (and not the precise Hessian matrix).
40
40
  """
41
41
 
42
42
  def __init__(self,
@@ -49,7 +49,7 @@ class HessianInfoService:
49
49
  Args:
50
50
  graph: Float graph.
51
51
  representative_dataset_gen: A callable that provides a dataset for sampling.
52
- fw_impl: Framework-specific implementation for trace Hessian approximation computation.
52
+ fw_impl: Framework-specific implementation for Hessian approximation scores computation.
53
53
  """
54
54
  self.graph = graph
55
55
 
@@ -58,7 +58,7 @@ class HessianInfoService:
58
58
  self.fw_impl = fw_impl
59
59
  self.num_iterations_for_approximation = num_iterations_for_approximation
60
60
 
61
- self.trace_hessian_request_to_score_list = {}
61
+ self.hessian_scores_request_to_scores_list = {}
62
62
 
63
63
  def _sample_batch_representative_dataset(self,
64
64
  representative_dataset: Any,
@@ -142,11 +142,11 @@ class HessianInfoService:
142
142
 
143
143
  def _clear_saved_hessian_info(self):
144
144
  """Clears the saved info approximations."""
145
- self.trace_hessian_request_to_score_list={}
145
+ self.hessian_scores_request_to_scores_list={}
146
146
 
147
- def count_saved_info_of_request(self, hessian_request: TraceHessianRequest) -> Dict:
147
+ def count_saved_scores_of_request(self, hessian_request: HessianScoresRequest) -> Dict:
148
148
  """
149
- Counts the saved approximations of Hessian info (traces, for now) for a specific request.
149
+ Counts the saved approximations of Hessian scores for a specific request.
150
150
  If some approximations were computed for this request before, the amount of approximations (per image)
151
151
  will be returned. If not, zero is returned.
152
152
 
@@ -166,55 +166,58 @@ class HessianInfoService:
166
166
  Logger.critical(f"Expecting the Hessian request to include only non-reused nodes at this point, "
167
167
  f"but found node {n.name} with 'reuse' status.")
168
168
  # Check if the request for this node is in the saved info and store its count, otherwise store 0
169
- per_node_counter[n] = len(self.trace_hessian_request_to_score_list.get(hessian_request, []))
169
+ per_node_counter[n] = len(self.hessian_scores_request_to_scores_list.get(hessian_request, []))
170
170
 
171
171
  return per_node_counter
172
172
 
173
- def compute(self, trace_hessian_request: TraceHessianRequest, representative_dataset_gen, num_hessian_samples: int,
173
+ def compute(self,
174
+ hessian_scores_request: HessianScoresRequest,
175
+ representative_dataset_gen,
176
+ num_hessian_samples: int,
174
177
  last_iter_remain_samples: List[List[np.ndarray]] = None):
175
178
  """
176
- Computes an approximation of the trace of the Hessian based on the
179
+ Computes scores based on the Hessian matrix approximation according to the
177
180
  provided request configuration and stores it in the cache.
178
181
 
179
182
  Args:
180
- trace_hessian_request: Configuration for which to compute the approximation.
183
+ hessian_scores_request: Configuration for which to compute the approximation.
181
184
  representative_dataset_gen: A callable that provides a dataset for sampling.
182
185
  num_hessian_samples: Number of requested samples to compute batch Hessian approximation scores.
183
186
  last_iter_remain_samples: A list of input samples (for each input layer) with remaining samples from
184
187
  previous iterations.
185
188
  """
186
- Logger.debug(f"Computing Hessian-trace approximation for nodes {trace_hessian_request.target_nodes}.")
189
+ Logger.debug(f"Computing Hessian-scores approximations for nodes {hessian_scores_request.target_nodes}.")
187
190
 
188
191
  images, next_iter_remain_samples = representative_dataset_gen(num_hessian_samples=num_hessian_samples,
189
192
  last_iter_remain_samples=last_iter_remain_samples)
190
193
 
191
194
  # Compute and store the computed approximation in the saved info
192
195
  topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
193
- trace_hessian_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name))
196
+ hessian_scores_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name))
194
197
 
195
- # Get the framework-specific calculator for trace Hessian approximation
196
- fw_hessian_calculator = self.fw_impl.get_trace_hessian_calculator(graph=self.graph,
197
- input_images=images,
198
- trace_hessian_request=trace_hessian_request,
199
- num_iterations_for_approximation=self.num_iterations_for_approximation)
198
+ # Get the framework-specific calculator Hessian-approximation scores
199
+ fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph,
200
+ input_images=images,
201
+ hessian_scores_request=hessian_scores_request,
202
+ num_iterations_for_approximation=self.num_iterations_for_approximation)
200
203
 
201
- trace_hessian = fw_hessian_calculator.compute()
204
+ hessian_scores = fw_hessian_calculator.compute()
202
205
 
203
- for node, hessian in zip(trace_hessian_request.target_nodes, trace_hessian):
204
- single_node_request = self._construct_single_node_request(trace_hessian_request.mode,
205
- trace_hessian_request.granularity,
206
+ for node, hessian in zip(hessian_scores_request.target_nodes, hessian_scores):
207
+ single_node_request = self._construct_single_node_request(hessian_scores_request.mode,
208
+ hessian_scores_request.granularity,
206
209
  node)
207
210
 
208
211
  # The hessian for each node is expected to be a tensor where the first axis represents the number of
209
212
  # images in the batch on which the approximation was computed.
210
213
  # We collect the results as a list of a result for images, which is combined across batches.
211
- # After conversion, trace_hessian_request_to_score_list for a request of a single node should be a list of
214
+ # After conversion, hessian_scores_request_to_scores_list for a request of a single node should be a list of
212
215
  # results of all images, where each result is a tensor of the shape depending on the granularity.
213
- if single_node_request in self.trace_hessian_request_to_score_list:
214
- self.trace_hessian_request_to_score_list[single_node_request] += (
216
+ if single_node_request in self.hessian_scores_request_to_scores_list:
217
+ self.hessian_scores_request_to_scores_list[single_node_request] += (
215
218
  self._convert_tensor_to_list_of_appx_results(hessian))
216
219
  else:
217
- self.trace_hessian_request_to_score_list[single_node_request] = (
220
+ self.hessian_scores_request_to_scores_list[single_node_request] = (
218
221
  self._convert_tensor_to_list_of_appx_results(hessian))
219
222
 
220
223
  # In case that we are required to return a number of scores that is larger that the computation batch size
@@ -226,15 +229,15 @@ class HessianInfoService:
226
229
  and len(next_iter_remain_samples[0]) > 0 else None
227
230
 
228
231
  def fetch_hessian(self,
229
- trace_hessian_request: TraceHessianRequest,
232
+ hessian_scores_request: HessianScoresRequest,
230
233
  required_size: int,
231
234
  batch_size: int = 1) -> List[List[np.ndarray]]:
232
235
  """
233
- Fetches the computed approximations of the trace of the Hessian for the given
236
+ Fetches the computed approximations of the Hessian-based scores for the given
234
237
  request and required size.
235
238
 
236
239
  Args:
237
- trace_hessian_request: Configuration for which to fetch the approximation.
240
+ hessian_scores_request: Configuration for which to fetch the approximation.
238
241
  required_size: Number of approximations required.
239
242
  batch_size: The Hessian computation batch size.
240
243
 
@@ -245,28 +248,28 @@ class HessianInfoService:
245
248
  OC for per-output-channel when the requested node has OC output-channels, etc.)
246
249
  """
247
250
 
248
- if len(trace_hessian_request.target_nodes) == 0:
251
+ if len(hessian_scores_request.target_nodes) == 0:
249
252
  return []
250
253
 
251
254
  if required_size == 0:
252
- return [[] for _ in trace_hessian_request.target_nodes]
255
+ return [[] for _ in hessian_scores_request.target_nodes]
253
256
 
254
- Logger.info(f"\nEnsuring {required_size} Hessian-trace approximation for nodes "
255
- f"{trace_hessian_request.target_nodes}.")
257
+ Logger.info(f"\nEnsuring {required_size} Hessian-approximation scores for nodes "
258
+ f"{hessian_scores_request.target_nodes}.")
256
259
 
257
260
  # Replace node in reused target nodes with a representing node from the 'reuse group'.
258
- for n in trace_hessian_request.target_nodes:
261
+ for n in hessian_scores_request.target_nodes:
259
262
  if n.reuse_group:
260
263
  rep_node = self._get_representing_of_reuse_group(n)
261
- trace_hessian_request.target_nodes.remove(n)
262
- if rep_node not in trace_hessian_request.target_nodes:
263
- trace_hessian_request.target_nodes.append(rep_node)
264
+ hessian_scores_request.target_nodes.remove(n)
265
+ if rep_node not in hessian_scores_request.target_nodes:
266
+ hessian_scores_request.target_nodes.append(rep_node)
264
267
 
265
268
  # Ensure the saved info has the required number of approximations
266
- self._populate_saved_info_to_size(trace_hessian_request, required_size, batch_size)
269
+ self._populate_saved_info_to_size(hessian_scores_request, required_size, batch_size)
267
270
 
268
271
  # Return the saved approximations for the given request
269
- return self._collect_saved_hessians_for_request(trace_hessian_request, required_size)
272
+ return self._collect_saved_hessians_for_request(hessian_scores_request, required_size)
270
273
 
271
274
  def _get_representing_of_reuse_group(self, node) -> Any:
272
275
  """
@@ -286,20 +289,20 @@ class HessianInfoService:
286
289
  return father_nodes[0]
287
290
 
288
291
  def _populate_saved_info_to_size(self,
289
- trace_hessian_request: TraceHessianRequest,
292
+ hessian_scores_request: HessianScoresRequest,
290
293
  required_size: int,
291
294
  batch_size: int = 1):
292
295
  """
293
- Ensures that the saved info has the required size of trace Hessian approximations for the given request.
296
+ Ensures that the saved info has the required size of Hessian approximation scores for the given request.
294
297
 
295
298
  Args:
296
- trace_hessian_request: Configuration for which to ensure the saved info size.
297
- required_size: Required number of trace Hessian approximations.
299
+ hessian_scores_request: Configuration of the request to ensure the saved info size.
300
+ required_size: Required number of Hessian-approximation scores.
298
301
  batch_size: The Hessian computation batch size.
299
302
  """
300
303
 
301
304
  # Get the current number of saved approximations for each node in the request
302
- current_existing_hessians = self.count_saved_info_of_request(trace_hessian_request)
305
+ current_existing_hessians = self.count_saved_scores_of_request(hessian_scores_request)
303
306
 
304
307
  # Compute the required number of approximations to meet the required size.
305
308
  # Since we allow batch and multi-nodes computation, we take the node with the maximal number of missing
@@ -308,9 +311,9 @@ class HessianInfoService:
308
311
  max_remaining_hessians = required_size - min_exist_hessians
309
312
 
310
313
  Logger.info(
311
- f"Running Hessian approximation computation for {len(trace_hessian_request.target_nodes)} nodes.\n "
312
- f"The node with minimal existing Hessian-trace approximations has {min_exist_hessians} "
313
- f"approximations computed.\n"
314
+ f"Running Hessian approximation computation for {len(hessian_scores_request.target_nodes)} nodes.\n "
315
+ f"The node with minimal existing Hessian-approximation scores has {min_exist_hessians} "
316
+ f"approximated scores computed.\n"
314
317
  f"{max_remaining_hessians} approximations left to compute...")
315
318
 
316
319
  hessian_representative_dataset = partial(self._sample_batch_representative_dataset,
@@ -325,30 +328,31 @@ class HessianInfoService:
325
328
  pbar.update(1)
326
329
  size_to_compute = min(max_remaining_hessians, batch_size)
327
330
  next_iter_remaining_samples = (
328
- self.compute(trace_hessian_request, hessian_representative_dataset, size_to_compute,
331
+ self.compute(hessian_scores_request, hessian_representative_dataset, size_to_compute,
329
332
  last_iter_remain_samples=next_iter_remaining_samples))
330
333
  max_remaining_hessians -= size_to_compute
331
334
 
332
- def _collect_saved_hessians_for_request(self, trace_hessian_request: TraceHessianRequest, required_size: int
333
- ) -> List[List[np.ndarray]]:
335
+ def _collect_saved_hessians_for_request(self,
336
+ hessian_scores_request: HessianScoresRequest,
337
+ required_size: int) -> List[List[np.ndarray]]:
334
338
  """
335
339
  Collects Hessian approximation for the nodes in the given request.
336
340
 
337
341
  Args:
338
- trace_hessian_request: Configuration for which to fetch the approximation.
339
- required_size: Required number of trace Hessian approximations.
342
+ hessian_scores_request: Configuration for which to fetch the approximation.
343
+ required_size: Required number of Hessian-approximation scores.
340
344
 
341
345
  Returns: A list with List of computed Hessian approximation (a tensor for each score) for each node
342
346
  in the request.
343
347
 
344
348
  """
345
349
  collected_results = []
346
- for node in trace_hessian_request.target_nodes:
347
- single_node_request = self._construct_single_node_request(trace_hessian_request.mode,
348
- trace_hessian_request.granularity,
350
+ for node in hessian_scores_request.target_nodes:
351
+ single_node_request = self._construct_single_node_request(hessian_scores_request.mode,
352
+ hessian_scores_request.granularity,
349
353
  node)
350
354
 
351
- res_for_node = self.trace_hessian_request_to_score_list.get(single_node_request)
355
+ res_for_node = self.hessian_scores_request_to_scores_list.get(single_node_request)
352
356
  if res_for_node is None: # pragma: no cover
353
357
  Logger.critical(f"Couldn't find saved Hessian approximations for node {node.name}.")
354
358
  if len(res_for_node) < required_size: # pragma: no cover
@@ -362,22 +366,23 @@ class HessianInfoService:
362
366
  return collected_results
363
367
 
364
368
  @staticmethod
365
- def _construct_single_node_request(mode: HessianMode, granularity: HessianInfoGranularity, target_nodes: List
366
- ) -> TraceHessianRequest:
369
+ def _construct_single_node_request(mode: HessianMode,
370
+ granularity: HessianScoresGranularity,
371
+ target_nodes: List) -> HessianScoresRequest:
367
372
  """
368
373
  Constructs a Hessian request with for a single node. Used for retrieving and maintaining cached results.
369
374
 
370
375
  Args:
371
- mode (HessianMode): Mode of Hessian's trace approximation (w.r.t weights or activations).
372
- granularity (HessianInfoGranularity): Granularity level for the approximation.
373
- target_nodes (List[BaseNode]): The node in the float graph for which the Hessian's trace approximation is targeted.
376
+ mode (HessianMode): Mode of Hessian's approximation (w.r.t weights or activations).
377
+ granularity (HessianScoresGranularity): Granularity level for the approximation.
378
+ target_nodes (List[BaseNode]): The node in the float graph for which the Hessian's approximation scores is targeted.
374
379
 
375
- Returns: A TraceHessianRequest with the given details for the requested node.
380
+ Returns: A HessianScoresRequest with the given details for the requested node.
376
381
 
377
382
  """
378
- return TraceHessianRequest(mode,
379
- granularity,
380
- target_nodes=[target_nodes])
383
+ return HessianScoresRequest(mode,
384
+ granularity,
385
+ target_nodes=[target_nodes])
381
386
 
382
387
  @staticmethod
383
388
  def _convert_tensor_to_list_of_appx_results(t: Any) -> List:
@@ -19,7 +19,7 @@ from model_compression_toolkit.constants import EPS
19
19
 
20
20
  def normalize_scores(hessian_approximations: List) -> List[np.ndarray]:
21
21
  """
22
- Normalize Hessian information approximations by dividing the trace Hessian approximations value by the sum of all
22
+ Normalize Hessian scores approximations by dividing their value by the sum of all
23
23
  other values.
24
24
 
25
25
  Args: