mct-nightly 2.0.0.20240417.406__py3-none-any.whl → 2.0.0.20240418.439__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {mct_nightly-2.0.0.20240417.406.dist-info → mct_nightly-2.0.0.20240418.439.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.0.0.20240417.406.dist-info → mct_nightly-2.0.0.20240418.439.dist-info}/RECORD +32 -29
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/constants.py +2 -0
  5. model_compression_toolkit/core/common/graph/base_node.py +1 -1
  6. model_compression_toolkit/core/common/hessian/hessian_info_service.py +2 -3
  7. model_compression_toolkit/core/common/hessian/trace_hessian_request.py +1 -3
  8. model_compression_toolkit/core/common/quantization/quantization_config.py +5 -2
  9. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +67 -4
  10. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +10 -3
  11. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +14 -4
  12. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +30 -3
  13. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +17 -7
  14. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +14 -3
  15. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +13 -3
  16. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +16 -3
  17. model_compression_toolkit/core/common/similarity_analyzer.py +14 -2
  18. model_compression_toolkit/core/common/substitutions/remove_identity.py +48 -0
  19. model_compression_toolkit/core/graph_prep_runner.py +10 -4
  20. model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_identity.py +51 -0
  21. model_compression_toolkit/core/keras/keras_implementation.py +3 -1
  22. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_identity.py +50 -0
  23. model_compression_toolkit/core/pytorch/pytorch_implementation.py +3 -1
  24. model_compression_toolkit/core/quantization_prep_runner.py +6 -2
  25. model_compression_toolkit/core/runner.py +5 -2
  26. model_compression_toolkit/gptq/keras/quantization_facade.py +2 -1
  27. model_compression_toolkit/gptq/pytorch/quantization_facade.py +3 -1
  28. model_compression_toolkit/gptq/runner.py +1 -0
  29. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +5 -5
  30. {mct_nightly-2.0.0.20240417.406.dist-info → mct_nightly-2.0.0.20240418.439.dist-info}/LICENSE.md +0 -0
  31. {mct_nightly-2.0.0.20240417.406.dist-info → mct_nightly-2.0.0.20240418.439.dist-info}/WHEEL +0 -0
  32. {mct_nightly-2.0.0.20240417.406.dist-info → mct_nightly-2.0.0.20240418.439.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.0.0.20240417.406
3
+ Version: 2.0.0.20240418.439
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,13 +1,13 @@
1
- model_compression_toolkit/__init__.py,sha256=HSq5ybA5NctJln9ucs7HnIcj00pgOGdhjVxEY-2w5dY,1573
2
- model_compression_toolkit/constants.py,sha256=f9at1H_-vb5nvdHRmAHUco4ja4_QermK6yu0N9qbRGE,3723
1
+ model_compression_toolkit/__init__.py,sha256=aO8E_DhwQy12oAxKxqXFskaEwaq_icpSqsisZn6UyZM,1573
2
+ model_compression_toolkit/constants.py,sha256=yIJyJ-e1WrDeKD9kG15qkqfYnoj7J1J2CxnJDt008ik,3756
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
5
5
  model_compression_toolkit/metadata.py,sha256=IyoON37lBv3TI0rZGCP4K5t3oYI4TOmYy-LRXOwHGpE,1136
6
6
  model_compression_toolkit/core/__init__.py,sha256=TrRgkWpT1AN2Faw1M_1HXyJkJnbxfn9p-RigDZl7pg0,1982
7
7
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
8
- model_compression_toolkit/core/graph_prep_runner.py,sha256=Ftqm59hT5TGWmSNkY9bFZkVfCacpGyZfCe-6yZR5WY0,10100
9
- model_compression_toolkit/core/quantization_prep_runner.py,sha256=hFhDkS8GwzXZ7Ho_9qbbb8DAAWs3OONOfMSD5OU_b0o,6153
10
- model_compression_toolkit/core/runner.py,sha256=NKSC6ujfQPy6dKtJVwxyK2zNDd64eyR5csYy9lBrCPA,11836
8
+ model_compression_toolkit/core/graph_prep_runner.py,sha256=kM70wmNG3yMFiGQc0uO0wn9j4ZbSWxUEykpxDK55doc,10567
9
+ model_compression_toolkit/core/quantization_prep_runner.py,sha256=0ga95vh_ZXO79r8FB26L5GIZKHkG98wq1hMsNH1bIeU,6453
10
+ model_compression_toolkit/core/runner.py,sha256=E_gXj95Az3C3swsv7v1zeKZx25keWjnD30uhI7ONZkY,12028
11
11
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
12
12
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
13
13
  model_compression_toolkit/core/common/framework_implementation.py,sha256=pOT9ZmRFL9FY92uUtigrO3sbWGiyVDhHAM1fbA4b5yo,20752
@@ -17,7 +17,7 @@ model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3u
17
17
  model_compression_toolkit/core/common/model_collector.py,sha256=ofcepKtxc3j2Ouz6BpAKXTzPgjABnpRP47ndmJCXAkk,8352
18
18
  model_compression_toolkit/core/common/model_validation.py,sha256=LaG8wd6aZl0OJgieE3SeiVDEPxtk8IHq9-3wSnmWhY4,1214
19
19
  model_compression_toolkit/core/common/node_prior_info.py,sha256=WXX_PrGVG9M9I_REG5ZzFBohwmV4yf356sZnrja_FLo,2832
20
- model_compression_toolkit/core/common/similarity_analyzer.py,sha256=98l9ttnXHf6VYxBW4852h2CPJKg3A6nLOovpHn-tnKs,8560
20
+ model_compression_toolkit/core/common/similarity_analyzer.py,sha256=5av6qDKNDJDHg0p387oOxemxvp2xkfjzB_QNaSHN6po,9199
21
21
  model_compression_toolkit/core/common/user_info.py,sha256=dSRMnT-oewmdOziIpEuW-s9K7vTSeyUBxT4z9neXurI,1648
22
22
  model_compression_toolkit/core/common/back2framework/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
23
23
  model_compression_toolkit/core/common/back2framework/base_model_builder.py,sha256=V1oShKzbSkdcTvREn8VnQQBzvm-tTHkWMXqMkYozF2s,2023
@@ -31,7 +31,7 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
31
31
  model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=lOubqpc18TslhXZijWUJQAa1c3jIB2S-M-5HK78wJPQ,5548
32
32
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
33
33
  model_compression_toolkit/core/common/graph/base_graph.py,sha256=06mvCb_HHA5iIOdQ31a-nimhrpSA-jYnuV1Ir76QGa8,38259
34
- model_compression_toolkit/core/common/graph/base_node.py,sha256=jPYpf6sci8LswatxTyygD8ZM5OvsCnxBEWsSl-g64wI,28492
34
+ model_compression_toolkit/core/common/graph/base_node.py,sha256=38-4iyOdiuWBD3eZtP7T74NYtLuqLaEj_cQZbAFHpG0,28499
35
35
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
36
36
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=RgwWAoMX7YV5c2gZdTBSX-ziTh3OLbebZXr3jitkxDs,3173
37
37
  model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
@@ -45,10 +45,10 @@ model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256
45
45
  model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
46
46
  model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=gw4av_rzn_3oEAPpD3B7PHZDqnxHMjIESevl6ppPnkk,7175
47
47
  model_compression_toolkit/core/common/hessian/__init__.py,sha256=bxPVbkIlHFJMiOgTdWMVCqcD9JKV5kb2bVdWUTeLpj8,1021
48
- model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=8B-B5G_0ukNq6ICQNyMUuopSD8viWa72mUPXF3zFlFM,9721
48
+ model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=wUmyekByJIMjupAb4qttVQHsv2pJ1ydDg17U8d5azWE,9660
49
49
  model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=FpXQvJmhiF6PAWX9M_0XZ2Qe8Wv8bXcv0Sj3si5YIjQ,1325
50
50
  model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py,sha256=bWxavhwDrSHTQPQclUzzW_Q3FVgKEtwrnD7a9lmHNbo,4379
51
- model_compression_toolkit/core/common/hessian/trace_hessian_request.py,sha256=EvdZFWlpkN9pBqWZ7jReWHIN0FTUy-9x5KgAErXWwSw,3321
51
+ model_compression_toolkit/core/common/hessian/trace_hessian_request.py,sha256=lgZZgkpCURkMNaipFoRqwsONU74OWmMXSZvh4Dc4aMk,3251
52
52
  model_compression_toolkit/core/common/matchers/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
53
53
  model_compression_toolkit/core/common/matchers/base_graph_filter.py,sha256=mTk54z0mIbFmPOb4h0xfLtLDookcFyNh8H0pIN5js_M,3091
54
54
  model_compression_toolkit/core/common/matchers/base_matcher.py,sha256=JCj-NLAXOJa-GcSX-94PVUTWjooQUd0NemiyNg5uKGQ,2210
@@ -102,23 +102,23 @@ model_compression_toolkit/core/common/quantization/core_config.py,sha256=KYdyfSm
102
102
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
103
103
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
104
104
  model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=TCgpvtfyzFUedv4sZ6sKzsTyikaVl2ixLj_aHPSC2r0,27014
105
- model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=BieZDv9oc-Mc78S_LRMGo-s_2acbqiLE0ewaSE1v2VY,6818
105
+ model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=Y76BZ-X2vE_PXeM9r7D93VsFnbC_evoHhN7zYuvFdzw,7041
106
106
  model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=T1nVWdRJfBQ_iuMQYQSIkjfkR-2n3lAOKGAz_rUZZN0,2190
107
107
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
108
108
  model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
109
109
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
110
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=9BEv2l0z2trDEsr40VB8tO3ToBA_b2sd_jH9uqZ5Wo8,11503
110
+ model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=O4qFJw3nBYUD4cGbO8haGXZ2-piSqoRpDKDD74iXSxw,12417
111
111
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
112
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=TUJuSpX8pcsIPbJ6z_YGWgD_uafqlKRJcpsTIFpjMKU,19936
113
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=HSbAlDKXZMn8BtQQGL8TnlXvO2f_2oTLXAK1khraX7g,7410
112
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=4x6rgQ5bCz2kysVkjBXxbb2dNEC9N1S2TE46kOFXU_c,23305
113
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=AROE8pZEHmzGNCRoxr5QH2QFYvu1kefSVk6is3fsifI,8027
114
114
  model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
115
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=BiwDqt5CeU6CW0Qusy3LwWhFtf2J9BvSuGMsTsG6rSw,8538
115
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=ejc_obamUndJsv3F1FuOGMrIibS__qDUbAia1H9vwUM,9487
116
116
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=noEdvGiyyW7acgQ2OFWLedCODibTGYJifC9qo8YIU5U,4558
117
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=H2D9rdChIviL_j0mF6zy8Qeu_ZXKRu-hLqckSAT1MR8,4352
117
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=7ITrOw5ykncpHNghlPNTaDZExFYrPmhRck4oW0GaPe0,6213
118
118
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=7kt0JB8PQE0SW9kg8fCwZ5mBkHNgiRrn0of4ZQYQN2A,41524
119
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=nug6XgsywxYf57XF_Tnt2xwdf0zLLsajiZKEblo4lFc,3882
120
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=QtSAtdAb7sTgtoe9L6DnMFO7rjkOtpzE9kD9xmG7eYM,9743
121
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py,sha256=nsaM-AJ6WMUBT31jFIJ2wkYAiGM8qqm9lleMS8AwINI,7933
119
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=kAqVKZYu6FHWlC_PUiytsmXdTX1GzO_S5DWrTXuJBjs,4894
120
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=_ULwlPvzVL_UcYVlUPjDIeXz_99eW26l9FwGzaUu-_M,10789
121
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py,sha256=VG0UqFOQk_7ALdJsUl1wwwFLjE38DxN6-NRZx161XiY,8902
122
122
  model_compression_toolkit/core/common/quantization/quantizers/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
123
123
  model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py,sha256=P0x_y18LypBxP2tV9OWizheYfILqvaMC8RwHo04sUpQ,2761
124
124
  model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py,sha256=CCFhi5LUIcHCCIzDyORvm0FDZLknrctdNwNlPphOQgI,14245
@@ -135,6 +135,7 @@ model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,
135
135
  model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py,sha256=YqLKiO5gFBEvI6noAWeMME1JHaYUaGFMglVFg8AqGjc,10028
136
136
  model_compression_toolkit/core/common/substitutions/linear_collapsing.py,sha256=iEtzbWCDXP6EDkTZCtREQ0rpMxhQ2kM9zlcP_0KLq9I,12367
137
137
  model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py,sha256=uoauhmncQqUBNvD-qCLIXsIbl_IzrbxSKdxiMig-5W4,2406
138
+ model_compression_toolkit/core/common/substitutions/remove_identity.py,sha256=LjkedR5fnXy4LCEQ7rnVTBI-cTkdDxXtufge5Llj2J0,2038
138
139
  model_compression_toolkit/core/common/substitutions/residual_collapsing.py,sha256=doErjlMq-uSObYMSjA6IywSHb3Hz3QCc0HKU68ccrQ4,4767
139
140
  model_compression_toolkit/core/common/substitutions/scale_equalization.py,sha256=p57u25qdW2pimxzGwgMXEBV4S-LzXuTVAlIM7830WfU,10966
140
141
  model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=cyy4qnlD-v1Gou62oHNDsf1hWLWkYfcjVv1otFrUltY,29865
@@ -149,7 +150,7 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
149
150
  model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
150
151
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
151
152
  model_compression_toolkit/core/keras/default_framework_info.py,sha256=Ha4HTHuiw_KTS5Po1Xnv6GyK9eprpDhYWf-eooS62Ys,4961
152
- model_compression_toolkit/core/keras/keras_implementation.py,sha256=RS2UEtZ_anZeDxz7Zv6sNv7v9tFVct6d9KVrUlxTGpo,29309
153
+ model_compression_toolkit/core/keras/keras_implementation.py,sha256=7RBALls_V0z18WtkWhVEpjAYmaTZvhMxQaDm4J7nkDc,29457
153
154
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
154
155
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=Aqh31wOPaiZcJIOm-uJwzev0eTMdJyXaOk97rs4z7BU,3879
155
156
  model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=Xmk2ZL5CaYdb7iG62HdtZ1F64vap7ffnrsuR3e3G5hc,4851
@@ -174,6 +175,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_co
174
175
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py,sha256=kjwlKtm5yhNgWVVcW6mN-hn7enwAnn_8-TUZvxZBiQs,4112
175
176
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=l9PUREBf4aRwWILiybdteveeUbh7js-i-hLt8Ma0e4c,26771
176
177
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py,sha256=IdKOg6AWZWMcmDbOuNdxetS5_zTarXIIffdYL7JTdvk,3872
178
+ model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_identity.py,sha256=z2J2Xk7b_w_fEgJmK87lwwBmEoAZpGxPmsBrR24IkZs,2035
177
179
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py,sha256=gSqUYh76tP7NcZfqFSnuPIrUpyBh6UjjcPJtJxZtOZk,3181
178
180
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py,sha256=ryes9y1ie-vjBGso2TeO4EXxVk69Ew3iSAhshPz1Ou4,5542
179
181
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/separableconv_decomposition.py,sha256=TEaHlIbXj_ZjIdT5TmAICD3WLD3u_7g0fLWQcNzTJuM,7941
@@ -211,7 +213,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
211
213
  model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
212
214
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
213
215
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=IoMvTch5awAEPvB6Tg6ANhFGXvfSgv7JLsUBlxpMwk4,4330
214
- model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=mT4jd8E1saCpAgrsClufQbnVJ0eYn1xaTQ3teALu4jk,27117
216
+ model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=pDA2hL84XrO0zwAsFxM5a92BO_C2bBEtC9GEo4QaKyM,27267
215
217
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
216
218
  model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=E6ifk1HdO60k4IRH2EFBzAYWtwUlrGqJoQ66nknpHoQ,4983
217
219
  model_compression_toolkit/core/pytorch/utils.py,sha256=dRPiteBg2dBNsHwZyYzXiCIAjnelSoeZZsDXlsTw5JQ,2880
@@ -238,6 +240,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_
238
240
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=VNg-VgzCxSyqy2J3neEPl6U0SPO8UIVU_T47bGhz4FE,38459
239
241
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/permute_call_method.py,sha256=EMCviyFyJFLEKuAUz3rZHLfB9MAU1kywSBL2XQNzLlg,1953
240
242
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py,sha256=9tI14dWDQkTCgLwVZdqmHxEek5KgYPL3x5fnJWWq7bg,5667
243
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_identity.py,sha256=joHjwiUxccypMHkTy46rI91VyapLn9yJ2YRo5ISnOH4,1987
241
244
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=jOqlelGhADEZiYUEyYj9oJZ5YLXx8jWNUlVTG6Td79Y,4919
242
245
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py,sha256=dwRy3ZZ0qShBEQLknkYUVPtgZsk6rjJ4IXf553mcch8,2902
243
246
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py,sha256=XFtU9yuBmoZlX0f0mS6otMPWMk-RcWs94XdvvTNhW8Y,3303
@@ -329,7 +332,7 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha
329
332
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=D_mEUK1sb4kY5946oErfw3RC5mfBTVaw3LZRIKWYKcE,4918
330
333
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=4sN5z-6BXrTE5Dp2FX_jKO9ty5iZ2r4RM7XvXtDVLSI,9348
331
334
  model_compression_toolkit/gptq/__init__.py,sha256=YKg-tMj9D4Yd0xW9VRD5EN1J5JrmlRbNEF2fOSgodqA,1228
332
- model_compression_toolkit/gptq/runner.py,sha256=MIg-oBtR1nbHkexySdCJD_XfjRoHSknLotmGBMuD5qM,5924
335
+ model_compression_toolkit/gptq/runner.py,sha256=PQoLK3WhdRuUwZMd1VbtA7KZ9c-zWig_0ShmTtvJSHY,5970
333
336
  model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
334
337
  model_compression_toolkit/gptq/common/gptq_config.py,sha256=6xP99B-lK1bwGv3AdqxnW1V51z2VdzQcjvoSgJOmygA,5288
335
338
  model_compression_toolkit/gptq/common/gptq_constants.py,sha256=QSm6laLkIV0LYmU0BLtmKp3Fi3SqDfbncFQWOGA1cGU,611
@@ -341,7 +344,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
341
344
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
342
345
  model_compression_toolkit/gptq/keras/gptq_training.py,sha256=zyVcEQzdnNsrIz32U1pqqoi08hzxRdJ2CumaPFGwbDM,19123
343
346
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=5IvgGlJlgOmQYmldjdCBv7tuzAoY0HazatG5Pedrg0Q,4639
344
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=CCV9uyaq-qUGDeXL5OgEWFXSiUkerXrNwFVyA1brrKM,14663
347
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=L5yqjkzw_oszL--dV9EjGoXUYmqM9GmDP7kS7_k96xw,14748
345
348
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
346
349
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=2YU-x4-Q5f6hkUJf0tw6vcwdNwRMHdefrFjhhyHYsvA,4782
347
350
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -358,7 +361,7 @@ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa9
358
361
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
359
362
  model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=xkDa62AdIRwv8dEshffALW9Ri66eseEpyUF9taMUKns,16509
360
363
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=yXJzDd24zfGs2_vfMovxD1WSh1RxXoPxN4GztOf3P5c,3967
361
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=iBLEbLgde6JQNPhJysfT2rl_Sc7-wyoIZnXRAXQWnR0,13065
364
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=bZvrMKN2jFJH9fodtbCCAtKNVXIvlOAnIaxcGov320o,13154
362
365
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
363
366
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=TCA1hAc7raPnrjl06sjFtVM4XUtLtuwAhCGX4U3KGZo,4137
364
367
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
@@ -429,7 +432,7 @@ model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py,sh
429
432
  model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py,sha256=aHoAu5Iye9YVn2HLwNb4X9cUDX1WJt20R5GsNGIAk9E,3337
430
433
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
431
434
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py,sha256=fPOzybGECCWPkAD1hmJryWZrf9vd5Od-UOH6PE0lH94,3820
432
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py,sha256=v1eush7kGZ_Pdl8iyIVkKIqCmix2afiuPZDMgm6kBrE,1522
435
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py,sha256=F5RG4MnuAwKcNXbfVbPFLQu30-lNax-7knqu20B6udQ,1522
433
436
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/__init__.py,sha256=1mMOREEMoNHu_KTMGDp4crN61opKWX6aFn1DrDLvqcc,717
434
437
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py,sha256=S-GwMI-JiuPpbtOdd6TSOEjiUFiIs6M2RAiJNJ3O950,10883
435
438
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py,sha256=bPBWxopMUHFgiaJjaAfoompwShvfH2wHAouN56PQn0A,6484
@@ -480,8 +483,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
480
483
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
481
484
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
482
485
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=7bbzqJN8ZAycVDvZr_5xC-niTAR5df8f03Kooev_pfg,3047
483
- mct_nightly-2.0.0.20240417.406.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
484
- mct_nightly-2.0.0.20240417.406.dist-info/METADATA,sha256=wDYGNbzlScIweXxmrfcYA9RSLM_OaB2fYaIsx28fm-Y,18795
485
- mct_nightly-2.0.0.20240417.406.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
486
- mct_nightly-2.0.0.20240417.406.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
487
- mct_nightly-2.0.0.20240417.406.dist-info/RECORD,,
486
+ mct_nightly-2.0.0.20240418.439.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
487
+ mct_nightly-2.0.0.20240418.439.dist-info/METADATA,sha256=ES0claumrC9y2bX7XAFj8RD6nZNBClpiLEVCOphlRxE,18795
488
+ mct_nightly-2.0.0.20240418.439.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
489
+ mct_nightly-2.0.0.20240418.439.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
490
+ mct_nightly-2.0.0.20240418.439.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.0.0.20240417.000406"
30
+ __version__ = "2.0.0.20240418.000439"
@@ -97,6 +97,8 @@ UPPER_FACTOR = 1.2
97
97
  DEC_RANGE_BOTTOM = 0.97
98
98
  DEC_RANGE_UPPER = 1.03
99
99
 
100
+ NUM_QPARAM_HESSIAN_SAMPLES = 16
101
+
100
102
  # Resource utilization computation parameters
101
103
  BITS_TO_BYTES = 8.0
102
104
 
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import copy
17
- from typing import Dict, Any, Tuple, List, Type
17
+ from typing import Dict, Any, Tuple, List, Type, Union
18
18
 
19
19
  import numpy as np
20
20
 
@@ -17,7 +17,6 @@ from functools import partial
17
17
  from typing import Callable, List
18
18
 
19
19
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
20
- from model_compression_toolkit.core.common import Graph
21
20
  from model_compression_toolkit.core.common.hessian.trace_hessian_request import TraceHessianRequest
22
21
  from model_compression_toolkit.logger import Logger
23
22
 
@@ -38,7 +37,7 @@ class HessianInfoService:
38
37
  """
39
38
 
40
39
  def __init__(self,
41
- graph: Graph,
40
+ graph,
42
41
  representative_dataset: Callable,
43
42
  fw_impl,
44
43
  num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS
@@ -151,7 +150,7 @@ class HessianInfoService:
151
150
  if required_size==0:
152
151
  return []
153
152
 
154
- Logger.info(f"Ensuring {required_size} Hessian-trace approximation for node {trace_hessian_request.target_node}.")
153
+ Logger.info(f"\nEnsuring {required_size} Hessian-trace approximation for node {trace_hessian_request.target_node}.")
155
154
 
156
155
  # Replace request of a reused target node with a request of the 'reuse group'.
157
156
  if trace_hessian_request.target_node.reuse_group:
@@ -16,8 +16,6 @@ from typing import List
16
16
 
17
17
  from enum import Enum
18
18
 
19
- from model_compression_toolkit.core.common import BaseNode
20
-
21
19
 
22
20
  class HessianMode(Enum):
23
21
  """
@@ -54,7 +52,7 @@ class TraceHessianRequest:
54
52
  def __init__(self,
55
53
  mode: HessianMode,
56
54
  granularity: HessianInfoGranularity,
57
- target_node: BaseNode,
55
+ target_node,
58
56
  ):
59
57
  """
60
58
  Attributes:
@@ -26,14 +26,16 @@ class QuantizationErrorMethod(Enum):
26
26
 
27
27
  NOCLIPPING - Use min/max values as thresholds.
28
28
 
29
- MSE - Use min square error for minimizing quantization noise.
29
+ MSE - Use mean square error for minimizing quantization noise.
30
30
 
31
- MAE - Use min absolute error for minimizing quantization noise.
31
+ MAE - Use mean absolute error for minimizing quantization noise.
32
32
 
33
33
  KL - Use KL-divergence to make signals distributions to be similar as possible.
34
34
 
35
35
  Lp - Use Lp-norm to minimizing quantization noise.
36
36
 
37
+ HMSE - Use Hessian-based mean squared error for minimizing quantization noise. This method is using Hessian scores to factorize more valuable parameters when computing the error induced by quantization.
38
+
37
39
  """
38
40
 
39
41
  NOCLIPPING = 0
@@ -41,6 +43,7 @@ class QuantizationErrorMethod(Enum):
41
43
  MAE = 2
42
44
  KL = 4
43
45
  LP = 5
46
+ HMSE = 6
44
47
 
45
48
 
46
49
  class QuantizationConfig:
@@ -13,13 +13,16 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from copy import deepcopy
16
- from typing import Tuple, Callable
16
+ from typing import Tuple, Callable, List
17
17
  import numpy as np
18
18
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
19
+ from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, HessianInfoGranularity, \
20
+ HessianInfoService
19
21
  from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_mae, compute_lp_norm
20
22
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
- from model_compression_toolkit.constants import FLOAT_32
22
- from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor
23
+ from model_compression_toolkit.constants import FLOAT_32, NUM_QPARAM_HESSIAN_SAMPLES
24
+ from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor, \
25
+ reshape_tensor_for_per_channel_search
23
26
 
24
27
 
25
28
  def _mse_error_histogram(q_bins: np.ndarray,
@@ -371,13 +374,63 @@ def _get_sliced_histogram(bins: np.ndarray,
371
374
  return bins_subset, counts_subset
372
375
 
373
376
 
377
+ def _compute_hessian_for_hmse(node,
378
+ hessian_info_service: HessianInfoService,
379
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> List[np.ndarray]:
380
+ """
381
+ Compute and retrieve Hessian-based scores for using during HMSE error computation.
382
+
383
+ Args:
384
+ node: The node to compute Hessian-based scores for.
385
+ hessian_info_service: HessianInfoService object for retrieving Hessian-based scores.
386
+ num_hessian_samples: Number of samples to approximate Hessian-based scores on.
387
+
388
+ Returns: A list with computed Hessian-based scores tensors for the given node.
389
+
390
+ """
391
+ _request = TraceHessianRequest(mode=HessianMode.WEIGHTS,
392
+ granularity=HessianInfoGranularity.PER_ELEMENT,
393
+ target_node=node)
394
+ _scores_for_node = hessian_info_service.fetch_hessian(_request,
395
+ required_size=num_hessian_samples)
396
+
397
+ return _scores_for_node
398
+
399
+
400
+ def _hmse_error_function_wrapper(float_tensor: np.ndarray,
401
+ fxp_tensor: np.ndarray,
402
+ axis: int,
403
+ norm: bool,
404
+ hessian_scores: np.ndarray):
405
+ """
406
+ This function wraps the HMSE error method to enable using it during parameters selection.
407
+
408
+ Args:
409
+ float_tensor: Float tensor.
410
+ fxp_tensor: Quantized tensor.
411
+ axis: Axis along which the operation has been performed. If not None, then per-channel computation is expected.
412
+ norm: Indicates whether to normalize the result of the error function.
413
+ hessian_scores: A tensor with Hessian-based scores to use for Hessian-based MSE (HMSE) error computation.
414
+
415
+ Returns: The HMSE error between the float and fixed-point tensors.
416
+
417
+ """
418
+ if axis is not None:
419
+ hessian_scores = reshape_tensor_for_per_channel_search(hessian_scores, 0)
420
+
421
+ return compute_mse(float_tensor, fxp_tensor, axis, norm, weights=hessian_scores)
422
+
423
+
374
424
  def get_threshold_selection_tensor_error_function(quantization_method: QuantizationMethod,
375
425
  quant_error_method: qc.QuantizationErrorMethod,
376
426
  p: int,
377
427
  axis: int = None,
378
428
  norm: bool = False,
379
429
  n_bits: int = 8,
380
- signed: bool = True) -> Callable:
430
+ signed: bool = True,
431
+ node=None,
432
+ hessian_info_service: HessianInfoService = None,
433
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Callable:
381
434
  """
382
435
  Returns the error function compatible to the provided threshold method,
383
436
  to be used in the threshold optimization search for tensor quantization.
@@ -389,6 +442,9 @@ def get_threshold_selection_tensor_error_function(quantization_method: Quantizat
389
442
  norm: Indicates whether to normalize the result of the error function.
390
443
  n_bits: Number of bits used to quantize the tensor.
391
444
  signed: Indicates whether the input is signed.
445
+ node: The node for which the quantization error is computed (used only with HMSE error method).
446
+ hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (used only with HMSE error method).
447
+ num_hessian_samples: Number of samples to approximate Hessian-based scores on (used only with HMSE error method).
392
448
 
393
449
  Returns: a Callable method that calculates the error between a tensor and a quantized tensor.
394
450
  """
@@ -418,6 +474,13 @@ def get_threshold_selection_tensor_error_function(quantization_method: Quantizat
418
474
  n_bits=n_bits,
419
475
  per_channel=True)
420
476
 
477
+ if quant_error_method == qc.QuantizationErrorMethod.HMSE:
478
+ node_hessian_scores = _compute_hessian_for_hmse(node, hessian_info_service, num_hessian_samples)
479
+ node_hessian_scores = np.sqrt(np.mean(node_hessian_scores, axis=0))
480
+
481
+ return lambda x, y, threshold: _hmse_error_function_wrapper(x, y, norm=norm, axis=axis,
482
+ hessian_scores=node_hessian_scores)
483
+
421
484
  quant_method_error_function_mapping = {
422
485
  qc.QuantizationErrorMethod.MSE: lambda x, y, threshold: compute_mse(x, y, norm=norm, axis=axis),
423
486
  qc.QuantizationErrorMethod.MAE: lambda x, y, threshold: compute_mae(x, y, norm=norm, axis=axis),
@@ -18,7 +18,8 @@ from sklearn.cluster import KMeans
18
18
 
19
19
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
20
20
  from model_compression_toolkit.constants import LUT_VALUES, MIN_THRESHOLD, SCALE_PER_CHANNEL, \
21
- LUT_VALUES_BITWIDTH, THRESHOLD
21
+ LUT_VALUES_BITWIDTH, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES
22
+ from model_compression_toolkit.core.common.hessian import HessianInfoService
22
23
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import \
23
24
  max_power_of_two, int_quantization_with_threshold
24
25
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import \
@@ -37,7 +38,10 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
37
38
  n_iter: int = 10,
38
39
  min_threshold: float = MIN_THRESHOLD,
39
40
  quant_error_method: qc.QuantizationErrorMethod = None,
40
- is_symmetric=False) -> dict:
41
+ is_symmetric=False,
42
+ node=None,
43
+ hessian_info_service: HessianInfoService = None,
44
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> dict:
41
45
  """
42
46
  The quantizer first finds the closest max value per channel of tensor_data.
43
47
  Now, we divide tensor_data with the threshold vector per channel. In addition, we scale the result to the range
@@ -53,7 +57,10 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
53
57
  n_iter: Number of iterations to search_methods for the optimal threshold.
54
58
  min_threshold: Minimal threshold to chose when the computed one is smaller.
55
59
  quant_error_method: an error function to optimize the parameters' selection accordingly (not used for this method).
56
- is_symmetric (bool): Whether to apply symmetric weight quantization (default is False, meaning power of 2 quantization)
60
+ is_symmetric (bool): Whether to apply symmetric weight quantization (default is False, meaning power of 2 quantization).
61
+ node: The node for which the quantization error is computed (not used for this method).
62
+ hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (not used for this method).
63
+ num_hessian_samples: Number of samples to approximate Hessian-based scores on (not used for this method).
57
64
 
58
65
  Returns:
59
66
  A dictionary containing the cluster assignments according to the k-means algorithm,
@@ -15,7 +15,8 @@
15
15
  import numpy as np
16
16
 
17
17
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
18
- from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD
18
+ from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES
19
+ from model_compression_toolkit.core.common.hessian import HessianInfoService
19
20
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \
20
21
  qparams_selection_tensor_search, qparams_selection_histogram_search
21
22
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import max_power_of_two, get_tensor_max
@@ -31,7 +32,11 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
31
32
  channel_axis: int = 1,
32
33
  n_iter: int = 10,
33
34
  min_threshold: float = MIN_THRESHOLD,
34
- quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE) -> dict:
35
+ quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
36
+ node=None,
37
+ hessian_info_service: HessianInfoService = None,
38
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
39
+ ) -> dict:
35
40
  """
36
41
  Compute the power of two threshold based on the provided QuantizationErrorMethod to quantize the tensor.
37
42
  Different search is applied, depends on the value of the selected QuantizationErrorMethod.
@@ -45,6 +50,9 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
45
50
  n_iter: Number of iterations to search for the optimal threshold (not used for this method).
46
51
  min_threshold: Minimal threshold to use if threshold is too small (not used for this method).
47
52
  quant_error_method: an error function to optimize the parameters' selection accordingly.
53
+ node: The node for which the quantization error is computed (used only with HMSE error method).
54
+ hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (used only with HMSE error method).
55
+ num_hessian_samples: Number of samples to approximate Hessian-based scores on (used only with HMSE error method).
48
56
 
49
57
  Returns:
50
58
  Power of two threshold to quantize the tensor in a power of 2 manner.
@@ -57,8 +65,10 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
57
65
  signed = True # weights are always signed
58
66
  axis = -1 if per_channel else None
59
67
  error_function = get_threshold_selection_tensor_error_function(QuantizationMethod.POWER_OF_TWO,
60
- quant_error_method, p, axis=axis, norm=False, n_bits=n_bits,
61
- signed=signed)
68
+ quant_error_method, p, axis=axis, norm=False,
69
+ n_bits=n_bits, signed=signed, node=node,
70
+ hessian_info_service=hessian_info_service,
71
+ num_hessian_samples=num_hessian_samples)
62
72
  threshold = qparams_selection_tensor_search(error_function,
63
73
  tensor_data,
64
74
  n_bits,
@@ -12,10 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
16
+
15
17
  from tqdm import tqdm
16
18
  from typing import List
17
19
 
20
+ from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
21
+ from model_compression_toolkit.core import QuantizationErrorMethod
18
22
  from model_compression_toolkit.core.common import Graph, BaseNode
23
+ from model_compression_toolkit.core.common.hessian import HessianInfoService
19
24
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
20
25
  import get_activations_qparams
21
26
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
@@ -25,7 +30,9 @@ from model_compression_toolkit.logger import Logger
25
30
 
26
31
  def calculate_quantization_params(graph: Graph,
27
32
  nodes: List[BaseNode] = [],
28
- specific_nodes: bool = False):
33
+ specific_nodes: bool = False,
34
+ hessian_info_service: HessianInfoService = None,
35
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES):
29
36
  """
30
37
  For a graph, go over its nodes, compute quantization params (for both weights and activations according
31
38
  to the given framework info), and create and attach a NodeQuantizationConfig to each node (containing the
@@ -39,6 +46,8 @@ def calculate_quantization_params(graph: Graph,
39
46
  graph: Graph to compute its nodes' thresholds.
40
47
  nodes: List of nodes to compute their thresholds instead of computing it for all nodes in the graph.
41
48
  specific_nodes: Flag to compute thresholds for only specific nodes.
49
+ hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (used only with HMSE error method).
50
+ num_hessian_samples: Number of samples to approximate Hessian-based scores on (used only with HMSE error method).
42
51
 
43
52
  """
44
53
 
@@ -60,10 +69,28 @@ def calculate_quantization_params(graph: Graph,
60
69
  output_channels_axis = channels_axis[0]
61
70
  else:
62
71
  output_channels_axis = None
72
+
73
+ mod_attr_cfg = attr_cfg
74
+
75
+ if attr_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
76
+ kernel_attr_name = graph.fw_info.get_kernel_op_attributes(n.type)
77
+ if len(kernel_attr_name) > 0:
78
+ kernel_attr_name = kernel_attr_name[0]
79
+
80
+ if kernel_attr_name is None or kernel_attr_name not in attr:
81
+ Logger.warning(f"The HMSE error method for parameters selection is only supported for "
82
+ f"kernel weights attributes. Running parameters selection for attribute "
83
+ f"'{attr}' in node '{n.name}' with the default MSE error method instead.")
84
+ mod_attr_cfg = copy.deepcopy(attr_cfg)
85
+ mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE
86
+
63
87
  weights_params = get_weights_qparams(n.get_weights_by_keys(attr),
64
88
  candidate_qc.weights_quantization_cfg,
65
- attr_cfg,
66
- output_channels_axis)
89
+ mod_attr_cfg,
90
+ output_channels_axis,
91
+ node=n,
92
+ hessian_info_service=hessian_info_service,
93
+ num_hessian_samples=num_hessian_samples)
67
94
  attr_cfg.set_weights_quantization_param(weights_params)
68
95
 
69
96
  if n.is_activation_quantization_enabled():
@@ -12,11 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Dict, Any, Tuple
15
+ from typing import Dict, Any
16
16
 
17
17
  import numpy as np
18
18
 
19
- from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
20
+ from model_compression_toolkit.core.common.hessian import HessianInfoService
20
21
  from model_compression_toolkit.defaultdict import DefaultDict
21
22
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
23
  from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
@@ -27,31 +28,40 @@ from model_compression_toolkit.core.common.quantization.node_quantization_config
27
28
  dummy_channel_mapping = DefaultDict(default_value=(None, None))
28
29
 
29
30
 
30
- def get_weights_qparams(kernel: np.ndarray,
31
+ def get_weights_qparams(weights_attr_values: np.ndarray,
31
32
  weights_quant_config: NodeWeightsQuantizationConfig,
32
33
  attr_quant_config: WeightsAttrQuantizationConfig,
33
- output_channels_axis: int) -> Dict[Any, Any]:
34
+ output_channels_axis: int,
35
+ node=None,
36
+ hessian_info_service: HessianInfoService = None,
37
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Dict[Any, Any]:
34
38
  """
35
39
  Compute thresholds to quantize a kernel according to a NodeWeightsQuantizationConfig
36
40
  instance.
37
41
 
38
42
  Args:
39
- kernel: Kernel to compute the quantization thresholds to.
43
+ weights_attr_values: Weights attribute parameter to compute the quantization thresholds for.
40
44
  weights_quant_config: Weights quantization configuration to define how the thresholds are computed.
41
45
  attr_quant_config: A specific weights attribute quantization configuration to get its params.
42
46
  output_channels_axis: Index of the kernel output channels dimension.
47
+ node: The node for which the quantization error is computed (used only with HMSE error method).
48
+ hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (used only with HMSE error method).
49
+ num_hessian_samples: Number of samples to approximate Hessian-based scores on (used only with HMSE error method).
43
50
 
44
51
  Returns:
45
52
  A dictionary with the quantization threshold of the kernel.
46
53
  """
47
54
  if attr_quant_config.weights_quantization_params_fn is not None:
48
- weights_params = attr_quant_config.weights_quantization_params_fn(kernel,
55
+ weights_params = attr_quant_config.weights_quantization_params_fn(weights_attr_values,
49
56
  p=attr_quant_config.l_p_value,
50
57
  n_bits=attr_quant_config.weights_n_bits,
51
58
  per_channel=attr_quant_config.weights_per_channel_threshold and output_channels_axis is not None,
52
59
  channel_axis=output_channels_axis,
53
60
  min_threshold=weights_quant_config.min_threshold,
54
- quant_error_method=attr_quant_config.weights_error_method)
61
+ quant_error_method=attr_quant_config.weights_error_method,
62
+ node=node,
63
+ hessian_info_service=hessian_info_service,
64
+ num_hessian_samples=num_hessian_samples)
55
65
  else:
56
66
  weights_params = {}
57
67
 
@@ -15,7 +15,8 @@
15
15
  import numpy as np
16
16
 
17
17
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
18
- from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD
18
+ from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES
19
+ from model_compression_toolkit.core.common.hessian import HessianInfoService
19
20
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
20
21
  get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function, _kl_error_histogram
21
22
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \
@@ -33,7 +34,10 @@ def symmetric_selection_tensor(tensor_data: np.ndarray,
33
34
  channel_axis: int = 1,
34
35
  n_iter: int = 10,
35
36
  min_threshold: float = MIN_THRESHOLD,
36
- quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE) -> dict:
37
+ quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
38
+ node=None,
39
+ hessian_info_service: HessianInfoService = None,
40
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> dict:
37
41
  """
38
42
  Compute the optimal threshold based on the provided QuantizationErrorMethod to quantize the tensor.
39
43
  Different search is applied, depends on the value of the selected QuantizationErrorMethod.
@@ -47,6 +51,9 @@ def symmetric_selection_tensor(tensor_data: np.ndarray,
47
51
  n_iter: Number of iterations to search for the optimal threshold (not used for this method).
48
52
  min_threshold: Minimal threshold to use if threshold is too small (not used for this method).
49
53
  quant_error_method: an error function to optimize the parameters' selection accordingly.
54
+ node: The node for which the quantization error is computed (used only with HMSE error method).
55
+ hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (used only with HMSE error method).
56
+ num_hessian_samples: Number of samples to approximate Hessian-based scores on (used only with HMSE error method).
50
57
 
51
58
  Returns:
52
59
  Optimal threshold to quantize the tensor in a symmetric manner.
@@ -59,7 +66,11 @@ def symmetric_selection_tensor(tensor_data: np.ndarray,
59
66
  else:
60
67
  signed = True # weights are always signed
61
68
  axis = -1 if per_channel else None
62
- error_function = get_threshold_selection_tensor_error_function(QuantizationMethod.SYMMETRIC, quant_error_method, p, axis=axis, norm=False, n_bits=n_bits, signed=signed)
69
+ error_function = get_threshold_selection_tensor_error_function(QuantizationMethod.SYMMETRIC, quant_error_method,
70
+ p, axis=axis, norm=False, n_bits=n_bits,
71
+ signed=signed, node=node,
72
+ hessian_info_service=hessian_info_service,
73
+ num_hessian_samples=num_hessian_samples)
63
74
  threshold = qparams_symmetric_selection_tensor_search(error_function,
64
75
  tensor_data,
65
76
  tensor_max,
@@ -15,7 +15,8 @@
15
15
  import numpy as np
16
16
 
17
17
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
18
- from model_compression_toolkit.constants import MIN_THRESHOLD, RANGE_MIN, RANGE_MAX
18
+ from model_compression_toolkit.constants import MIN_THRESHOLD, RANGE_MIN, RANGE_MAX, NUM_QPARAM_HESSIAN_SAMPLES
19
+ from model_compression_toolkit.core.common.hessian import HessianInfoService
19
20
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \
20
21
  qparams_uniform_selection_tensor_search, qparams_uniform_selection_histogram_search
21
22
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
@@ -31,7 +32,10 @@ def uniform_selection_tensor(tensor_data: np.ndarray,
31
32
  channel_axis: int = 1,
32
33
  n_iter: int = 10,
33
34
  min_threshold: float = MIN_THRESHOLD,
34
- quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE) -> dict:
35
+ quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
36
+ node=None,
37
+ hessian_info_service: HessianInfoService = None,
38
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> dict:
35
39
  """
36
40
  Compute the optimal quantization range based on the provided QuantizationErrorMethod
37
41
  to uniformly quantize the tensor.
@@ -46,6 +50,9 @@ def uniform_selection_tensor(tensor_data: np.ndarray,
46
50
  n_iter: Number of iterations to search for the optimal threshold (not used for this method).
47
51
  min_threshold: Minimal threshold to use if threshold is too small (not used for this method).
48
52
  quant_error_method: an error function to optimize the range parameters' selection accordingly.
53
+ node: The node for which the quantization error is computed (used only with HMSE error method).
54
+ hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (used only with HMSE error method).
55
+ num_hessian_samples: Number of samples to approximate Hessian-based scores on (used only with HMSE error method).
49
56
 
50
57
  Returns:
51
58
  Optimal quantization range to quantize the tensor uniformly.
@@ -57,7 +64,10 @@ def uniform_selection_tensor(tensor_data: np.ndarray,
57
64
  mm = tensor_min, tensor_max
58
65
  else:
59
66
  axis = -1 if per_channel else None
60
- error_function = get_threshold_selection_tensor_error_function(QuantizationMethod.UNIFORM, quant_error_method, p, axis=axis, norm=False)
67
+ error_function = get_threshold_selection_tensor_error_function(QuantizationMethod.UNIFORM, quant_error_method,
68
+ p, axis=axis, norm=False, node=node,
69
+ hessian_info_service=hessian_info_service,
70
+ num_hessian_samples=num_hessian_samples)
61
71
  mm = qparams_uniform_selection_tensor_search(error_function,
62
72
  tensor_data,
63
73
  tensor_min,
@@ -24,7 +24,8 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
24
24
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
25
25
  CandidateNodeQuantizationConfig
26
26
  from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
27
- from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
27
+ from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
28
+ QuantizationErrorMethod
28
29
  from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
29
30
  get_activation_quantization_params_fn, get_weights_quantization_params_fn
30
31
  from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
@@ -36,19 +37,31 @@ from model_compression_toolkit.target_platform_capabilities.target_platform.op_q
36
37
 
37
38
  def set_quantization_configuration_to_graph(graph: Graph,
38
39
  quant_config: QuantizationConfig,
39
- mixed_precision_enable: bool = False) -> Graph:
40
+ mixed_precision_enable: bool = False,
41
+ running_gptq: bool = False) -> Graph:
40
42
  """
41
43
  Add quantization configuration for each graph node.
42
44
 
43
45
  Args:
44
46
  graph: Graph for which to add quantization info to each node.
45
47
  quant_config: Quantization configuration containing parameters for how the graph should be quantized.
46
- mixed_precision_enable: is mixed precision enabled
48
+ mixed_precision_enable: is mixed precision enabled.
49
+ running_gptq: Whether or not a GPTQ optimization is planned to run after the PTQ process.
47
50
 
48
51
  Returns:
49
52
  The graph with quantization configurations attached to each node in it.
50
53
  """
51
54
 
55
+ if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
56
+ if not running_gptq:
57
+ Logger.warning(f"The HMSE error method for parameters selection is only supported when running GPTQ "
58
+ f"optimization due to long execution time that is not suitable for basic PTQ. "
59
+ f"Using the default MSE error method instead.")
60
+ quant_config.weights_error_method = QuantizationErrorMethod.MSE
61
+ else:
62
+ Logger.warning("Using the HMSE error method for weights quantization parameters search. "
63
+ "Note: This method may significantly increase runtime during the parameter search process.")
64
+
52
65
  for n in graph.nodes:
53
66
  set_quantization_configs_to_node(node=n,
54
67
  quant_config=quant_config,
@@ -18,6 +18,8 @@ from typing import Any
18
18
  import numpy as np
19
19
 
20
20
  from model_compression_toolkit.constants import EPS
21
+ from model_compression_toolkit.logger import Logger
22
+
21
23
 
22
24
  #########################
23
25
  # Helpful functions
@@ -87,7 +89,8 @@ def compute_mse(float_tensor: np.ndarray,
87
89
  norm: bool = False,
88
90
  norm_eps: float = 1e-8,
89
91
  batch: bool = False,
90
- axis: int = None) -> float:
92
+ axis: int = None,
93
+ weights: np.ndarray = None) -> float:
91
94
  """
92
95
  Compute the mean square error between two numpy arrays.
93
96
 
@@ -98,6 +101,7 @@ def compute_mse(float_tensor: np.ndarray,
98
101
  norm_eps: epsilon value for error normalization stability.
99
102
  batch: Whether to run batch similarity analysis or not.
100
103
  axis: Axis along which the operator has been computed.
104
+ weights: Weights tensor to use for computing Weighted-MSE error computation.
101
105
 
102
106
  Returns:
103
107
  The MSE distance between the two tensors.
@@ -107,7 +111,15 @@ def compute_mse(float_tensor: np.ndarray,
107
111
  float_flat = flatten_tensor(float_tensor, batch, axis)
108
112
  fxp_flat = flatten_tensor(fxp_tensor, batch, axis)
109
113
 
110
- error = ((float_flat - fxp_flat) ** 2).mean(axis=-1)
114
+ if weights is not None:
115
+ w_flat = flatten_tensor(weights, batch, axis)
116
+ if w_flat.shape != float_flat.shape:
117
+ Logger.critical(f"Shape mismatch: The shape of the weights tensor {weights.shape} does not match the shape "
118
+ f"of the input tensors {float_flat.shape} for Weighted-MSE computation.") # pragma: no cover
119
+ error = ((w_flat * (float_flat - fxp_flat)) ** 2).mean(axis=-1)
120
+ else:
121
+ error = ((float_flat - fxp_flat) ** 2).mean(axis=-1)
122
+
111
123
  if norm:
112
124
  error /= ((float_flat ** 2).mean(axis=-1) + norm_eps)
113
125
 
@@ -0,0 +1,48 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from model_compression_toolkit.core.common.graph.base_graph import Graph
17
+ from model_compression_toolkit.core.common.graph.base_node import BaseNode
18
+
19
+
20
+ def remove_identity_node(graph: Graph,
21
+ node: BaseNode) -> Graph:
22
+ """
23
+ The method to perform the substitution of the identity node by
24
+ reconnecting its input directly to its output, effectively removing the node
25
+ from the graph.
26
+
27
+ Args:
28
+ graph: The current graph of operations where the node resides.
29
+ node: The specific `BaseNode` that is matched to be an Identity operation.
30
+
31
+ Returns:
32
+ Graph: The updated graph after removing the identity node.
33
+ """
34
+ # Retrieve the predecessor nodes of the identity node.
35
+ prev_identity_nodes = graph.get_prev_nodes(node)
36
+ # Ensure there is exactly one predecessor; otherwise, do nothing.
37
+ if len(prev_identity_nodes) != 1:
38
+ return graph
39
+
40
+ # Reconnect the output edges of the identity node to its predecessor,
41
+ # effectively bypassing the identity node.
42
+ graph.reconnect_out_edges(current_node=node, new_node=prev_identity_nodes[0])
43
+ # Remove the edge from the predecessor to the identity node.
44
+ graph.remove_edge(prev_identity_nodes[0], node)
45
+ # Remove the identity node from the graph.
46
+ graph.remove_node(node_to_remove=node)
47
+
48
+ return graph
@@ -39,7 +39,8 @@ def graph_preparation_runner(in_model: Any,
39
39
  fw_impl: FrameworkImplementation,
40
40
  tpc: TargetPlatformCapabilities,
41
41
  tb_w: TensorboardWriter = None,
42
- mixed_precision_enable: bool = False) -> Graph:
42
+ mixed_precision_enable: bool = False,
43
+ running_gptq: bool = False) -> Graph:
43
44
  """
44
45
  Runs all required preparations in order to build a quantization graph from the given model,
45
46
  quantization configuration and target platform specifications.
@@ -59,6 +60,7 @@ def graph_preparation_runner(in_model: Any,
59
60
  the attached framework operator's information.
60
61
  tb_w: TensorboardWriter object for logging.
61
62
  mixed_precision_enable: is mixed precision enabled.
63
+ running_gptq: Whether or not a GPTQ optimization is planned to run after the PTQ process.
62
64
 
63
65
  Returns:
64
66
  An internal graph representation of the input model.
@@ -79,7 +81,8 @@ def graph_preparation_runner(in_model: Any,
79
81
  fw_info,
80
82
  tb_w,
81
83
  fw_impl,
82
- mixed_precision_enable=mixed_precision_enable)
84
+ mixed_precision_enable=mixed_precision_enable,
85
+ running_gptq=running_gptq)
83
86
 
84
87
  return transformed_graph
85
88
 
@@ -90,7 +93,8 @@ def get_finalized_graph(initial_graph: Graph,
90
93
  fw_info: FrameworkInfo = None,
91
94
  tb_w: TensorboardWriter = None,
92
95
  fw_impl: FrameworkImplementation = None,
93
- mixed_precision_enable: bool = False) -> Graph:
96
+ mixed_precision_enable: bool = False,
97
+ running_gptq: bool = False) -> Graph:
94
98
  """
95
99
  Applies all edit operation (edit, substitutions, etc.) on the model's graph, to prepare it for the quantization
96
100
  process. All future graph substitutions and operations that change the graph should be added to this method.
@@ -105,6 +109,7 @@ def get_finalized_graph(initial_graph: Graph,
105
109
  tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
106
110
  fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
107
111
  mixed_precision_enable: is mixed precision enabled.
112
+ running_gptq: Whether or not a GPTQ optimization is planned to run after the PTQ process.
108
113
 
109
114
  Returns: Graph object that represents the model, after applying all required modifications to it.
110
115
  """
@@ -142,7 +147,8 @@ def get_finalized_graph(initial_graph: Graph,
142
147
  ######################################
143
148
  transformed_graph = set_quantization_configuration_to_graph(graph=transformed_graph,
144
149
  quant_config=quant_config,
145
- mixed_precision_enable=mixed_precision_enable)
150
+ mixed_precision_enable=mixed_precision_enable,
151
+ running_gptq=running_gptq)
146
152
 
147
153
  ######################################
148
154
  # Layer fusing
@@ -0,0 +1,51 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import keras
17
+ import tensorflow as tf
18
+
19
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
20
+ from model_compression_toolkit.core import common
21
+ from model_compression_toolkit.core.common.graph.base_graph import Graph
22
+ from model_compression_toolkit.core.common.graph.base_node import BaseNode
23
+ from model_compression_toolkit.core.common.substitutions.remove_identity import remove_identity_node
24
+
25
+
26
+ class RemoveIdentity(common.BaseSubstitution):
27
+ """
28
+ Remove Identity layers from the graph.
29
+ """
30
+
31
+ def __init__(self):
32
+ nodes = NodeOperationMatcher(keras.layers.Identity) | NodeOperationMatcher(tf.identity)
33
+ super().__init__(matcher_instance=nodes)
34
+
35
+ def substitute(self,
36
+ graph: Graph,
37
+ node: BaseNode) -> Graph:
38
+ """
39
+ The method to perform the substitution of the identity keras node by
40
+ reconnecting its input directly to its output, effectively removing the node
41
+ from the graph.
42
+
43
+ Args:
44
+ graph: The current graph of operations where the node resides.
45
+ node: The specific `BaseNode` that is matched to be an Identity operation.
46
+
47
+ Returns:
48
+ Graph: The updated graph after removing the identity node.
49
+ """
50
+ return remove_identity_node(graph, node)
51
+
@@ -22,6 +22,7 @@ from tensorflow.keras.models import Model
22
22
 
23
23
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
24
24
  from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, HessianInfoService
25
+ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.remove_identity import RemoveIdentity
25
26
  from model_compression_toolkit.core.keras.hessian.activation_trace_hessian_calculator_keras import \
26
27
  ActivationTraceHessianCalculatorKeras
27
28
  from model_compression_toolkit.core.keras.hessian.weights_trace_hessian_calculator_keras import WeightsTraceHessianCalculatorKeras
@@ -246,7 +247,8 @@ class KerasImplementation(FrameworkImplementation):
246
247
  MatmulToDenseSubstitution(),
247
248
  MultiHeadAttentionDecomposition(),
248
249
  ActivationDecomposition(),
249
- DwconvToConv()]
250
+ DwconvToConv(),
251
+ RemoveIdentity()]
250
252
 
251
253
  def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig) -> \
252
254
  List[common.BaseSubstitution]:
@@ -0,0 +1,50 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+
17
+ from model_compression_toolkit.core.common.substitutions.remove_identity import remove_identity_node
18
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
19
+ from model_compression_toolkit.core import common
20
+ from model_compression_toolkit.core.common.graph.base_graph import Graph
21
+ from model_compression_toolkit.core.common.graph.base_node import BaseNode
22
+
23
+
24
+ class RemoveIdentity(common.BaseSubstitution):
25
+ """
26
+ Remove `torch.nn.Identity` layers from the graph.
27
+ """
28
+
29
+ def __init__(self):
30
+ nodes = NodeOperationMatcher(torch.nn.Identity)
31
+ super().__init__(matcher_instance=nodes)
32
+
33
+ def substitute(self,
34
+ graph: Graph,
35
+ node: BaseNode) -> Graph:
36
+ """
37
+ The method to perform the substitution of the `torch.nn.Identity` node by
38
+ reconnecting its input directly to its output, effectively removing the node
39
+ from the graph.
40
+
41
+ Args:
42
+ graph: The current graph of operations where the node resides.
43
+ node: The specific `BaseNode` that is matched to be an Identity operation.
44
+
45
+ Returns:
46
+ Graph: The updated graph after removing the identity node.
47
+ """
48
+ return remove_identity_node(graph, node)
49
+
50
+
@@ -58,6 +58,7 @@ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.co
58
58
  FunctionalConvSubstitution
59
59
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
60
60
  ReLUBoundToPowerOfTwo
61
+ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.remove_identity import RemoveIdentity
61
62
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.reshape_with_static_shapes import \
62
63
  ReshapeWithStaticShapes
63
64
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.residual_collapsing import \
@@ -238,7 +239,8 @@ class PytorchImplementation(FrameworkImplementation):
238
239
  PermuteCallMethod(),
239
240
  FunctionalConvSubstitution(fw_info),
240
241
  FunctionalBatchNorm(),
241
- FunctionalLayerNorm()]
242
+ FunctionalLayerNorm(),
243
+ RemoveIdentity()]
242
244
 
243
245
  def get_substitutions_pre_statistics_collection(self,
244
246
  quant_config: QuantizationConfig
@@ -21,6 +21,7 @@ from tqdm import tqdm
21
21
  from model_compression_toolkit.core.common import FrameworkInfo
22
22
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
23
23
  from model_compression_toolkit.core.common.graph.base_graph import Graph
24
+ from model_compression_toolkit.core.common.hessian import HessianInfoService
24
25
  from model_compression_toolkit.core.common.model_collector import ModelCollector
25
26
  from model_compression_toolkit.core.common.network_editors.edit_network import edit_network_graph
26
27
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
@@ -38,7 +39,8 @@ def quantization_preparation_runner(graph: Graph,
38
39
  core_config: CoreConfig,
39
40
  fw_info: FrameworkInfo,
40
41
  fw_impl: FrameworkImplementation,
41
- tb_w: TensorboardWriter = None) -> Graph:
42
+ tb_w: TensorboardWriter = None,
43
+ hessian_info_service: HessianInfoService = None,) -> Graph:
42
44
  """
43
45
  Prepares a trained model for post-training quantization.
44
46
  First, the model graph is optimized using several transformations (e.g. folding BatchNormalization to preceding layers).
@@ -55,6 +57,7 @@ def quantization_preparation_runner(graph: Graph,
55
57
  groups of layers by how they should be quantized, etc.).
56
58
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
57
59
  tb_w: TensorboardWriter object for logging
60
+ hessian_info_service: HessianInfoService object for retrieving Hessian-based scores.
58
61
 
59
62
  Returns:
60
63
  Graph object that represents the model, contains thresholds, and ready for quantization.
@@ -86,7 +89,8 @@ def quantization_preparation_runner(graph: Graph,
86
89
  ######################################
87
90
  # Calculate quantization params
88
91
  ######################################
89
- calculate_quantization_params(graph)
92
+
93
+ calculate_quantization_params(graph, hessian_info_service=hessian_info_service)
90
94
 
91
95
  if tb_w is not None:
92
96
  tb_w.add_graph(graph, 'thresholds_selection')
@@ -48,6 +48,7 @@ def core_runner(in_model: Any,
48
48
  fw_impl: FrameworkImplementation,
49
49
  tpc: TargetPlatformCapabilities,
50
50
  target_resource_utilization: ResourceUtilization = None,
51
+ running_gptq: bool = False,
51
52
  tb_w: TensorboardWriter = None):
52
53
  """
53
54
  Quantize a trained model using post-training quantization.
@@ -97,7 +98,8 @@ def core_runner(in_model: Any,
97
98
  fw_impl,
98
99
  tpc,
99
100
  tb_w,
100
- mixed_precision_enable=core_config.mixed_precision_enable)
101
+ mixed_precision_enable=core_config.mixed_precision_enable,
102
+ running_gptq=running_gptq)
101
103
 
102
104
  hessian_info_service = HessianInfoService(graph=graph,
103
105
  representative_dataset=representative_data_gen,
@@ -108,7 +110,8 @@ def core_runner(in_model: Any,
108
110
  core_config=core_config,
109
111
  fw_info=fw_info,
110
112
  fw_impl=fw_impl,
111
- tb_w=tb_w)
113
+ tb_w=tb_w,
114
+ hessian_info_service=hessian_info_service)
112
115
 
113
116
  ######################################
114
117
  # Finalize bit widths
@@ -212,7 +212,8 @@ if FOUND_TF:
212
212
  fw_impl=fw_impl,
213
213
  tpc=target_platform_capabilities,
214
214
  target_resource_utilization=target_resource_utilization,
215
- tb_w=tb_w)
215
+ tb_w=tb_w,
216
+ running_gptq=True)
216
217
 
217
218
  float_graph = copy.deepcopy(tg)
218
219
 
@@ -180,7 +180,9 @@ if FOUND_TORCH:
180
180
  fw_impl=fw_impl,
181
181
  tpc=target_platform_capabilities,
182
182
  target_resource_utilization=target_resource_utilization,
183
- tb_w=tb_w)
183
+ tb_w=tb_w,
184
+ running_gptq=True)
185
+
184
186
  float_graph = copy.deepcopy(graph)
185
187
 
186
188
  # ---------------------- #
@@ -111,6 +111,7 @@ def gptq_runner(tg: Graph,
111
111
  #############################################
112
112
  # Gradient Based Post Training Quantization
113
113
  #############################################
114
+ Logger.info("Running GPTQ optimization.")
114
115
  tg_gptq = _apply_gptq(gptq_config,
115
116
  gptq_representative_data_gen,
116
117
  tb_w,
@@ -13,12 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH
16
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tp_model import get_tp_model, generate_tp_model, \
16
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model, generate_tp_model, \
17
17
  get_op_quantization_configs
18
18
  if FOUND_TF:
19
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import get_keras_tpc as get_keras_tpc_latest
20
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import generate_keras_tpc
19
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_latest
20
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import generate_keras_tpc
21
21
  if FOUND_TORCH:
22
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import get_pytorch_tpc as \
22
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import get_pytorch_tpc as \
23
23
  get_pytorch_tpc_latest
24
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import generate_pytorch_tpc
24
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import generate_pytorch_tpc