mct-nightly 2.2.0.20241106.458__py3-none-any.whl → 2.2.0.20241107.459__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241107.459.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241107.459.dist-info}/RECORD +17 -29
  3. {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241107.459.dist-info}/top_level.txt +0 -1
  4. model_compression_toolkit/__init__.py +1 -1
  5. model_compression_toolkit/core/common/framework_implementation.py +46 -27
  6. model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -0
  7. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -0
  8. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +81 -0
  9. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +190 -0
  10. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +14 -2
  11. model_compression_toolkit/core/keras/keras_implementation.py +23 -2
  12. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +67 -0
  13. model_compression_toolkit/core/pytorch/pytorch_implementation.py +21 -0
  14. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +57 -0
  15. model_compression_toolkit/core/runner.py +8 -0
  16. tests_pytest/__init__.py +0 -14
  17. tests_pytest/keras/__init__.py +0 -14
  18. tests_pytest/keras/core/__init__.py +0 -14
  19. tests_pytest/keras/core/test_data_util.py +0 -91
  20. tests_pytest/keras/gptq/__init__.py +0 -14
  21. tests_pytest/keras/gptq/test_gradual_act_quantization.py +0 -102
  22. tests_pytest/keras/trainable_infrastructure/__init__.py +0 -16
  23. tests_pytest/keras/trainable_infrastructure/test_linear_annealing.py +0 -49
  24. tests_pytest/pytorch/__init__.py +0 -14
  25. tests_pytest/pytorch/core/__init__.py +0 -14
  26. tests_pytest/pytorch/core/test_data_util.py +0 -125
  27. tests_pytest/pytorch/gptq/__init__.py +0 -14
  28. tests_pytest/pytorch/gptq/test_annealing_cfg.py +0 -40
  29. tests_pytest/pytorch/gptq/test_gradual_act_quantization.py +0 -100
  30. tests_pytest/pytorch/trainable_infrastructure/__init__.py +0 -14
  31. tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py +0 -49
  32. {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241107.459.dist-info}/LICENSE.md +0 -0
  33. {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241107.459.dist-info}/WHEEL +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20241106.458
3
+ Version: 2.2.0.20241107.459
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=GTkn7rteG7LZ-9q9GZqbdyesxj5ZTiAgyOI4bdnUc6A,1573
1
+ model_compression_toolkit/__init__.py,sha256=PruMwc3p_YlsqfDYAgQ3ZUCTOuwfbVcA1bqPzuTsSQM,1573
2
2
  model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -8,10 +8,10 @@ model_compression_toolkit/core/__init__.py,sha256=tnDtL9KmT0vsOU27SsJ19TKDEbIH-t
8
8
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
9
9
  model_compression_toolkit/core/graph_prep_runner.py,sha256=7-b7Jd5jBVaXOWg5nSqbEyzBtdaGDbCxs8aqMV6GZ6I,11287
10
10
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=OtL6g2rTC5mfdKrkzm47EPPW-voGGVYMYxpy2_sfu1U,6547
11
- model_compression_toolkit/core/runner.py,sha256=lahkYyfdsb3HJPJ5Lui7hp4vVWyIOJLXJQ5ATxiIyos,14264
11
+ model_compression_toolkit/core/runner.py,sha256=IavCZRVG9RisEKvFDxz27WDRKrfIG03YKXKv3tcagPo,14700
12
12
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
13
13
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
14
- model_compression_toolkit/core/common/framework_implementation.py,sha256=7k66e0b06eLnFLmu67onWPiM2lJfhWiuyQZPsRJm3lk,21294
14
+ model_compression_toolkit/core/common/framework_implementation.py,sha256=IkMydCj6voau7dwkYLYA_Ka_EFUKP3GKQdpYN6b1fgc,22163
15
15
  model_compression_toolkit/core/common/framework_info.py,sha256=1ZMMGS9ip-kSflqkartyNRt9aQ5ub1WepuTRcTy-YSQ,6337
16
16
  model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
17
17
  model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
@@ -105,8 +105,8 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
105
105
  model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
106
106
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=zJP2W9apUPX9RstpPWWK71wr9xJsg7j-s7lGV4_bQdc,1510
107
107
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
108
- model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=YycYN8_JMzvSR3pTVm5dT5x4zP3yBHn0Z9agnwrvOKI,26395
109
- model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=BTDa1Izpdd4Z4essxTWP42V87f8mdq9vdKdVhE8vibo,3818
108
+ model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=HmtyIQCQqhay-8oqU3rUHOeK6VhTtH9nuW24HigCUo0,26517
109
+ model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=nBqwNhbDbWQGYbfazLPHrP_ZCCnjbL-k5q58T8yIAcc,3917
110
110
  model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=eyosbVdnCwed7oMQ19tqnh0VoyGZ_UAuD_UnNoXyBpo,2210
111
111
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
112
112
  model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
@@ -128,10 +128,12 @@ model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantiz
128
128
  model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py,sha256=iEoWUPFQMcvZXHtLMe2_7L7IK25XcKiY6-d1_gArZs0,11880
129
129
  model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py,sha256=wXExWHf5-0He7L4bpvFpKlx7FG4u3DAfNZiXPpOs_SQ,5521
130
130
  model_compression_toolkit/core/common/statistics_correction/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
131
+ model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py,sha256=Aw2N7FSO7p1Kmh-tUjajV9pqrjMJQtgF5etG0WV9Le8,4440
131
132
  model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py,sha256=xSWVDOODgbN0k4mjJWWtpawilOsqdm4O7Uw2hbA75EA,4669
132
133
  model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py,sha256=C_nwhhitTd1pCto0nHZPn3fjIMOeDD7VIciumTR3s6k,5641
134
+ model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py,sha256=ov9-WYktWKqRquibwyARR81QVT9TfPWAoTTfnKOQSd0,9273
133
135
  model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py,sha256=LaGhYES7HgIDf9Bi2KAG_mBzAWuum0J6AGmAFPC8wwo,10478
134
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=5dzNtzDMmmLETgAU23k8Cu7q0q3z1EyS-46_Yx-aS7s,5519
136
+ model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=E0ZA4edimJwpHh9twI5gafcoJ9fX5F1JX2QUOkUOKEw,6250
135
137
  model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
136
138
  model_compression_toolkit/core/common/substitutions/apply_substitutions.py,sha256=k-bifmakHIYZeZS-4T1QpZ1Et6AwAijMRgAKs7hmMKc,1390
137
139
  model_compression_toolkit/core/common/substitutions/batchnorm_folding.py,sha256=wLlTT7sqUffKHwOrMG2VV5SktQkkP54l8taW1Fq0mh0,13392
@@ -155,7 +157,7 @@ model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiO
155
157
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
156
158
  model_compression_toolkit/core/keras/data_util.py,sha256=JdomIJZfep0QYPtx2jlg0xJ40cd9S_I7BakaWQi0wKw,2681
157
159
  model_compression_toolkit/core/keras/default_framework_info.py,sha256=PYcER89eEXjKtR0T7-2Y4f7cckqoD5OQbpHePoRkMec,5030
158
- model_compression_toolkit/core/keras/keras_implementation.py,sha256=Tn4_rkcx9bH3x-pEoUbGu94S7_nj3Hl3BfvL8SPIL3g,30957
160
+ model_compression_toolkit/core/keras/keras_implementation.py,sha256=Hi8seiFJdFqgYGGC003Y4879JQ7rmVZe8YiJ76T7FDE,32133
159
161
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
160
162
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
161
163
  model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=s56UIgiPipUQRNd2sd1xW6GFfYNMBmrocRCNtvpYLbY,4977
@@ -214,13 +216,14 @@ model_compression_toolkit/core/keras/reader/nested_model/nodes_merger.py,sha256=
214
216
  model_compression_toolkit/core/keras/reader/nested_model/outputs_merger.py,sha256=dUzvNVzamauDLjgyjHweWux6T2vRko3anAuPxnaGpX8,2408
215
217
  model_compression_toolkit/core/keras/statistics_correction/__init__.py,sha256=9HIBmj8ROdCA-yvkpA8EcN6RHJe_2vEpLLW_gxOJtak,698
216
218
  model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py,sha256=XNCtT9klMcsO1v5KA3MmCq_WgXOIT5QSzbfTOa9T-04,3060
219
+ model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py,sha256=lq6yw9r1u0ZGA95JFvzsV-HQax66qAkJBmGeKnG9OrM,3409
217
220
  model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
218
221
  model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
219
222
  model_compression_toolkit/core/pytorch/constants.py,sha256=YwD_joIF0vK8UG2vW1NVvg36pCNWA0vHOXjAgy_XWn0,2794
220
223
  model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
221
224
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=-Vls1P_8Ckm_18nnOsmQkZ71SmzHwtQLbQ383Z4Rb-U,4365
222
225
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
223
- model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=TWA5Eu_85TIoCii1Owx2yx_ECckOnGg7xgQkiueuZPE,28245
226
+ model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=4uzO-lXfuitlC3NHx5-k2Fjm8VHa1T7ox9c8DSxYs9M,29437
224
227
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
225
228
  model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=xpKj99OZKT9NT0vKIl_cOe8d89d2gef1gKoNT6PFElE,4989
226
229
  model_compression_toolkit/core/pytorch/utils.py,sha256=7VbgcLwtQvdEEc_AJgSOQ3U3KRKCICFPaBirN1fIQxg,3940
@@ -274,6 +277,7 @@ model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=7XNc7-l1MZP
274
277
  model_compression_toolkit/core/pytorch/reader/reader.py,sha256=GEJE0QX8XJFWbYCkbRBtzttZtmmuoACLx8gw9KyAQCE,6015
275
278
  model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
276
279
  model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py,sha256=VgU24J3jf7QComHH7jonOXSkg6mO4TOch3uFkOthZvM,3261
280
+ model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py,sha256=N-9QaEaQYUsIoya9Lc0ZDoMZ0fkiT2gFpOd4zXHKP34,3096
277
281
  model_compression_toolkit/data_generation/__init__.py,sha256=9xLN7VE3lnYVjoroYfJ24dxK_-kGEbMmMVeS1PPkPEY,1513
278
282
  model_compression_toolkit/data_generation/common/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
279
283
  model_compression_toolkit/data_generation/common/constants.py,sha256=21e3ZX9WVYojexG2acTgklrBk8ZO9DjJnKpP4KHZC44,1018
@@ -554,24 +558,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
554
558
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
555
559
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
556
560
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
557
- tests_pytest/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
558
- tests_pytest/keras/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
559
- tests_pytest/keras/core/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
560
- tests_pytest/keras/core/test_data_util.py,sha256=XSoPu_ci1xy2EtK-3OWGpESr-Meg1GDaxuSvcj3yt-w,3915
561
- tests_pytest/keras/gptq/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
562
- tests_pytest/keras/gptq/test_gradual_act_quantization.py,sha256=iwKaLI7QQ8H3qj6zmwwfd2ZOwRcCr8T-v_4llSh_chM,4804
563
- tests_pytest/keras/trainable_infrastructure/__init__.py,sha256=DvaMXJtJZHAqOm96NdfBiNQsbN2sc9bG2kkyY-mpPh8,710
564
- tests_pytest/keras/trainable_infrastructure/test_linear_annealing.py,sha256=dZjrMHVIiEVRNDYR3a4lZaXF2ElxFx32KAXXQvDz-v8,1793
565
- tests_pytest/pytorch/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
566
- tests_pytest/pytorch/core/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
567
- tests_pytest/pytorch/core/test_data_util.py,sha256=Bg3c21YVfXE1SAUlTao553gXcITTKF4CPeKtl3peBTE,5604
568
- tests_pytest/pytorch/gptq/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
569
- tests_pytest/pytorch/gptq/test_annealing_cfg.py,sha256=hGC7L6mp3N1ygcJ3OctgS_Fz2JY75q5aswolJkbHkZM,2208
570
- tests_pytest/pytorch/gptq/test_gradual_act_quantization.py,sha256=Dg2cg1X8u9Jxm7Y6tlZIGH81EPoW_vYorcdDExdj02w,4630
571
- tests_pytest/pytorch/trainable_infrastructure/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
572
- tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py,sha256=zErt9tOu7oupjpv08cvd1Cxvdk9qvP7GMUP6EhefK0c,1814
573
- mct_nightly-2.2.0.20241106.458.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
574
- mct_nightly-2.2.0.20241106.458.dist-info/METADATA,sha256=OvQyuNyvb2ucuyM03TFlWlAicuXkgODpKoR9u4zQ8NI,20830
575
- mct_nightly-2.2.0.20241106.458.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
576
- mct_nightly-2.2.0.20241106.458.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
577
- mct_nightly-2.2.0.20241106.458.dist-info/RECORD,,
561
+ mct_nightly-2.2.0.20241107.459.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
562
+ mct_nightly-2.2.0.20241107.459.dist-info/METADATA,sha256=kpUYeQnvsDFhkybX7mLdLyEG8PUb32Dgmj_Id-faTyI,20830
563
+ mct_nightly-2.2.0.20241107.459.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
564
+ mct_nightly-2.2.0.20241107.459.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
565
+ mct_nightly-2.2.0.20241107.459.dist-info/RECORD,,
@@ -1,2 +1 @@
1
1
  model_compression_toolkit
2
- tests_pytest
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20241106.000458"
30
+ __version__ = "2.2.0.20241107.000459"
@@ -64,7 +64,7 @@ class FrameworkImplementation(ABC):
64
64
 
65
65
  Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request.
66
66
  """
67
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
67
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
68
68
  f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover
69
69
 
70
70
  @abstractmethod
@@ -77,7 +77,7 @@ class FrameworkImplementation(ABC):
77
77
  Returns:
78
78
  Numpy array converted from the input tensor.
79
79
  """
80
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
80
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
81
81
  f'framework\'s to_numpy method.') # pragma: no cover
82
82
 
83
83
  @abstractmethod
@@ -90,7 +90,7 @@ class FrameworkImplementation(ABC):
90
90
  Returns:
91
91
  Framework's tensor converted from the input Numpy array.
92
92
  """
93
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
93
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
94
94
  f'framework\'s to_tensor method.') # pragma: no cover
95
95
 
96
96
  @abstractmethod
@@ -106,7 +106,7 @@ class FrameworkImplementation(ABC):
106
106
  Returns:
107
107
  Graph representing the input model.
108
108
  """
109
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
109
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
110
110
  f'framework\'s model_reader method.') # pragma: no cover
111
111
 
112
112
  @abstractmethod
@@ -131,7 +131,7 @@ class FrameworkImplementation(ABC):
131
131
  Returns:
132
132
  A tuple with the model and additional relevant supporting objects.
133
133
  """
134
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
134
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
135
135
  f'framework\'s model_builder method.') # pragma: no cover
136
136
 
137
137
  @abstractmethod
@@ -148,7 +148,7 @@ class FrameworkImplementation(ABC):
148
148
  Returns:
149
149
  The frameworks model's output.
150
150
  """
151
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
151
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
152
152
  f'framework\'s run_model_inference method.') # pragma: no cover
153
153
 
154
154
  @abstractmethod
@@ -167,9 +167,28 @@ class FrameworkImplementation(ABC):
167
167
  Returns:
168
168
  Graph after SNC.
169
169
  """
170
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
170
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
171
171
  f'framework\'s apply_shift_negative_correction method.') # pragma: no cover
172
172
 
173
+ @abstractmethod
174
+ def compute_activation_bias_correction(self,
175
+ graph: Graph,
176
+ quant_config: QuantizationConfig,
177
+ fw_info: FrameworkInfo) -> Graph:
178
+ """
179
+ Compute activation bias correction on a graph.
180
+
181
+ Args:
182
+ graph: Graph to apply activation bias correction on.
183
+ quant_config: QuantizationConfig of how the model should be quantized.
184
+ fw_info: FrameworkInfo object with information about the specific framework's model.
185
+
186
+ Returns:
187
+ Graph after activation bias correction computing.
188
+ """
189
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
190
+ f'framework\'s compute_activation_bias_correction method.') # pragma: no cover
191
+
173
192
  @abstractmethod
174
193
  def get_substitutions_channel_equalization(self,
175
194
  quant_config: QuantizationConfig,
@@ -184,7 +203,7 @@ class FrameworkImplementation(ABC):
184
203
  Returns:
185
204
  A list of the framework substitutions used after we collect statistics.
186
205
  """
187
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
206
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
188
207
  f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover
189
208
 
190
209
  @abstractmethod
@@ -194,7 +213,7 @@ class FrameworkImplementation(ABC):
194
213
  Returns: A list of the framework substitutions used to prepare the graph.
195
214
 
196
215
  """
197
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
216
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
198
217
  f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover
199
218
 
200
219
  @abstractmethod
@@ -208,7 +227,7 @@ class FrameworkImplementation(ABC):
208
227
  Returns: A list of the framework substitutions used before we collect statistics.
209
228
 
210
229
  """
211
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
230
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
212
231
  f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover
213
232
 
214
233
  @abstractmethod
@@ -216,7 +235,7 @@ class FrameworkImplementation(ABC):
216
235
  """
217
236
  Returns: linear collapsing substitution
218
237
  """
219
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
238
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
220
239
  f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover
221
240
 
222
241
  @abstractmethod
@@ -224,7 +243,7 @@ class FrameworkImplementation(ABC):
224
243
  """
225
244
  Returns: conv2d add const collapsing substitution
226
245
  """
227
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
246
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
228
247
  f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover
229
248
 
230
249
  @abstractmethod
@@ -239,7 +258,7 @@ class FrameworkImplementation(ABC):
239
258
  Returns:
240
259
  A list of the framework substitutions used for statistics correction.
241
260
  """
242
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
261
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
243
262
  f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover
244
263
 
245
264
  @abstractmethod
@@ -247,7 +266,7 @@ class FrameworkImplementation(ABC):
247
266
  """
248
267
  Returns: A list of the framework substitutions used for residual collapsing
249
268
  """
250
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
269
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
251
270
  f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover
252
271
 
253
272
 
@@ -263,7 +282,7 @@ class FrameworkImplementation(ABC):
263
282
  Returns:
264
283
  A list of the framework substitutions used after we collect statistics.
265
284
  """
266
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
285
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
267
286
  f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover
268
287
 
269
288
  @abstractmethod
@@ -272,7 +291,7 @@ class FrameworkImplementation(ABC):
272
291
  Returns: A list of Keras substitutions used to build a virtual graph with composed activation-weights pairs.
273
292
  """
274
293
 
275
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
294
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
276
295
  f'framework\'s get_substitutions_virtual_weights_activation_coupling '
277
296
  f'method.') # pragma: no cover
278
297
 
@@ -288,7 +307,7 @@ class FrameworkImplementation(ABC):
288
307
  Returns:
289
308
  A list of the framework substitutions used after we apply second moment statistics.
290
309
  """
291
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
310
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
292
311
  f'framework\'s get_substitutions_after_second_moment_correction '
293
312
  f'method.') # pragma: no cover
294
313
 
@@ -316,7 +335,7 @@ class FrameworkImplementation(ABC):
316
335
  A function that computes the metric.
317
336
  """
318
337
 
319
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
338
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
320
339
  f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover
321
340
 
322
341
  def get_node_prior_info(self, node: BaseNode,
@@ -334,7 +353,7 @@ class FrameworkImplementation(ABC):
334
353
  NodePriorInfo with information about the node.
335
354
  """
336
355
 
337
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
356
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
338
357
  f'framework\'s get_node_prior_info method.') # pragma: no cover
339
358
 
340
359
  def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
@@ -345,7 +364,7 @@ class FrameworkImplementation(ABC):
345
364
  Returns: True if the node should be considered an interest point, False otherwise.
346
365
  """
347
366
 
348
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
367
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
349
368
  f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover
350
369
 
351
370
  def get_mp_node_distance_fn(self, n: BaseNode,
@@ -364,7 +383,7 @@ class FrameworkImplementation(ABC):
364
383
  Returns: A distance function between two tensors and a axis on which the distance is computed (if exists).
365
384
  """
366
385
 
367
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
386
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
368
387
  f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover
369
388
 
370
389
 
@@ -381,7 +400,7 @@ class FrameworkImplementation(ABC):
381
400
 
382
401
  """
383
402
 
384
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
403
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
385
404
  f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover
386
405
 
387
406
  @abstractmethod
@@ -398,7 +417,7 @@ class FrameworkImplementation(ABC):
398
417
  Returns: The MAC count of the operation
399
418
  """
400
419
 
401
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
420
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
402
421
  f'framework\'s get_node_mac_operations method.') # pragma: no cover
403
422
 
404
423
  @abstractmethod
@@ -419,7 +438,7 @@ class FrameworkImplementation(ABC):
419
438
  Returns:
420
439
  A Graph after second moment correction.
421
440
  """
422
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
441
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
423
442
  f'framework\'s apply_second_moment_correction method.') # pragma: no cover
424
443
 
425
444
  @abstractmethod
@@ -436,7 +455,7 @@ class FrameworkImplementation(ABC):
436
455
  Returns:
437
456
  The output of the model inference on the given input.
438
457
  """
439
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
458
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
440
459
  f'framework\'s sensitivity_eval_inference method.') # pragma: no cover
441
460
 
442
461
  def get_inferable_quantizers(self, node: BaseNode):
@@ -452,9 +471,9 @@ class FrameworkImplementation(ABC):
452
471
 
453
472
  """
454
473
 
455
- raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
474
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
456
475
  f'framework\'s get_inferable_quantizers method.') # pragma: no cover
457
-
476
+
458
477
  @staticmethod
459
478
  def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
460
479
  """
@@ -95,7 +95,9 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
95
95
  self.activation_error_method = qc.activation_error_method
96
96
  self.activation_n_bits = op_cfg.activation_n_bits
97
97
  self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
98
+ self.activation_bias_correction_term = None
98
99
  self.enable_activation_quantization = op_cfg.enable_activation_quantization
100
+ self.quantization_preserving = op_cfg.quantization_preserving
99
101
  self.signedness = op_cfg.signedness
100
102
  self.activation_channel_equalization = qc.activation_channel_equalization
101
103
  self.input_scaling = qc.input_scaling
@@ -84,6 +84,8 @@ class QuantizationConfig:
84
84
  shift_negative_threshold_recalculation: bool = False
85
85
  shift_negative_params_search: bool = False
86
86
  concat_threshold_update: bool = False
87
+ activation_bias_correction: bool = False
88
+ activation_bias_correction_threshold: float = 0.0
87
89
 
88
90
 
89
91
  # Default quantization configuration the library use.
@@ -0,0 +1,81 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from model_compression_toolkit.core import CoreConfig, QuantizationConfig
17
+ from model_compression_toolkit.core.common import BaseNode, Graph
18
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
19
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig
21
+
22
+
23
+ def apply_activation_bias_correction_to_graph(graph: Graph,
24
+ core_config: CoreConfig,
25
+ fw_impl: FrameworkImplementation) -> Graph:
26
+ """
27
+ Get a graph, where each node has a final activation quantization configuration (with an activation bias
28
+ correction term in it), and apply the activation bias correction for each node in the graph.
29
+
30
+ Args:
31
+ graph: Graph to apply activation bias correction to.
32
+ core_config: CoreConfig containing parameters of how the model should be quantized.
33
+ fw_impl: FrameworkImplementation object with a specific framework methods implementation.
34
+
35
+ Returns:
36
+ Graph with activation bias correction apply to it's nodes.
37
+ """
38
+
39
+ for n in graph.nodes:
40
+ # Activation bias correction is only relevant for nodes with kernel op
41
+ kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
42
+ if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \
43
+ n.final_activation_quantization_cfg.activation_bias_correction_term is not None:
44
+ # If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was
45
+ # calculated during model preparation, and is used now in the node's bias term.
46
+ _apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
47
+ return graph
48
+
49
+
50
+ def _apply_activation_bias_correction_to_node(node: BaseNode,
51
+ fw_impl: FrameworkImplementation,
52
+ qc: QuantizationConfig):
53
+ """
54
+ Set new bias to node using the activation bias correction term that is stored in the
55
+ final activation quantization configuration.
56
+
57
+ Args:
58
+ node: Node to set its corrected bias after activation bias correction.
59
+ fw_impl: FrameworkImplementation object with a specific framework methods implementation.
60
+ qc: QuantizationConfig containing parameters of how the model should be quantized.
61
+
62
+ """
63
+ correction = node.final_activation_quantization_cfg.activation_bias_correction_term
64
+ bias = node.get_weights_by_keys(fw_impl.constants.BIAS) # get original bias from node's weights
65
+
66
+ if bias is None:
67
+ # If the layer has no bias, we set the bias as -correction.
68
+ node.set_weights_by_keys(fw_impl.constants.BIAS, - correction)
69
+
70
+ # Mark the use_bias attribute of the node.
71
+ node.framework_attr[fw_impl.constants.USE_BIAS] = True
72
+
73
+ # Configure the quantization of the bias as disabled.
74
+ node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
75
+ WeightsAttrQuantizationConfig(
76
+ qc,
77
+ AttributeQuantizationConfig(
78
+ enable_weights_quantization=False)))
79
+ else:
80
+ # If the layer has bias, we subtract the correction from original bias
81
+ node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction)