mct-nightly 0.0.0__py3-none-any.whl → 1.1.0.02122021-003117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/METADATA +3 -2
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/RECORD +31 -38
- model_compression_toolkit/__init__.py +2 -6
- model_compression_toolkit/common/base_substitutions.py +1 -0
- model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +9 -12
- model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +8 -21
- model_compression_toolkit/common/collectors/histogram_collector.py +1 -1
- model_compression_toolkit/common/graph/base_graph.py +2 -4
- model_compression_toolkit/common/graph/graph_matchers.py +3 -1
- model_compression_toolkit/common/graph/graph_searches.py +3 -1
- model_compression_toolkit/common/mixed_precision/bit_width_setter.py +1 -2
- model_compression_toolkit/common/network_editors/node_filters.py +1 -0
- model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py +1 -1
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +3 -5
- model_compression_toolkit/common/quantization/quantize_graph_weights.py +4 -7
- model_compression_toolkit/common/quantization/quantize_node.py +3 -5
- model_compression_toolkit/keras/__init__.py +2 -0
- model_compression_toolkit/keras/back2framework/model_builder.py +24 -1
- model_compression_toolkit/{common → keras/back2framework}/model_collector.py +9 -18
- model_compression_toolkit/keras/default_framework_info.py +0 -1
- model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +57 -10
- model_compression_toolkit/keras/graph_substitutions/substituter.py +171 -0
- model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +26 -6
- model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +12 -5
- model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +3 -4
- model_compression_toolkit/keras/quantization_facade.py +524 -188
- model_compression_toolkit/keras/reader/connectivity_handler.py +4 -1
- model_compression_toolkit/keras/visualization/nn_visualizer.py +1 -2
- model_compression_toolkit/common/framework_implementation.py +0 -239
- model_compression_toolkit/common/gptq/__init__.py +0 -14
- model_compression_toolkit/common/gptq/gptq_config.py +0 -65
- model_compression_toolkit/common/model_builder_mode.py +0 -34
- model_compression_toolkit/common/post_training_quantization.py +0 -459
- model_compression_toolkit/common/substitutions/__init__.py +0 -14
- model_compression_toolkit/common/substitutions/apply_substitutions.py +0 -40
- model_compression_toolkit/keras/keras_implementation.py +0 -256
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/LICENSE +0 -0
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/WHEEL +0 -0
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mct-nightly
|
|
3
|
-
Version:
|
|
3
|
+
Version: 1.1.0.02122021-003117
|
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
|
5
5
|
Home-page: UNKNOWN
|
|
6
6
|
License: UNKNOWN
|
|
@@ -11,6 +11,7 @@ Classifier: Operating System :: OS Independent
|
|
|
11
11
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
12
12
|
Requires-Python: >=3.6
|
|
13
13
|
Description-Content-Type: text/markdown
|
|
14
|
+
Requires-Dist: tensorflow (==2.5.*)
|
|
14
15
|
Requires-Dist: networkx (==2.5)
|
|
15
16
|
Requires-Dist: tqdm
|
|
16
17
|
Requires-Dist: Pillow
|
|
@@ -18,9 +19,9 @@ Requires-Dist: numpy
|
|
|
18
19
|
Requires-Dist: opencv-python
|
|
19
20
|
Requires-Dist: scikit-image
|
|
20
21
|
Requires-Dist: scikit-learn
|
|
22
|
+
Requires-Dist: tensorflow-model-optimization
|
|
21
23
|
Requires-Dist: tensorboard
|
|
22
24
|
Requires-Dist: PuLP
|
|
23
|
-
Requires-Dist: matplotlib
|
|
24
25
|
|
|
25
26
|
# Model Compression Toolkit (MCT)
|
|
26
27
|

|
|
@@ -1,33 +1,27 @@
|
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
|
1
|
+
model_compression_toolkit/__init__.py,sha256=3Jc6zSKbZJ2oGnNJMKKw54QWKZdtDMGoYLIGbG8IooU,1895
|
|
2
2
|
model_compression_toolkit/common/__init__.py,sha256=LtklANR21DvGL9oO4ila5u6Ft7M83-dqftICub6Vowk,1454
|
|
3
|
-
model_compression_toolkit/common/base_substitutions.py,sha256=
|
|
3
|
+
model_compression_toolkit/common/base_substitutions.py,sha256=eAQfGVDrzHw6qRCkdm5ENVJLlod633UWCI2qLQyyoCA,1657
|
|
4
4
|
model_compression_toolkit/common/constants.py,sha256=DWcVXmyONazCWZYp6LT-0Kav0bbMxfB8sDWjMRAxqxY,1472
|
|
5
5
|
model_compression_toolkit/common/data_loader.py,sha256=fqTPODRsWZAhkXExE4iP2kgtHv5kOw1UzBojWXIrq7c,4018
|
|
6
6
|
model_compression_toolkit/common/defaultdict.py,sha256=C9LVHx7D0WniqOsoiENkPNdUj49kiQEabESr6FRT5OA,2132
|
|
7
|
-
model_compression_toolkit/common/framework_implementation.py,sha256=mz3WI7wgXoo_jIs_XM-D1uCl7S-gEPml3UKDjAGri3w,9827
|
|
8
7
|
model_compression_toolkit/common/framework_info.py,sha256=mpPvVS8JYa987MYqmfj32iEsLY-vhqY8UGtgoUYK3_k,7723
|
|
9
8
|
model_compression_toolkit/common/logger.py,sha256=RrDE9VCndUrdbQ_4DL5VP8z10fr5BMCJvxswHIvb3Os,4674
|
|
10
|
-
model_compression_toolkit/common/model_builder_mode.py,sha256=G06Tpf-p7vBpa_k0Z8sNdbwhN_QEX-3lf0ajafSrQKA,1488
|
|
11
|
-
model_compression_toolkit/common/model_collector.py,sha256=c-wy-ln9o9FI7jgpS4pXEB3h81XLXwXKPrY0qjdHQuw,4665
|
|
12
|
-
model_compression_toolkit/common/post_training_quantization.py,sha256=T6T9YmoKgXZMdRuLKVZVVviXb7SaefnQ4HNsk5AiWv4,22575
|
|
13
9
|
model_compression_toolkit/common/similarity_analyzer.py,sha256=fB5oGzRY941XONCxnXRy43n4F1qxqXM8EuF31t__Izg,4047
|
|
14
10
|
model_compression_toolkit/common/statistics_collector.py,sha256=cv-qrVM9JY9JLyxfC5G2EzqPAdbOBMN4Gm3Qf9tOhIY,7973
|
|
15
11
|
model_compression_toolkit/common/user_info.py,sha256=OYUdFBma5TKrDLBXEjgo1Oj0RK0zpvkVmP_5ZJ0X-2c,1535
|
|
16
12
|
model_compression_toolkit/common/bias_correction/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
|
|
17
|
-
model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py,sha256=
|
|
18
|
-
model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py,sha256
|
|
13
|
+
model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py,sha256=i_Zky82Ow0_dpeSqp7z_5z2Zm-Yx-tQtY8R1RRIUfDE,2870
|
|
14
|
+
model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py,sha256=-pud80aNesybSqxvYSUiHOvK3CEcTcBvQZtxIGXjaNg,7424
|
|
19
15
|
model_compression_toolkit/common/collectors/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
|
|
20
16
|
model_compression_toolkit/common/collectors/base_collector.py,sha256=s1Kne14LcFUajwftMBrusu0wTPHCUQx4TegPcFSWVss,2427
|
|
21
|
-
model_compression_toolkit/common/collectors/histogram_collector.py,sha256=
|
|
17
|
+
model_compression_toolkit/common/collectors/histogram_collector.py,sha256=9Ypg35t-zdWSTryA-XvubHSD_PD8ZvJoQ0IliqFm46c,6843
|
|
22
18
|
model_compression_toolkit/common/collectors/mean_collector.py,sha256=x3n2dmZl2EuXh_LgEQnO7YMlL6G7s0nFltsIhXGx6W0,3835
|
|
23
19
|
model_compression_toolkit/common/collectors/min_max_per_channel_collector.py,sha256=ZzJkogbqDkVj_tiS3gzHfTsEkomroeHBcJrXod0ga4k,5031
|
|
24
|
-
model_compression_toolkit/common/gptq/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
|
|
25
|
-
model_compression_toolkit/common/gptq/gptq_config.py,sha256=XDfJgamLEyJKyq4rmhX5gqhM30a1DiJjKqoy3N6t1pM,2787
|
|
26
20
|
model_compression_toolkit/common/graph/__init__.py,sha256=3x8iejFh5xa6tUnOtu5tU7KFOKdRqADhNac89pDItWo,760
|
|
27
|
-
model_compression_toolkit/common/graph/base_graph.py,sha256=
|
|
21
|
+
model_compression_toolkit/common/graph/base_graph.py,sha256=2Bds0yeUA8HtF1Rg2moBwLRN5cG0p7wBjdIEVT3RJo8,18395
|
|
28
22
|
model_compression_toolkit/common/graph/edge.py,sha256=J1ksD-NHXNA_aMqoEQEaR81URLuAht2KNwu_ARwODHE,3712
|
|
29
|
-
model_compression_toolkit/common/graph/graph_matchers.py,sha256=
|
|
30
|
-
model_compression_toolkit/common/graph/graph_searches.py,sha256=
|
|
23
|
+
model_compression_toolkit/common/graph/graph_matchers.py,sha256=2k5yogBwXcxIAENZTIcpZ-lhCe23ZGqnERcd3RmTWik,4860
|
|
24
|
+
model_compression_toolkit/common/graph/graph_searches.py,sha256=o3xcFMtKJMElFGt0wcEMC0c48Qcv0azsF-e__Ti_TKk,5199
|
|
31
25
|
model_compression_toolkit/common/graph/graph_vis.py,sha256=NqWkLW75ap2n-RX_TmooVtxJ63iR3BgtfRR6f4eUWYs,3704
|
|
32
26
|
model_compression_toolkit/common/graph/node.py,sha256=LUi8l1-mzrI-0_USw1Zt-1j6xkHT4tDSAPx9OXBc5pI,7527
|
|
33
27
|
model_compression_toolkit/common/matchers/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
|
|
@@ -38,7 +32,7 @@ model_compression_toolkit/common/matchers/function.py,sha256=Aqe0gLQJUSAn-unHYF5
|
|
|
38
32
|
model_compression_toolkit/common/matchers/node_matcher.py,sha256=fyKrVaC_S1qU-Fx-SJVW0YxEDv7LMHDUZQyfMgbH-xY,2746
|
|
39
33
|
model_compression_toolkit/common/matchers/walk_matcher.py,sha256=wZuTitcKe4CeGizwb4P1_D5o1wd7ZfaKRA_tTng8FeQ,1112
|
|
40
34
|
model_compression_toolkit/common/mixed_precision/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
|
|
41
|
-
model_compression_toolkit/common/mixed_precision/bit_width_setter.py,sha256=
|
|
35
|
+
model_compression_toolkit/common/mixed_precision/bit_width_setter.py,sha256=b47YTTlz3emzcnFLTYUXyO9nG1AOHe1didsx6KqK8EY,6348
|
|
42
36
|
model_compression_toolkit/common/mixed_precision/distance_weighting.py,sha256=3a42gnctiGI5sLDTQj0WjSaaWp8AImM3nMQHNVpyies,2426
|
|
43
37
|
model_compression_toolkit/common/mixed_precision/kpi.py,sha256=eUYXmR7LJC2SG77XI2yy29UkS_2uimxcS03EzXZJox4,1279
|
|
44
38
|
model_compression_toolkit/common/mixed_precision/mixed_precision_quantization_config.py,sha256=fkTsB_-mbwMWfxXJxpStUBGU8Os3_vLI267oeWKS58A,3423
|
|
@@ -49,26 +43,26 @@ model_compression_toolkit/common/mixed_precision/search_methods/linear_programmi
|
|
|
49
43
|
model_compression_toolkit/common/network_editors/__init__.py,sha256=dJWrbHs2NeGYW7Kzy7m7waIw8Ounjd1tiiL0a4nRAos,1235
|
|
50
44
|
model_compression_toolkit/common/network_editors/actions.py,sha256=czJlYQnC6r38XJcpipvcCy2DQNd01QRwkIh_d3BLwSg,13118
|
|
51
45
|
model_compression_toolkit/common/network_editors/edit_network.py,sha256=KsfILh0y_n5kqkaAJDeuHtJ5Pgf4OW6oWjl13ySV0QQ,1723
|
|
52
|
-
model_compression_toolkit/common/network_editors/node_filters.py,sha256=
|
|
46
|
+
model_compression_toolkit/common/network_editors/node_filters.py,sha256=T3jMUoajr32BxS28NctEzMm5kBxD9nGSFZCaimAhpaI,3152
|
|
53
47
|
model_compression_toolkit/common/quantization/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
|
|
54
48
|
model_compression_toolkit/common/quantization/node_quantization_config.py,sha256=e81Yoqtp6e8saCDZFd2bYJMRtdxQ7aPCorwhYQSws6g,9359
|
|
55
49
|
model_compression_toolkit/common/quantization/quantization_analyzer.py,sha256=xcZ2BPvUZ5ATNStavKjWOcJnvx8OQ7yBG0O_GgcsvuE,2972
|
|
56
50
|
model_compression_toolkit/common/quantization/quantization_config.py,sha256=N2JZFO2AFlAqv9KbKYtIZlfKi2Dsb5RjNCNR8O76DPs,8503
|
|
57
51
|
model_compression_toolkit/common/quantization/quantization_params_fn_selection.py,sha256=WSarcRfTcG2t62JCXIIucbe86oM_856vCb_1ygn-FeM,5448
|
|
58
|
-
model_compression_toolkit/common/quantization/quantize_graph_weights.py,sha256=
|
|
59
|
-
model_compression_toolkit/common/quantization/quantize_node.py,sha256=
|
|
52
|
+
model_compression_toolkit/common/quantization/quantize_graph_weights.py,sha256=7f9OUmG1BqPT2BmySNh49e3_NjW-MmG34Yun7mkpTew,2698
|
|
53
|
+
model_compression_toolkit/common/quantization/quantize_node.py,sha256=_g4QVfqwR2MjomqZT6N_8lTEAmYu3RKv1-JdcnJD9Mc,3387
|
|
60
54
|
model_compression_toolkit/common/quantization/set_node_quantization_config.py,sha256=MNn4o26souwd1rkdxWbXF97TA_12knNEuRiVaj4vlCw,7729
|
|
61
55
|
model_compression_toolkit/common/quantization/quantization_params_generation/__init__.py,sha256=8wFyT6t-uX0MQ-2QCa4qGZYmquhWroEdBrsren6wrlI,2045
|
|
62
56
|
model_compression_toolkit/common/quantization/quantization_params_generation/kl_selection.py,sha256=Rk7yKNZDXePffzcOxlqK1er8mXkw0k_Qn3TY-Ve18eE,15158
|
|
63
57
|
model_compression_toolkit/common/quantization/quantization_params_generation/kmeans_params.py,sha256=e8E7Htdc00mqATtHpOJUATadDY-LuHSUfJBh1-FLeKw,2648
|
|
64
|
-
model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py,sha256=
|
|
58
|
+
model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py,sha256=8ShckR-tkrWTBiJEFhogHhTsShRF5hTTlHHV7OVJWAA,6642
|
|
65
59
|
model_compression_toolkit/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=YUkjhYl4rHDtPIOUfMS7gxKiMFWIcIYhR1_g9Rdk6WQ,3559
|
|
66
60
|
model_compression_toolkit/common/quantization/quantization_params_generation/mae_selection.py,sha256=ipfyH29zpSK78C_ktEOdch5O2-fBDHYOd3b4IU4EFfk,5343
|
|
67
61
|
model_compression_toolkit/common/quantization/quantization_params_generation/mse_selection.py,sha256=R2pBjOIeozczdVi-nqfQvy0Auj0QLP0XsgaCIH6DdMk,5420
|
|
68
62
|
model_compression_toolkit/common/quantization/quantization_params_generation/no_clipping.py,sha256=0TTU_rptXNjhldoH_K1-CE8eWEGq-EoSJPLlRM8dxUQ,6466
|
|
69
63
|
model_compression_toolkit/common/quantization/quantization_params_generation/outlier_filter.py,sha256=SeouwEGp5oQ9vZu6OxHKpw3Hu01mqYmCoS-xneLN4kk,1773
|
|
70
64
|
model_compression_toolkit/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=4pHn5oX2vfzfiLv2g7YX925wnSnF0qd5o-q0VcLtbSI,3066
|
|
71
|
-
model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py,sha256=
|
|
65
|
+
model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py,sha256=OSiSoIsvVsQ0hZKA6TF_P6NTyXgWeODBRX39ef4taio,4733
|
|
72
66
|
model_compression_toolkit/common/quantization/quantization_params_generation/qparams_search.py,sha256=RP1ylMBISrCAxRVZfPoD6RC5msb-Y6bvrJN9lC8xKGk,8565
|
|
73
67
|
model_compression_toolkit/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=G_IqFKpn4yVT6Q-uKbtJaF0N0QUxKvHeKdcNEGKlxwc,4943
|
|
74
68
|
model_compression_toolkit/common/quantization/quantizers/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
|
|
@@ -76,37 +70,36 @@ model_compression_toolkit/common/quantization/quantizers/kmeans_quantizer.py,sha
|
|
|
76
70
|
model_compression_toolkit/common/quantization/quantizers/lut_kmeans_quantizer.py,sha256=IhyoB8dWfwuXo4Ow-8iWWCln8lUZh3gX6cGHIZpf3CQ,2797
|
|
77
71
|
model_compression_toolkit/common/quantization/quantizers/power_of_two_quantizer.py,sha256=NI2UG_cn76svg1j2sGNVWOWuRcqcrVGh2Dge8CMpKB0,2108
|
|
78
72
|
model_compression_toolkit/common/quantization/quantizers/quantizers_helpers.py,sha256=muetNju10st84LnjCTLueSsNZxsXtatxGlV77crb6ms,6014
|
|
79
|
-
model_compression_toolkit/common/substitutions/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
|
|
80
|
-
model_compression_toolkit/common/substitutions/apply_substitutions.py,sha256=hfRgeqS873jMnF9CQJCUVMlElYX1tddGSFo1HfEOMuM,1459
|
|
81
73
|
model_compression_toolkit/common/visualization/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
|
|
82
74
|
model_compression_toolkit/common/visualization/tensorboard_writer.py,sha256=_aqRGqgZ7XcWK5gWl8aYFoMMk79K2Ex3c9aRvrqFAMI,16899
|
|
83
|
-
model_compression_toolkit/keras/__init__.py,sha256=
|
|
75
|
+
model_compression_toolkit/keras/__init__.py,sha256=J44ch9RssgFKwQA1XfWDFfA-UiusQmW4jMBNc69QNgQ,754
|
|
84
76
|
model_compression_toolkit/keras/constants.py,sha256=NbsC3VSuZ4Tl28o7FR5PFIJwG5y3iGhHbsqUM90gaFs,1418
|
|
85
|
-
model_compression_toolkit/keras/default_framework_info.py,sha256
|
|
86
|
-
model_compression_toolkit/keras/
|
|
87
|
-
model_compression_toolkit/keras/quantization_facade.py,sha256=hYzQaVrhFcR0x4uBhDrb11_PN4OssRCqDY9Utgoexmg,13440
|
|
77
|
+
model_compression_toolkit/keras/default_framework_info.py,sha256=-gip0vu3wQXoABILprPlAwwexqFhDSBKfSk-_c-h0WM,5187
|
|
78
|
+
model_compression_toolkit/keras/quantization_facade.py,sha256=cc6Zj_J57blCWDM_M6Yhqj1uhR1SYiInArRYZYkml00,28485
|
|
88
79
|
model_compression_toolkit/keras/tensor_marking.py,sha256=lfVqpVW3Pfb1BS7Q2d0yra3dFfKBsqGzy9CjkCfwnZA,5067
|
|
89
80
|
model_compression_toolkit/keras/back2framework/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
|
|
90
81
|
model_compression_toolkit/keras/back2framework/instance_builder.py,sha256=PO_UaPVEWlKqfw8dexOdEeesQF_SUIACT_kGSLoDDHI,3344
|
|
91
|
-
model_compression_toolkit/keras/back2framework/model_builder.py,sha256
|
|
82
|
+
model_compression_toolkit/keras/back2framework/model_builder.py,sha256=-NZbDh9iYqFWXrE4z7U59i1YV9cLfldeD78TvgdouqE,15292
|
|
83
|
+
model_compression_toolkit/keras/back2framework/model_collector.py,sha256=i1_uvmeQdzOeIc0eaMCq7knYCTvqeZKJNZF9hcV7L1s,4171
|
|
92
84
|
model_compression_toolkit/keras/gradient_ptq/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
|
|
93
85
|
model_compression_toolkit/keras/gradient_ptq/gptq_loss.py,sha256=GGog0LDTb2_XptBjrlNjn2CJUlh6n7T32GNM1bMi5yo,1841
|
|
94
86
|
model_compression_toolkit/keras/gradient_ptq/graph_info.py,sha256=3538-D-4u1K8nXQcOeDyC-8dcP1tEk7pfv2cbRXwvzE,3188
|
|
95
87
|
model_compression_toolkit/keras/gradient_ptq/graph_update.py,sha256=vdbO2ZyXR4tkMf47CCW67bxuxFp4uTrlVYwK_BHN4Pc,3405
|
|
96
|
-
model_compression_toolkit/keras/gradient_ptq/training_wrapper.py,sha256=
|
|
88
|
+
model_compression_toolkit/keras/gradient_ptq/training_wrapper.py,sha256=W3wZnEskqR8vv8XQT21QAsiAvQF5IIoytACNxCuK-TE,7942
|
|
97
89
|
model_compression_toolkit/keras/graph_substitutions/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
|
|
90
|
+
model_compression_toolkit/keras/graph_substitutions/substituter.py,sha256=mzuIqICMvXJTfpMo_hHZ9xOLOl-yZVtoyrakYUzaQGY,7776
|
|
98
91
|
model_compression_toolkit/keras/graph_substitutions/substitutions/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
|
|
99
92
|
model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py,sha256=S-Y3FBIk_IrigIsuJCENa23ZkRSq9PoTDvGe-_8snyY,3886
|
|
100
93
|
model_compression_toolkit/keras/graph_substitutions/substitutions/batchnorm_folding.py,sha256=U2OjwCZ2sfS_OR4SLVjtx5w79Fg6fte5pLyLU_ec_2Q,5190
|
|
101
|
-
model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py,sha256=
|
|
94
|
+
model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py,sha256=7pcFm_fxf4d7Hg9NLu37a91OPTjqrSreacw6DyO1cE8,6546
|
|
102
95
|
model_compression_toolkit/keras/graph_substitutions/substitutions/mark_activation.py,sha256=uCZfrIWXw6OyihHV5UlGVy9S6HVnO1HHhZ6V5wmEFbM,3038
|
|
103
|
-
model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py,sha256=
|
|
96
|
+
model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py,sha256=jyOOhESbfmePHGhS9MnE2pygxBK24PUAEdRychim__s,5636
|
|
104
97
|
model_compression_toolkit/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py,sha256=8wOjEGvTG2nAR5Zq2TSny_eeDep_8sYbDqyOT_xGtbY,2237
|
|
105
98
|
model_compression_toolkit/keras/graph_substitutions/substitutions/scale_equalization.py,sha256=nRMmqGMF-MJlp3dd6hGSSldGWq6P-wCvGyvSRClf_JI,15460
|
|
106
99
|
model_compression_toolkit/keras/graph_substitutions/substitutions/separableconv_decomposition.py,sha256=rL3dqnRXN7vj8PkBcdP73fwVV__BajBTRLMF8ilViBc,7609
|
|
107
100
|
model_compression_toolkit/keras/graph_substitutions/substitutions/shift_negative_activation.py,sha256=iBCF1K2lHCnVlmjaeE6GfbVRHZuH3jmQcGce9DOoYkQ,23234
|
|
108
101
|
model_compression_toolkit/keras/mixed_precision/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
|
|
109
|
-
model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py,sha256=
|
|
102
|
+
model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py,sha256=1SnCo59VmcJVwb9gJD5poJrObytcltznGjf9sWniqZw,11399
|
|
110
103
|
model_compression_toolkit/keras/quantizer/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
|
|
111
104
|
model_compression_toolkit/keras/quantizer/base_quantizer.py,sha256=WoD4UrgXcCjZSXPnOX9slxJLrYDYN7go6nKs-bWjfuo,1735
|
|
112
105
|
model_compression_toolkit/keras/quantizer/fake_quant_builder.py,sha256=jtGyNrTGFt6PFoPG9eQA56eIKUDLsulDAsde7o0mx2A,3342
|
|
@@ -125,7 +118,7 @@ model_compression_toolkit/keras/quantizer/mixed_precision/selective_quantizer.py
|
|
|
125
118
|
model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py,sha256=zGiaXcUmbiNRIb2Q9E6BocyCL4cyQjuDpfguC9sFVf8,7521
|
|
126
119
|
model_compression_toolkit/keras/reader/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
|
|
127
120
|
model_compression_toolkit/keras/reader/common.py,sha256=KCsvLFOeFltb9NBqQ2nU0MHYVPDBCYDBKF3oeoaangg,2417
|
|
128
|
-
model_compression_toolkit/keras/reader/connectivity_handler.py,sha256
|
|
121
|
+
model_compression_toolkit/keras/reader/connectivity_handler.py,sha256=-AYoapd40V25yt1h9MqDzXUDRVUOYho_h6F-U5zDXoY,11139
|
|
129
122
|
model_compression_toolkit/keras/reader/node_builder.py,sha256=N9KSkY7AaZeq0-Zxi5MhKdCsTW2qmFp1yjqk5aPVPpA,3420
|
|
130
123
|
model_compression_toolkit/keras/reader/reader.py,sha256=0i9GT3rTWDDK5fKbtt6PTqFpM59obk8mK0SDXVbycn4,7939
|
|
131
124
|
model_compression_toolkit/keras/reader/nested_model/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
|
|
@@ -134,9 +127,9 @@ model_compression_toolkit/keras/reader/nested_model/nested_model_handler.py,sha2
|
|
|
134
127
|
model_compression_toolkit/keras/reader/nested_model/nodes_merger.py,sha256=iI63gMLuE4_C-8aKSOLT3xRowViI3Je8aN8r5A37vy0,2167
|
|
135
128
|
model_compression_toolkit/keras/reader/nested_model/outputs_merger.py,sha256=bTzVEtXk-7zrtGN7_NZSAKSuCeBWTXd_sPe7RFYgHUQ,2381
|
|
136
129
|
model_compression_toolkit/keras/visualization/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
|
|
137
|
-
model_compression_toolkit/keras/visualization/nn_visualizer.py,sha256=
|
|
138
|
-
mct_nightly-
|
|
139
|
-
mct_nightly-
|
|
140
|
-
mct_nightly-
|
|
141
|
-
mct_nightly-
|
|
142
|
-
mct_nightly-
|
|
130
|
+
model_compression_toolkit/keras/visualization/nn_visualizer.py,sha256=AgrgGRHmV8M_NtFH5JzqmjAfBW2Q6PElHWwAq9WRSRU,4726
|
|
131
|
+
mct_nightly-1.1.0.2122021.post3117.dist-info/LICENSE,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
132
|
+
mct_nightly-1.1.0.2122021.post3117.dist-info/METADATA,sha256=D53VNuzBWaVqfm1h-O0Na5x7_kiAsa3x-6bBfX8FVEw,5950
|
|
133
|
+
mct_nightly-1.1.0.2122021.post3117.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92
|
|
134
|
+
mct_nightly-1.1.0.2122021.post3117.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
135
|
+
mct_nightly-1.1.0.2122021.post3117.dist-info/RECORD,,
|
|
@@ -13,8 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
from model_compression_toolkit.
|
|
16
|
+
from model_compression_toolkit.keras.quantization_facade import keras_post_training_quantization, keras_post_training_quantization_mixed_precision
|
|
17
|
+
from model_compression_toolkit.keras.gradient_ptq.training_wrapper import GradientPTQConfig
|
|
18
18
|
from model_compression_toolkit.common.quantization import quantization_config
|
|
19
19
|
from model_compression_toolkit.common.mixed_precision import mixed_precision_quantization_config
|
|
20
20
|
from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig, \
|
|
@@ -27,8 +27,4 @@ from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
|
27
27
|
from model_compression_toolkit.common.defaultdict import DefaultDict
|
|
28
28
|
from model_compression_toolkit.common import network_editors as network_editor
|
|
29
29
|
|
|
30
|
-
from model_compression_toolkit.keras.quantization_facade import keras_post_training_quantization, \
|
|
31
|
-
keras_post_training_quantization_mixed_precision
|
|
32
|
-
|
|
33
|
-
|
|
34
30
|
__version__ = "1.1.0"
|
|
@@ -20,6 +20,7 @@ from typing import Any
|
|
|
20
20
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
21
21
|
from model_compression_toolkit.common.matchers.base_matcher import BaseMatcher
|
|
22
22
|
|
|
23
|
+
|
|
23
24
|
class BaseSubstitution(ABC):
|
|
24
25
|
"""
|
|
25
26
|
Base class for all substitution classes.
|
|
@@ -14,13 +14,13 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
import copy
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
|
|
18
17
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
19
18
|
from model_compression_toolkit.common import Graph, Node
|
|
19
|
+
from model_compression_toolkit.keras.constants import BIAS, USE_BIAS
|
|
20
|
+
|
|
20
21
|
|
|
21
22
|
def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
|
|
22
|
-
fw_info: FrameworkInfo
|
|
23
|
-
fw_impl: FrameworkImplementation) -> Graph:
|
|
23
|
+
fw_info: FrameworkInfo) -> Graph:
|
|
24
24
|
"""
|
|
25
25
|
Get a graph, where each node has a final weights quantization configuration (with a bias
|
|
26
26
|
correction term in it), and apply the bias correction for each node in the graph.
|
|
@@ -28,7 +28,6 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
|
|
|
28
28
|
Args:
|
|
29
29
|
graph_to_apply_bias_correction: Graph to apply bias correction to.
|
|
30
30
|
fw_info: Framework information (e.g, operators to quantize their weights).
|
|
31
|
-
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
32
31
|
|
|
33
32
|
Returns:
|
|
34
33
|
Graph with bias correction apply to its' nodes.
|
|
@@ -40,28 +39,26 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
|
|
|
40
39
|
# If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg,
|
|
41
40
|
# a bias correction term was calculated during model preparation, and is used now in the node's bias term.
|
|
42
41
|
if n.final_weights_quantization_cfg.weights_bias_correction:
|
|
43
|
-
_apply_bias_correction_to_node(n
|
|
42
|
+
_apply_bias_correction_to_node(n)
|
|
44
43
|
return graph
|
|
45
44
|
|
|
46
45
|
|
|
47
|
-
def _apply_bias_correction_to_node(node:Node
|
|
48
|
-
fw_impl: FrameworkImplementation):
|
|
46
|
+
def _apply_bias_correction_to_node(node:Node):
|
|
49
47
|
"""
|
|
50
48
|
Set new bias to node using the bias-correction term that is stored in the
|
|
51
49
|
final weights quantization configuration.
|
|
52
50
|
|
|
53
51
|
Args:
|
|
54
52
|
node: Node to set its corrected bias after bias-correction.
|
|
55
|
-
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
56
53
|
|
|
57
54
|
"""
|
|
58
55
|
correction = node.final_weights_quantization_cfg.bias_corrected
|
|
59
56
|
|
|
60
|
-
bias = node.get_weights_by_keys(
|
|
57
|
+
bias = node.get_weights_by_keys(BIAS) # get original bias from node's weights
|
|
61
58
|
|
|
62
59
|
if bias is not None: # It the layer has bias, we subtract the correction from original bias
|
|
63
|
-
node.set_weights_by_keys(
|
|
60
|
+
node.set_weights_by_keys(BIAS, node.get_weights_by_keys(BIAS) - correction)
|
|
64
61
|
|
|
65
62
|
else: # It the layer has no bias, we consider it as if it has and its value is 0.
|
|
66
|
-
node.set_weights_by_keys(
|
|
67
|
-
node.framework_attr[
|
|
63
|
+
node.set_weights_by_keys(BIAS, - correction)
|
|
64
|
+
node.framework_attr[USE_BIAS] = True # Mark the use_bias attribute of the node.
|
|
@@ -12,22 +12,20 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
16
15
|
import copy
|
|
17
16
|
from typing import Any
|
|
18
17
|
|
|
19
18
|
import numpy as np
|
|
20
19
|
|
|
21
|
-
from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
|
|
22
20
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
23
21
|
from model_compression_toolkit.common import Node, Logger, Graph
|
|
24
22
|
from model_compression_toolkit.common.quantization.quantize_node import get_quantized_kernel_by_weights_qc
|
|
25
23
|
from model_compression_toolkit.common.statistics_collector import BaseStatsContainer
|
|
24
|
+
from model_compression_toolkit.keras.constants import KERNEL
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
def compute_bias_correction_of_graph(graph_co_compute_bias: Graph,
|
|
29
|
-
fw_info: FrameworkInfo
|
|
30
|
-
fw_impl: FrameworkImplementation) -> Graph:
|
|
28
|
+
fw_info: FrameworkInfo) -> Graph:
|
|
31
29
|
"""
|
|
32
30
|
For each node in a graph, and for each candidate weights quantization configuration,
|
|
33
31
|
compute the bias-correction term, and store it in the candidate weights quantization configuration.
|
|
@@ -36,7 +34,6 @@ def compute_bias_correction_of_graph(graph_co_compute_bias: Graph,
|
|
|
36
34
|
graph_co_compute_bias: Graph with nodes to compute the bias correction for
|
|
37
35
|
each node's weights quantization configuration candidates.
|
|
38
36
|
fw_info: Framework info like lists of nodes their kernel should quantized.
|
|
39
|
-
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
40
37
|
|
|
41
38
|
Returns:
|
|
42
39
|
Graph with bias correction for each weights quantization configuration candidate
|
|
@@ -46,17 +43,13 @@ def compute_bias_correction_of_graph(graph_co_compute_bias: Graph,
|
|
|
46
43
|
graph = copy.deepcopy(graph_co_compute_bias)
|
|
47
44
|
for n in graph.nodes:
|
|
48
45
|
if fw_info.in_kernel_ops(n):
|
|
49
|
-
_compute_bias_correction_per_candidate_qc(n,
|
|
50
|
-
fw_info,
|
|
51
|
-
graph.get_in_stats_collector(n),
|
|
52
|
-
fw_impl=fw_impl)
|
|
46
|
+
_compute_bias_correction_per_candidate_qc(n, fw_info, graph.get_in_stats_collector(n))
|
|
53
47
|
return graph
|
|
54
48
|
|
|
55
49
|
|
|
56
50
|
def _compute_bias_correction_per_candidate_qc(node: Node,
|
|
57
51
|
fw_info: FrameworkInfo,
|
|
58
|
-
node_in_stats_collector: BaseStatsContainer
|
|
59
|
-
fw_impl: FrameworkImplementation):
|
|
52
|
+
node_in_stats_collector: BaseStatsContainer):
|
|
60
53
|
"""
|
|
61
54
|
For each candidate weights quantization configuration of a given node,
|
|
62
55
|
compute the bias-correction term, and store it in the candidate weights quantization configuration.
|
|
@@ -65,7 +58,6 @@ def _compute_bias_correction_per_candidate_qc(node: Node,
|
|
|
65
58
|
node: Node to compute the bias correction for its different candidates.
|
|
66
59
|
fw_info: Framework info like lists of nodes their kernel should quantized.
|
|
67
60
|
node_in_stats_collector: Statistics collector of the node for the mean per-channel.
|
|
68
|
-
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
69
61
|
|
|
70
62
|
"""
|
|
71
63
|
|
|
@@ -74,8 +66,7 @@ def _compute_bias_correction_per_candidate_qc(node: Node,
|
|
|
74
66
|
if fw_info.in_kernel_ops(node) and weights_qc.enable_weights_quantization:
|
|
75
67
|
quantized_kernel, io_channels_axes = get_quantized_kernel_by_weights_qc(fw_info,
|
|
76
68
|
node,
|
|
77
|
-
weights_qc
|
|
78
|
-
fw_impl=fw_impl)
|
|
69
|
+
weights_qc)
|
|
79
70
|
|
|
80
71
|
# If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg,
|
|
81
72
|
# a bias correction term is being calculated and used in the node's bias term.
|
|
@@ -84,8 +75,7 @@ def _compute_bias_correction_per_candidate_qc(node: Node,
|
|
|
84
75
|
node,
|
|
85
76
|
node_in_stats_collector,
|
|
86
77
|
io_channels_axes[1],
|
|
87
|
-
quantized_kernel
|
|
88
|
-
fw_impl=fw_impl)
|
|
78
|
+
quantized_kernel)
|
|
89
79
|
|
|
90
80
|
# Store the correction term to use it later,
|
|
91
81
|
weights_qc.bias_corrected = bias_correction_term
|
|
@@ -135,8 +125,7 @@ def _get_bias_correction_term_of_node(input_channels_axis: int,
|
|
|
135
125
|
n: Node,
|
|
136
126
|
node_in_stats_collector: BaseStatsContainer,
|
|
137
127
|
output_channels_axis: int,
|
|
138
|
-
quantized_kernel: np.ndarray
|
|
139
|
-
fw_impl: FrameworkImplementation):
|
|
128
|
+
quantized_kernel: np.ndarray):
|
|
140
129
|
"""
|
|
141
130
|
Get the bias correction term for a node, using a quantized kernel (which can be quantized
|
|
142
131
|
using any possible bit width)
|
|
@@ -147,8 +136,6 @@ def _get_bias_correction_term_of_node(input_channels_axis: int,
|
|
|
147
136
|
node_in_stats_collector: Input statistics collector of the node.
|
|
148
137
|
output_channels_axis: Index of output channels of the kernel.
|
|
149
138
|
quantized_kernel: Quantized kernel of the node.
|
|
150
|
-
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
151
|
-
|
|
152
139
|
|
|
153
140
|
Returns:
|
|
154
141
|
Bias-correction term to subtract from the current node's bias.
|
|
@@ -163,7 +150,7 @@ def _get_bias_correction_term_of_node(input_channels_axis: int,
|
|
|
163
150
|
f'Unknown input channel axis for node named: {n.name},'
|
|
164
151
|
f' please update channel mapping function')
|
|
165
152
|
# Compute the bias correction term.
|
|
166
|
-
correction = _compute_bias_correction(n.get_weights_by_keys(
|
|
153
|
+
correction = _compute_bias_correction(n.get_weights_by_keys(KERNEL),
|
|
167
154
|
quantized_kernel,
|
|
168
155
|
node_in_stats_collector,
|
|
169
156
|
output_channels_axis,
|
|
@@ -74,7 +74,7 @@ class HistogramCollector(BaseCollector):
|
|
|
74
74
|
merged_histogram_min = np.min(bins_stack)
|
|
75
75
|
merged_histogram_max = np.max(bins_stack)
|
|
76
76
|
merged_bin_width = np.min(bins_stack[:, 1] - bins_stack[:, 0])
|
|
77
|
-
merged_histogram_bins = np.arange(merged_histogram_min, merged_histogram_max
|
|
77
|
+
merged_histogram_bins = np.arange(merged_histogram_min, merged_histogram_max, merged_bin_width)
|
|
78
78
|
|
|
79
79
|
merged_histogram_counts = None
|
|
80
80
|
for histogram in self.__histogram_per_iteration: # Iterate all collected histograms and merge them
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
15
|
+
|
|
16
16
|
|
|
17
17
|
from copy import copy, deepcopy
|
|
18
18
|
from typing import List
|
|
@@ -29,11 +29,9 @@ from model_compression_toolkit.common.graph.node import Node
|
|
|
29
29
|
from model_compression_toolkit.common.statistics_collector import BaseStatsContainer
|
|
30
30
|
from model_compression_toolkit.common.statistics_collector import scale_statistics, shift_statistics
|
|
31
31
|
from model_compression_toolkit.common.user_info import UserInformation
|
|
32
|
+
from model_compression_toolkit.keras.reader.connectivity_handler import OutTensor
|
|
32
33
|
from model_compression_toolkit.common.logger import Logger
|
|
33
34
|
|
|
34
|
-
OutTensor = namedtuple('OutTensor', 'node node_out_index')
|
|
35
|
-
|
|
36
|
-
|
|
37
35
|
class Graph(nx.MultiDiGraph, GraphSearches):
|
|
38
36
|
"""
|
|
39
37
|
Base graph representing a model to be optimized.
|
|
@@ -19,7 +19,9 @@ from typing import Any, List
|
|
|
19
19
|
from tensorflow.keras.layers import Layer
|
|
20
20
|
|
|
21
21
|
from model_compression_toolkit.common.graph.node import Node
|
|
22
|
-
from model_compression_toolkit.common.matchers import
|
|
22
|
+
from model_compression_toolkit.common.matchers import edge_matcher
|
|
23
|
+
from model_compression_toolkit.common.matchers import node_matcher
|
|
24
|
+
from model_compression_toolkit.common.matchers import walk_matcher
|
|
23
25
|
|
|
24
26
|
|
|
25
27
|
class NodeOperationMatcher(node_matcher.BaseNodeMatcher):
|
|
@@ -17,7 +17,9 @@ from abc import ABC
|
|
|
17
17
|
from typing import List, Any
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.common.graph.node import Node
|
|
20
|
-
from model_compression_toolkit.common.matchers import
|
|
20
|
+
from model_compression_toolkit.common.matchers import base_graph_filter
|
|
21
|
+
from model_compression_toolkit.common.matchers import edge_matcher
|
|
22
|
+
from model_compression_toolkit.common.matchers import node_matcher
|
|
21
23
|
from model_compression_toolkit.common.matchers.walk_matcher import WalkMatcherList
|
|
22
24
|
|
|
23
25
|
|
|
@@ -17,7 +17,6 @@ import copy
|
|
|
17
17
|
|
|
18
18
|
from typing import Any, List
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
|
|
21
20
|
from model_compression_toolkit.common import Graph, Node
|
|
22
21
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
23
22
|
from model_compression_toolkit.common.logger import Logger
|
|
@@ -26,7 +25,7 @@ from model_compression_toolkit.common.mixed_precision.mixed_precision_quantizati
|
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
|
|
29
|
-
def set_bit_widths(quant_config:
|
|
28
|
+
def set_bit_widths(quant_config: MixedPrecisionQuantizationConfig,
|
|
30
29
|
graph_to_set_bit_widths: Graph,
|
|
31
30
|
fw_info: FrameworkInfo = None,
|
|
32
31
|
bit_widths_config: List[int] = None) -> Graph:
|
model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py
CHANGED
|
@@ -20,7 +20,7 @@ from model_compression_toolkit.common.constants import MIN_THRESHOLD
|
|
|
20
20
|
from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_search import qparams_selection_tensor_search, qparams_selection_histogram_search
|
|
21
21
|
from model_compression_toolkit.common.similarity_analyzer import compute_lp_norm
|
|
22
22
|
|
|
23
|
-
from model_compression_toolkit.
|
|
23
|
+
from model_compression_toolkit.keras.constants import THRESHOLD
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def lp_selection_tensor(tensor_data: np.ndarray,
|
model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py
CHANGED
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import List
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
|
|
18
17
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
19
18
|
from model_compression_toolkit.common import Graph, Node, Logger
|
|
20
19
|
from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_activations_computation \
|
|
@@ -22,13 +21,13 @@ from model_compression_toolkit.common.quantization.quantization_params_generatio
|
|
|
22
21
|
get_activations_qparams
|
|
23
22
|
from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_weights_computation import \
|
|
24
23
|
get_weights_qparams, get_channels_axis
|
|
24
|
+
from model_compression_toolkit.keras.constants import KERNEL
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
def calculate_quantization_params(graph: Graph,
|
|
28
28
|
fw_info: FrameworkInfo,
|
|
29
29
|
nodes: List[Node] = [],
|
|
30
|
-
specific_nodes: bool = False
|
|
31
|
-
fw_impl: FrameworkImplementation = None):
|
|
30
|
+
specific_nodes: bool = False):
|
|
32
31
|
"""
|
|
33
32
|
For a graph, go over its nodes, compute quantization params (for both weights and activations according
|
|
34
33
|
to the given framework info), and create and attach a NodeQuantizationConfig to each node (containing the
|
|
@@ -43,7 +42,6 @@ def calculate_quantization_params(graph: Graph,
|
|
|
43
42
|
graph: Graph to compute its nodes' thresholds.
|
|
44
43
|
nodes: List of nodes to compute their thresholds instead of computing it for all nodes in the graph.
|
|
45
44
|
specific_nodes: Flag to compute thresholds for only specific nodes.
|
|
46
|
-
fw_impl: FrameworkImplementation with specific framework implementations.
|
|
47
45
|
|
|
48
46
|
"""
|
|
49
47
|
|
|
@@ -59,7 +57,7 @@ def calculate_quantization_params(graph: Graph,
|
|
|
59
57
|
|
|
60
58
|
for candidtae_qc in n.candidates_weights_quantization_cfg:
|
|
61
59
|
output_channels_axis, _ = get_channels_axis(candidtae_qc, fw_info, n.layer_class)
|
|
62
|
-
weights_params = get_weights_qparams(n.get_weights_by_keys(
|
|
60
|
+
weights_params = get_weights_qparams(n.get_weights_by_keys(KERNEL),
|
|
63
61
|
candidtae_qc,
|
|
64
62
|
output_channels_axis)
|
|
65
63
|
|
|
@@ -16,15 +16,14 @@
|
|
|
16
16
|
import copy
|
|
17
17
|
|
|
18
18
|
from model_compression_toolkit import common
|
|
19
|
-
from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
|
|
20
19
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
21
20
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
22
21
|
from model_compression_toolkit.common.quantization.quantize_node import get_quantized_kernel_by_weights_qc
|
|
22
|
+
from model_compression_toolkit.keras.constants import KERNEL
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
def quantize_graph_weights(graph_to_quantize: Graph,
|
|
26
|
-
fw_info: FrameworkInfo
|
|
27
|
-
fw_impl: FrameworkImplementation) -> Graph:
|
|
26
|
+
fw_info: FrameworkInfo) -> Graph:
|
|
28
27
|
"""
|
|
29
28
|
Get a graph representing a model, and quantize its nodes' weights.
|
|
30
29
|
Each node is quantized according to the passed framework info and quantization configuration.
|
|
@@ -34,7 +33,6 @@ def quantize_graph_weights(graph_to_quantize: Graph,
|
|
|
34
33
|
Args:
|
|
35
34
|
graph_to_quantize: Graph to quantize its nodes.
|
|
36
35
|
fw_info: Framework information needed for quantizing the graph's nodes' weights and activations.
|
|
37
|
-
fw_impl: FrameworkImplementation with specific framework implementations.
|
|
38
36
|
|
|
39
37
|
"""
|
|
40
38
|
graph = copy.deepcopy(graph_to_quantize)
|
|
@@ -45,14 +43,13 @@ def quantize_graph_weights(graph_to_quantize: Graph,
|
|
|
45
43
|
if fw_info.in_kernel_ops(n) and n.final_weights_quantization_cfg.enable_weights_quantization:
|
|
46
44
|
quantized_kernel, io_channels_axes = get_quantized_kernel_by_weights_qc(fw_info,
|
|
47
45
|
n,
|
|
48
|
-
n.final_weights_quantization_cfg
|
|
49
|
-
fw_impl=fw_impl)
|
|
46
|
+
n.final_weights_quantization_cfg)
|
|
50
47
|
|
|
51
48
|
common.Logger.debug(
|
|
52
49
|
f'Node name: {n.name} has the following quantization params: '
|
|
53
50
|
f'{str(n.final_weights_quantization_cfg.weights_quantization_params)}')
|
|
54
51
|
|
|
55
52
|
# Set the kernel node to be the quantized kernel.
|
|
56
|
-
n.set_weights_by_keys(
|
|
53
|
+
n.set_weights_by_keys(KERNEL, quantized_kernel)
|
|
57
54
|
|
|
58
55
|
return graph
|
|
@@ -18,18 +18,17 @@ import copy
|
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit import common
|
|
20
20
|
from model_compression_toolkit.common import Logger
|
|
21
|
-
from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
|
|
22
21
|
from model_compression_toolkit.common.graph.node import Node
|
|
23
22
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
24
23
|
from model_compression_toolkit.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
|
|
25
24
|
from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_weights_computation import \
|
|
26
25
|
get_channels_axis
|
|
26
|
+
from model_compression_toolkit.keras.constants import KERNEL
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def get_quantized_kernel_by_weights_qc(fw_info:FrameworkInfo,
|
|
30
30
|
n:Node,
|
|
31
|
-
weights_qc: NodeWeightsQuantizationConfig
|
|
32
|
-
fw_impl: FrameworkImplementation):
|
|
31
|
+
weights_qc: NodeWeightsQuantizationConfig):
|
|
33
32
|
"""
|
|
34
33
|
For a node and a weights quantization configuration, compute
|
|
35
34
|
the quantized kernel of the node and return it and the input/output channels indices.
|
|
@@ -38,7 +37,6 @@ def get_quantized_kernel_by_weights_qc(fw_info:FrameworkInfo,
|
|
|
38
37
|
fw_info: A FrameworkInfo object Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
|
|
39
38
|
n: Node to quantize its kernel.
|
|
40
39
|
weights_qc: Weight quantization configuration to use for the quantization.
|
|
41
|
-
fw_impl: FrameworkImplementation with specific framework implementations.
|
|
42
40
|
|
|
43
41
|
Returns:
|
|
44
42
|
A quantized kernel of the node using a weights quantization configuration.
|
|
@@ -55,7 +53,7 @@ def get_quantized_kernel_by_weights_qc(fw_info:FrameworkInfo,
|
|
|
55
53
|
n.layer_class)
|
|
56
54
|
|
|
57
55
|
Logger.debug(f'quantizing {n.name} with {weights_qc.weights_n_bits} bits')
|
|
58
|
-
quantized_kernel = weights_qc.weights_quantization_fn(n.get_weights_by_keys(
|
|
56
|
+
quantized_kernel = weights_qc.weights_quantization_fn(n.get_weights_by_keys(KERNEL),
|
|
59
57
|
n_bits=weights_qc.weights_n_bits,
|
|
60
58
|
signed=True,
|
|
61
59
|
quantization_params=weights_qc.weights_quantization_params,
|