mct-nightly 2.2.0.20241106.458__py3-none-any.whl → 2.2.0.20241108.459__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/RECORD +17 -29
- {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/top_level.txt +0 -1
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +46 -27
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -0
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -0
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +81 -0
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +190 -0
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +14 -2
- model_compression_toolkit/core/keras/keras_implementation.py +23 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +67 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +21 -0
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +57 -0
- model_compression_toolkit/core/runner.py +8 -0
- tests_pytest/__init__.py +0 -14
- tests_pytest/keras/__init__.py +0 -14
- tests_pytest/keras/core/__init__.py +0 -14
- tests_pytest/keras/core/test_data_util.py +0 -91
- tests_pytest/keras/gptq/__init__.py +0 -14
- tests_pytest/keras/gptq/test_gradual_act_quantization.py +0 -102
- tests_pytest/keras/trainable_infrastructure/__init__.py +0 -16
- tests_pytest/keras/trainable_infrastructure/test_linear_annealing.py +0 -49
- tests_pytest/pytorch/__init__.py +0 -14
- tests_pytest/pytorch/core/__init__.py +0 -14
- tests_pytest/pytorch/core/test_data_util.py +0 -125
- tests_pytest/pytorch/gptq/__init__.py +0 -14
- tests_pytest/pytorch/gptq/test_annealing_cfg.py +0 -40
- tests_pytest/pytorch/gptq/test_gradual_act_quantization.py +0 -100
- tests_pytest/pytorch/trainable_infrastructure/__init__.py +0 -14
- tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py +0 -49
- {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/WHEEL +0 -0
{mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=sV3EGoGSi45kJn43REtpqcqQutybRbJUFwnOwC2OByU,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -8,10 +8,10 @@ model_compression_toolkit/core/__init__.py,sha256=tnDtL9KmT0vsOU27SsJ19TKDEbIH-t
|
|
8
8
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
9
9
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=7-b7Jd5jBVaXOWg5nSqbEyzBtdaGDbCxs8aqMV6GZ6I,11287
|
10
10
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=OtL6g2rTC5mfdKrkzm47EPPW-voGGVYMYxpy2_sfu1U,6547
|
11
|
-
model_compression_toolkit/core/runner.py,sha256=
|
11
|
+
model_compression_toolkit/core/runner.py,sha256=IavCZRVG9RisEKvFDxz27WDRKrfIG03YKXKv3tcagPo,14700
|
12
12
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
13
13
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
14
|
-
model_compression_toolkit/core/common/framework_implementation.py,sha256=
|
14
|
+
model_compression_toolkit/core/common/framework_implementation.py,sha256=IkMydCj6voau7dwkYLYA_Ka_EFUKP3GKQdpYN6b1fgc,22163
|
15
15
|
model_compression_toolkit/core/common/framework_info.py,sha256=1ZMMGS9ip-kSflqkartyNRt9aQ5ub1WepuTRcTy-YSQ,6337
|
16
16
|
model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
|
17
17
|
model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
|
@@ -105,8 +105,8 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
|
|
105
105
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
|
106
106
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=zJP2W9apUPX9RstpPWWK71wr9xJsg7j-s7lGV4_bQdc,1510
|
107
107
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
|
108
|
-
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=
|
109
|
-
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=
|
108
|
+
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=HmtyIQCQqhay-8oqU3rUHOeK6VhTtH9nuW24HigCUo0,26517
|
109
|
+
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=nBqwNhbDbWQGYbfazLPHrP_ZCCnjbL-k5q58T8yIAcc,3917
|
110
110
|
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=eyosbVdnCwed7oMQ19tqnh0VoyGZ_UAuD_UnNoXyBpo,2210
|
111
111
|
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
|
112
112
|
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
|
@@ -128,10 +128,12 @@ model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantiz
|
|
128
128
|
model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py,sha256=iEoWUPFQMcvZXHtLMe2_7L7IK25XcKiY6-d1_gArZs0,11880
|
129
129
|
model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py,sha256=wXExWHf5-0He7L4bpvFpKlx7FG4u3DAfNZiXPpOs_SQ,5521
|
130
130
|
model_compression_toolkit/core/common/statistics_correction/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
131
|
+
model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py,sha256=Aw2N7FSO7p1Kmh-tUjajV9pqrjMJQtgF5etG0WV9Le8,4440
|
131
132
|
model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py,sha256=xSWVDOODgbN0k4mjJWWtpawilOsqdm4O7Uw2hbA75EA,4669
|
132
133
|
model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py,sha256=C_nwhhitTd1pCto0nHZPn3fjIMOeDD7VIciumTR3s6k,5641
|
134
|
+
model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py,sha256=ov9-WYktWKqRquibwyARR81QVT9TfPWAoTTfnKOQSd0,9273
|
133
135
|
model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py,sha256=LaGhYES7HgIDf9Bi2KAG_mBzAWuum0J6AGmAFPC8wwo,10478
|
134
|
-
model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=
|
136
|
+
model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=E0ZA4edimJwpHh9twI5gafcoJ9fX5F1JX2QUOkUOKEw,6250
|
135
137
|
model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
136
138
|
model_compression_toolkit/core/common/substitutions/apply_substitutions.py,sha256=k-bifmakHIYZeZS-4T1QpZ1Et6AwAijMRgAKs7hmMKc,1390
|
137
139
|
model_compression_toolkit/core/common/substitutions/batchnorm_folding.py,sha256=wLlTT7sqUffKHwOrMG2VV5SktQkkP54l8taW1Fq0mh0,13392
|
@@ -155,7 +157,7 @@ model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiO
|
|
155
157
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
156
158
|
model_compression_toolkit/core/keras/data_util.py,sha256=JdomIJZfep0QYPtx2jlg0xJ40cd9S_I7BakaWQi0wKw,2681
|
157
159
|
model_compression_toolkit/core/keras/default_framework_info.py,sha256=PYcER89eEXjKtR0T7-2Y4f7cckqoD5OQbpHePoRkMec,5030
|
158
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
160
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=Hi8seiFJdFqgYGGC003Y4879JQ7rmVZe8YiJ76T7FDE,32133
|
159
161
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
160
162
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
|
161
163
|
model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=s56UIgiPipUQRNd2sd1xW6GFfYNMBmrocRCNtvpYLbY,4977
|
@@ -214,13 +216,14 @@ model_compression_toolkit/core/keras/reader/nested_model/nodes_merger.py,sha256=
|
|
214
216
|
model_compression_toolkit/core/keras/reader/nested_model/outputs_merger.py,sha256=dUzvNVzamauDLjgyjHweWux6T2vRko3anAuPxnaGpX8,2408
|
215
217
|
model_compression_toolkit/core/keras/statistics_correction/__init__.py,sha256=9HIBmj8ROdCA-yvkpA8EcN6RHJe_2vEpLLW_gxOJtak,698
|
216
218
|
model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py,sha256=XNCtT9klMcsO1v5KA3MmCq_WgXOIT5QSzbfTOa9T-04,3060
|
219
|
+
model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py,sha256=lq6yw9r1u0ZGA95JFvzsV-HQax66qAkJBmGeKnG9OrM,3409
|
217
220
|
model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
218
221
|
model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
219
222
|
model_compression_toolkit/core/pytorch/constants.py,sha256=YwD_joIF0vK8UG2vW1NVvg36pCNWA0vHOXjAgy_XWn0,2794
|
220
223
|
model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
|
221
224
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=-Vls1P_8Ckm_18nnOsmQkZ71SmzHwtQLbQ383Z4Rb-U,4365
|
222
225
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
|
223
|
-
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=
|
226
|
+
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=4uzO-lXfuitlC3NHx5-k2Fjm8VHa1T7ox9c8DSxYs9M,29437
|
224
227
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
|
225
228
|
model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=xpKj99OZKT9NT0vKIl_cOe8d89d2gef1gKoNT6PFElE,4989
|
226
229
|
model_compression_toolkit/core/pytorch/utils.py,sha256=7VbgcLwtQvdEEc_AJgSOQ3U3KRKCICFPaBirN1fIQxg,3940
|
@@ -274,6 +277,7 @@ model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=7XNc7-l1MZP
|
|
274
277
|
model_compression_toolkit/core/pytorch/reader/reader.py,sha256=GEJE0QX8XJFWbYCkbRBtzttZtmmuoACLx8gw9KyAQCE,6015
|
275
278
|
model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
276
279
|
model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py,sha256=VgU24J3jf7QComHH7jonOXSkg6mO4TOch3uFkOthZvM,3261
|
280
|
+
model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py,sha256=N-9QaEaQYUsIoya9Lc0ZDoMZ0fkiT2gFpOd4zXHKP34,3096
|
277
281
|
model_compression_toolkit/data_generation/__init__.py,sha256=9xLN7VE3lnYVjoroYfJ24dxK_-kGEbMmMVeS1PPkPEY,1513
|
278
282
|
model_compression_toolkit/data_generation/common/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
279
283
|
model_compression_toolkit/data_generation/common/constants.py,sha256=21e3ZX9WVYojexG2acTgklrBk8ZO9DjJnKpP4KHZC44,1018
|
@@ -554,24 +558,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
554
558
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
|
555
559
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
556
560
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
tests_pytest/keras/gptq/test_gradual_act_quantization.py,sha256=iwKaLI7QQ8H3qj6zmwwfd2ZOwRcCr8T-v_4llSh_chM,4804
|
563
|
-
tests_pytest/keras/trainable_infrastructure/__init__.py,sha256=DvaMXJtJZHAqOm96NdfBiNQsbN2sc9bG2kkyY-mpPh8,710
|
564
|
-
tests_pytest/keras/trainable_infrastructure/test_linear_annealing.py,sha256=dZjrMHVIiEVRNDYR3a4lZaXF2ElxFx32KAXXQvDz-v8,1793
|
565
|
-
tests_pytest/pytorch/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
566
|
-
tests_pytest/pytorch/core/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
567
|
-
tests_pytest/pytorch/core/test_data_util.py,sha256=Bg3c21YVfXE1SAUlTao553gXcITTKF4CPeKtl3peBTE,5604
|
568
|
-
tests_pytest/pytorch/gptq/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
569
|
-
tests_pytest/pytorch/gptq/test_annealing_cfg.py,sha256=hGC7L6mp3N1ygcJ3OctgS_Fz2JY75q5aswolJkbHkZM,2208
|
570
|
-
tests_pytest/pytorch/gptq/test_gradual_act_quantization.py,sha256=Dg2cg1X8u9Jxm7Y6tlZIGH81EPoW_vYorcdDExdj02w,4630
|
571
|
-
tests_pytest/pytorch/trainable_infrastructure/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
572
|
-
tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py,sha256=zErt9tOu7oupjpv08cvd1Cxvdk9qvP7GMUP6EhefK0c,1814
|
573
|
-
mct_nightly-2.2.0.20241106.458.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
574
|
-
mct_nightly-2.2.0.20241106.458.dist-info/METADATA,sha256=OvQyuNyvb2ucuyM03TFlWlAicuXkgODpKoR9u4zQ8NI,20830
|
575
|
-
mct_nightly-2.2.0.20241106.458.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
576
|
-
mct_nightly-2.2.0.20241106.458.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
|
577
|
-
mct_nightly-2.2.0.20241106.458.dist-info/RECORD,,
|
561
|
+
mct_nightly-2.2.0.20241108.459.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
562
|
+
mct_nightly-2.2.0.20241108.459.dist-info/METADATA,sha256=BKcV37WvzOfa_tSqxVnw6l2_7oSMLjrOFCVEizArBSE,20830
|
563
|
+
mct_nightly-2.2.0.20241108.459.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
564
|
+
mct_nightly-2.2.0.20241108.459.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
565
|
+
mct_nightly-2.2.0.20241108.459.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.2.0.
|
30
|
+
__version__ = "2.2.0.20241108.000459"
|
@@ -64,7 +64,7 @@ class FrameworkImplementation(ABC):
|
|
64
64
|
|
65
65
|
Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request.
|
66
66
|
"""
|
67
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
67
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
68
68
|
f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover
|
69
69
|
|
70
70
|
@abstractmethod
|
@@ -77,7 +77,7 @@ class FrameworkImplementation(ABC):
|
|
77
77
|
Returns:
|
78
78
|
Numpy array converted from the input tensor.
|
79
79
|
"""
|
80
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
80
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
81
81
|
f'framework\'s to_numpy method.') # pragma: no cover
|
82
82
|
|
83
83
|
@abstractmethod
|
@@ -90,7 +90,7 @@ class FrameworkImplementation(ABC):
|
|
90
90
|
Returns:
|
91
91
|
Framework's tensor converted from the input Numpy array.
|
92
92
|
"""
|
93
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
93
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
94
94
|
f'framework\'s to_tensor method.') # pragma: no cover
|
95
95
|
|
96
96
|
@abstractmethod
|
@@ -106,7 +106,7 @@ class FrameworkImplementation(ABC):
|
|
106
106
|
Returns:
|
107
107
|
Graph representing the input model.
|
108
108
|
"""
|
109
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
109
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
110
110
|
f'framework\'s model_reader method.') # pragma: no cover
|
111
111
|
|
112
112
|
@abstractmethod
|
@@ -131,7 +131,7 @@ class FrameworkImplementation(ABC):
|
|
131
131
|
Returns:
|
132
132
|
A tuple with the model and additional relevant supporting objects.
|
133
133
|
"""
|
134
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
134
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
135
135
|
f'framework\'s model_builder method.') # pragma: no cover
|
136
136
|
|
137
137
|
@abstractmethod
|
@@ -148,7 +148,7 @@ class FrameworkImplementation(ABC):
|
|
148
148
|
Returns:
|
149
149
|
The frameworks model's output.
|
150
150
|
"""
|
151
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
151
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
152
152
|
f'framework\'s run_model_inference method.') # pragma: no cover
|
153
153
|
|
154
154
|
@abstractmethod
|
@@ -167,9 +167,28 @@ class FrameworkImplementation(ABC):
|
|
167
167
|
Returns:
|
168
168
|
Graph after SNC.
|
169
169
|
"""
|
170
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
170
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
171
171
|
f'framework\'s apply_shift_negative_correction method.') # pragma: no cover
|
172
172
|
|
173
|
+
@abstractmethod
|
174
|
+
def compute_activation_bias_correction(self,
|
175
|
+
graph: Graph,
|
176
|
+
quant_config: QuantizationConfig,
|
177
|
+
fw_info: FrameworkInfo) -> Graph:
|
178
|
+
"""
|
179
|
+
Compute activation bias correction on a graph.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
graph: Graph to apply activation bias correction on.
|
183
|
+
quant_config: QuantizationConfig of how the model should be quantized.
|
184
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
Graph after activation bias correction computing.
|
188
|
+
"""
|
189
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
190
|
+
f'framework\'s compute_activation_bias_correction method.') # pragma: no cover
|
191
|
+
|
173
192
|
@abstractmethod
|
174
193
|
def get_substitutions_channel_equalization(self,
|
175
194
|
quant_config: QuantizationConfig,
|
@@ -184,7 +203,7 @@ class FrameworkImplementation(ABC):
|
|
184
203
|
Returns:
|
185
204
|
A list of the framework substitutions used after we collect statistics.
|
186
205
|
"""
|
187
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
206
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
188
207
|
f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover
|
189
208
|
|
190
209
|
@abstractmethod
|
@@ -194,7 +213,7 @@ class FrameworkImplementation(ABC):
|
|
194
213
|
Returns: A list of the framework substitutions used to prepare the graph.
|
195
214
|
|
196
215
|
"""
|
197
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
216
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
198
217
|
f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover
|
199
218
|
|
200
219
|
@abstractmethod
|
@@ -208,7 +227,7 @@ class FrameworkImplementation(ABC):
|
|
208
227
|
Returns: A list of the framework substitutions used before we collect statistics.
|
209
228
|
|
210
229
|
"""
|
211
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
230
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
212
231
|
f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover
|
213
232
|
|
214
233
|
@abstractmethod
|
@@ -216,7 +235,7 @@ class FrameworkImplementation(ABC):
|
|
216
235
|
"""
|
217
236
|
Returns: linear collapsing substitution
|
218
237
|
"""
|
219
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
238
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
220
239
|
f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover
|
221
240
|
|
222
241
|
@abstractmethod
|
@@ -224,7 +243,7 @@ class FrameworkImplementation(ABC):
|
|
224
243
|
"""
|
225
244
|
Returns: conv2d add const collapsing substitution
|
226
245
|
"""
|
227
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
246
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
228
247
|
f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover
|
229
248
|
|
230
249
|
@abstractmethod
|
@@ -239,7 +258,7 @@ class FrameworkImplementation(ABC):
|
|
239
258
|
Returns:
|
240
259
|
A list of the framework substitutions used for statistics correction.
|
241
260
|
"""
|
242
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
261
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
243
262
|
f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover
|
244
263
|
|
245
264
|
@abstractmethod
|
@@ -247,7 +266,7 @@ class FrameworkImplementation(ABC):
|
|
247
266
|
"""
|
248
267
|
Returns: A list of the framework substitutions used for residual collapsing
|
249
268
|
"""
|
250
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
269
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
251
270
|
f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover
|
252
271
|
|
253
272
|
|
@@ -263,7 +282,7 @@ class FrameworkImplementation(ABC):
|
|
263
282
|
Returns:
|
264
283
|
A list of the framework substitutions used after we collect statistics.
|
265
284
|
"""
|
266
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
285
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
267
286
|
f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover
|
268
287
|
|
269
288
|
@abstractmethod
|
@@ -272,7 +291,7 @@ class FrameworkImplementation(ABC):
|
|
272
291
|
Returns: A list of Keras substitutions used to build a virtual graph with composed activation-weights pairs.
|
273
292
|
"""
|
274
293
|
|
275
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
294
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
276
295
|
f'framework\'s get_substitutions_virtual_weights_activation_coupling '
|
277
296
|
f'method.') # pragma: no cover
|
278
297
|
|
@@ -288,7 +307,7 @@ class FrameworkImplementation(ABC):
|
|
288
307
|
Returns:
|
289
308
|
A list of the framework substitutions used after we apply second moment statistics.
|
290
309
|
"""
|
291
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
310
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
292
311
|
f'framework\'s get_substitutions_after_second_moment_correction '
|
293
312
|
f'method.') # pragma: no cover
|
294
313
|
|
@@ -316,7 +335,7 @@ class FrameworkImplementation(ABC):
|
|
316
335
|
A function that computes the metric.
|
317
336
|
"""
|
318
337
|
|
319
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
338
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
320
339
|
f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover
|
321
340
|
|
322
341
|
def get_node_prior_info(self, node: BaseNode,
|
@@ -334,7 +353,7 @@ class FrameworkImplementation(ABC):
|
|
334
353
|
NodePriorInfo with information about the node.
|
335
354
|
"""
|
336
355
|
|
337
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
356
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
338
357
|
f'framework\'s get_node_prior_info method.') # pragma: no cover
|
339
358
|
|
340
359
|
def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
|
@@ -345,7 +364,7 @@ class FrameworkImplementation(ABC):
|
|
345
364
|
Returns: True if the node should be considered an interest point, False otherwise.
|
346
365
|
"""
|
347
366
|
|
348
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
367
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
349
368
|
f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover
|
350
369
|
|
351
370
|
def get_mp_node_distance_fn(self, n: BaseNode,
|
@@ -364,7 +383,7 @@ class FrameworkImplementation(ABC):
|
|
364
383
|
Returns: A distance function between two tensors and a axis on which the distance is computed (if exists).
|
365
384
|
"""
|
366
385
|
|
367
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
386
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
368
387
|
f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover
|
369
388
|
|
370
389
|
|
@@ -381,7 +400,7 @@ class FrameworkImplementation(ABC):
|
|
381
400
|
|
382
401
|
"""
|
383
402
|
|
384
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
403
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
385
404
|
f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover
|
386
405
|
|
387
406
|
@abstractmethod
|
@@ -398,7 +417,7 @@ class FrameworkImplementation(ABC):
|
|
398
417
|
Returns: The MAC count of the operation
|
399
418
|
"""
|
400
419
|
|
401
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
420
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
402
421
|
f'framework\'s get_node_mac_operations method.') # pragma: no cover
|
403
422
|
|
404
423
|
@abstractmethod
|
@@ -419,7 +438,7 @@ class FrameworkImplementation(ABC):
|
|
419
438
|
Returns:
|
420
439
|
A Graph after second moment correction.
|
421
440
|
"""
|
422
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
441
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
423
442
|
f'framework\'s apply_second_moment_correction method.') # pragma: no cover
|
424
443
|
|
425
444
|
@abstractmethod
|
@@ -436,7 +455,7 @@ class FrameworkImplementation(ABC):
|
|
436
455
|
Returns:
|
437
456
|
The output of the model inference on the given input.
|
438
457
|
"""
|
439
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
458
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
440
459
|
f'framework\'s sensitivity_eval_inference method.') # pragma: no cover
|
441
460
|
|
442
461
|
def get_inferable_quantizers(self, node: BaseNode):
|
@@ -452,9 +471,9 @@ class FrameworkImplementation(ABC):
|
|
452
471
|
|
453
472
|
"""
|
454
473
|
|
455
|
-
raise NotImplementedError(f'{self.__class__.__name__}
|
474
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
456
475
|
f'framework\'s get_inferable_quantizers method.') # pragma: no cover
|
457
|
-
|
476
|
+
|
458
477
|
@staticmethod
|
459
478
|
def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
|
460
479
|
"""
|
@@ -95,7 +95,9 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
95
95
|
self.activation_error_method = qc.activation_error_method
|
96
96
|
self.activation_n_bits = op_cfg.activation_n_bits
|
97
97
|
self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
|
98
|
+
self.activation_bias_correction_term = None
|
98
99
|
self.enable_activation_quantization = op_cfg.enable_activation_quantization
|
100
|
+
self.quantization_preserving = op_cfg.quantization_preserving
|
99
101
|
self.signedness = op_cfg.signedness
|
100
102
|
self.activation_channel_equalization = qc.activation_channel_equalization
|
101
103
|
self.input_scaling = qc.input_scaling
|
@@ -84,6 +84,8 @@ class QuantizationConfig:
|
|
84
84
|
shift_negative_threshold_recalculation: bool = False
|
85
85
|
shift_negative_params_search: bool = False
|
86
86
|
concat_threshold_update: bool = False
|
87
|
+
activation_bias_correction: bool = False
|
88
|
+
activation_bias_correction_threshold: float = 0.0
|
87
89
|
|
88
90
|
|
89
91
|
# Default quantization configuration the library use.
|
@@ -0,0 +1,81 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from model_compression_toolkit.core import CoreConfig, QuantizationConfig
|
17
|
+
from model_compression_toolkit.core.common import BaseNode, Graph
|
18
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
19
|
+
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig
|
21
|
+
|
22
|
+
|
23
|
+
def apply_activation_bias_correction_to_graph(graph: Graph,
|
24
|
+
core_config: CoreConfig,
|
25
|
+
fw_impl: FrameworkImplementation) -> Graph:
|
26
|
+
"""
|
27
|
+
Get a graph, where each node has a final activation quantization configuration (with an activation bias
|
28
|
+
correction term in it), and apply the activation bias correction for each node in the graph.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
graph: Graph to apply activation bias correction to.
|
32
|
+
core_config: CoreConfig containing parameters of how the model should be quantized.
|
33
|
+
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
Graph with activation bias correction apply to it's nodes.
|
37
|
+
"""
|
38
|
+
|
39
|
+
for n in graph.nodes:
|
40
|
+
# Activation bias correction is only relevant for nodes with kernel op
|
41
|
+
kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
|
42
|
+
if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \
|
43
|
+
n.final_activation_quantization_cfg.activation_bias_correction_term is not None:
|
44
|
+
# If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was
|
45
|
+
# calculated during model preparation, and is used now in the node's bias term.
|
46
|
+
_apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
|
47
|
+
return graph
|
48
|
+
|
49
|
+
|
50
|
+
def _apply_activation_bias_correction_to_node(node: BaseNode,
|
51
|
+
fw_impl: FrameworkImplementation,
|
52
|
+
qc: QuantizationConfig):
|
53
|
+
"""
|
54
|
+
Set new bias to node using the activation bias correction term that is stored in the
|
55
|
+
final activation quantization configuration.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
node: Node to set its corrected bias after activation bias correction.
|
59
|
+
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
60
|
+
qc: QuantizationConfig containing parameters of how the model should be quantized.
|
61
|
+
|
62
|
+
"""
|
63
|
+
correction = node.final_activation_quantization_cfg.activation_bias_correction_term
|
64
|
+
bias = node.get_weights_by_keys(fw_impl.constants.BIAS) # get original bias from node's weights
|
65
|
+
|
66
|
+
if bias is None:
|
67
|
+
# If the layer has no bias, we set the bias as -correction.
|
68
|
+
node.set_weights_by_keys(fw_impl.constants.BIAS, - correction)
|
69
|
+
|
70
|
+
# Mark the use_bias attribute of the node.
|
71
|
+
node.framework_attr[fw_impl.constants.USE_BIAS] = True
|
72
|
+
|
73
|
+
# Configure the quantization of the bias as disabled.
|
74
|
+
node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
|
75
|
+
WeightsAttrQuantizationConfig(
|
76
|
+
qc,
|
77
|
+
AttributeQuantizationConfig(
|
78
|
+
enable_weights_quantization=False)))
|
79
|
+
else:
|
80
|
+
# If the layer has bias, we subtract the correction from original bias
|
81
|
+
node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction)
|