mct-nightly 0.0.0__py3-none-any.whl → 1.1.0.01122021-003325__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/METADATA +3 -2
  2. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/RECORD +31 -38
  3. model_compression_toolkit/__init__.py +2 -6
  4. model_compression_toolkit/common/base_substitutions.py +1 -0
  5. model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +9 -12
  6. model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +8 -21
  7. model_compression_toolkit/common/collectors/histogram_collector.py +1 -1
  8. model_compression_toolkit/common/graph/base_graph.py +2 -4
  9. model_compression_toolkit/common/graph/graph_matchers.py +3 -1
  10. model_compression_toolkit/common/graph/graph_searches.py +3 -1
  11. model_compression_toolkit/common/mixed_precision/bit_width_setter.py +1 -2
  12. model_compression_toolkit/common/network_editors/node_filters.py +1 -0
  13. model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py +1 -1
  14. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +3 -5
  15. model_compression_toolkit/common/quantization/quantize_graph_weights.py +4 -7
  16. model_compression_toolkit/common/quantization/quantize_node.py +3 -5
  17. model_compression_toolkit/keras/__init__.py +2 -0
  18. model_compression_toolkit/keras/back2framework/model_builder.py +24 -1
  19. model_compression_toolkit/{common → keras/back2framework}/model_collector.py +9 -18
  20. model_compression_toolkit/keras/default_framework_info.py +0 -1
  21. model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +57 -10
  22. model_compression_toolkit/keras/graph_substitutions/substituter.py +171 -0
  23. model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +26 -6
  24. model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +12 -5
  25. model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +3 -4
  26. model_compression_toolkit/keras/quantization_facade.py +524 -188
  27. model_compression_toolkit/keras/reader/connectivity_handler.py +4 -1
  28. model_compression_toolkit/keras/visualization/nn_visualizer.py +1 -2
  29. model_compression_toolkit/common/framework_implementation.py +0 -239
  30. model_compression_toolkit/common/gptq/__init__.py +0 -14
  31. model_compression_toolkit/common/gptq/gptq_config.py +0 -65
  32. model_compression_toolkit/common/model_builder_mode.py +0 -34
  33. model_compression_toolkit/common/post_training_quantization.py +0 -459
  34. model_compression_toolkit/common/substitutions/__init__.py +0 -14
  35. model_compression_toolkit/common/substitutions/apply_substitutions.py +0 -40
  36. model_compression_toolkit/keras/keras_implementation.py +0 -256
  37. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/LICENSE +0 -0
  38. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/WHEEL +0 -0
  39. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 0.0.0
3
+ Version: 1.1.0.01122021-003325
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -11,6 +11,7 @@ Classifier: Operating System :: OS Independent
11
11
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
12
12
  Requires-Python: >=3.6
13
13
  Description-Content-Type: text/markdown
14
+ Requires-Dist: tensorflow (==2.5.*)
14
15
  Requires-Dist: networkx (==2.5)
15
16
  Requires-Dist: tqdm
16
17
  Requires-Dist: Pillow
@@ -18,9 +19,9 @@ Requires-Dist: numpy
18
19
  Requires-Dist: opencv-python
19
20
  Requires-Dist: scikit-image
20
21
  Requires-Dist: scikit-learn
22
+ Requires-Dist: tensorflow-model-optimization
21
23
  Requires-Dist: tensorboard
22
24
  Requires-Dist: PuLP
23
- Requires-Dist: matplotlib
24
25
 
25
26
  # Model Compression Toolkit (MCT)
26
27
  ![tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_suite.yml/badge.svg)
@@ -1,33 +1,27 @@
1
- model_compression_toolkit/__init__.py,sha256=B5RFfeWQNuMdLxrY6TZWHcX2O25oc7KfkKRvVYdGMKQ,1892
1
+ model_compression_toolkit/__init__.py,sha256=3Jc6zSKbZJ2oGnNJMKKw54QWKZdtDMGoYLIGbG8IooU,1895
2
2
  model_compression_toolkit/common/__init__.py,sha256=LtklANR21DvGL9oO4ila5u6Ft7M83-dqftICub6Vowk,1454
3
- model_compression_toolkit/common/base_substitutions.py,sha256=FHVBXTdn_sh6QpkdOpZ4IObcoeF49klPvak3T1XSHzs,1656
3
+ model_compression_toolkit/common/base_substitutions.py,sha256=eAQfGVDrzHw6qRCkdm5ENVJLlod633UWCI2qLQyyoCA,1657
4
4
  model_compression_toolkit/common/constants.py,sha256=DWcVXmyONazCWZYp6LT-0Kav0bbMxfB8sDWjMRAxqxY,1472
5
5
  model_compression_toolkit/common/data_loader.py,sha256=fqTPODRsWZAhkXExE4iP2kgtHv5kOw1UzBojWXIrq7c,4018
6
6
  model_compression_toolkit/common/defaultdict.py,sha256=C9LVHx7D0WniqOsoiENkPNdUj49kiQEabESr6FRT5OA,2132
7
- model_compression_toolkit/common/framework_implementation.py,sha256=mz3WI7wgXoo_jIs_XM-D1uCl7S-gEPml3UKDjAGri3w,9827
8
7
  model_compression_toolkit/common/framework_info.py,sha256=mpPvVS8JYa987MYqmfj32iEsLY-vhqY8UGtgoUYK3_k,7723
9
8
  model_compression_toolkit/common/logger.py,sha256=RrDE9VCndUrdbQ_4DL5VP8z10fr5BMCJvxswHIvb3Os,4674
10
- model_compression_toolkit/common/model_builder_mode.py,sha256=G06Tpf-p7vBpa_k0Z8sNdbwhN_QEX-3lf0ajafSrQKA,1488
11
- model_compression_toolkit/common/model_collector.py,sha256=c-wy-ln9o9FI7jgpS4pXEB3h81XLXwXKPrY0qjdHQuw,4665
12
- model_compression_toolkit/common/post_training_quantization.py,sha256=T6T9YmoKgXZMdRuLKVZVVviXb7SaefnQ4HNsk5AiWv4,22575
13
9
  model_compression_toolkit/common/similarity_analyzer.py,sha256=fB5oGzRY941XONCxnXRy43n4F1qxqXM8EuF31t__Izg,4047
14
10
  model_compression_toolkit/common/statistics_collector.py,sha256=cv-qrVM9JY9JLyxfC5G2EzqPAdbOBMN4Gm3Qf9tOhIY,7973
15
11
  model_compression_toolkit/common/user_info.py,sha256=OYUdFBma5TKrDLBXEjgo1Oj0RK0zpvkVmP_5ZJ0X-2c,1535
16
12
  model_compression_toolkit/common/bias_correction/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
17
- model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py,sha256=xYQoZ24IuFrm-4c415BCirYbRaDjKl9OnCW4TDUxO7k,3328
18
- model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py,sha256=UKpQZzK7U8xc5Tu2ZJIRk2b_rNB8I-lJDDIYWvcfr1A,8372
13
+ model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py,sha256=i_Zky82Ow0_dpeSqp7z_5z2Zm-Yx-tQtY8R1RRIUfDE,2870
14
+ model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py,sha256=-pud80aNesybSqxvYSUiHOvK3CEcTcBvQZtxIGXjaNg,7424
19
15
  model_compression_toolkit/common/collectors/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
20
16
  model_compression_toolkit/common/collectors/base_collector.py,sha256=s1Kne14LcFUajwftMBrusu0wTPHCUQx4TegPcFSWVss,2427
21
- model_compression_toolkit/common/collectors/histogram_collector.py,sha256=_MDJPsa98rsuN1p8ix3fhoeCg6k9lEZjBvyqTLQ9Fqw,6860
17
+ model_compression_toolkit/common/collectors/histogram_collector.py,sha256=9Ypg35t-zdWSTryA-XvubHSD_PD8ZvJoQ0IliqFm46c,6843
22
18
  model_compression_toolkit/common/collectors/mean_collector.py,sha256=x3n2dmZl2EuXh_LgEQnO7YMlL6G7s0nFltsIhXGx6W0,3835
23
19
  model_compression_toolkit/common/collectors/min_max_per_channel_collector.py,sha256=ZzJkogbqDkVj_tiS3gzHfTsEkomroeHBcJrXod0ga4k,5031
24
- model_compression_toolkit/common/gptq/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
25
- model_compression_toolkit/common/gptq/gptq_config.py,sha256=XDfJgamLEyJKyq4rmhX5gqhM30a1DiJjKqoy3N6t1pM,2787
26
20
  model_compression_toolkit/common/graph/__init__.py,sha256=3x8iejFh5xa6tUnOtu5tU7KFOKdRqADhNac89pDItWo,760
27
- model_compression_toolkit/common/graph/base_graph.py,sha256=jUDItUWIDquKTmEtVtQ-WC2gK8l7QPZKN_1LHF-tmhs,18408
21
+ model_compression_toolkit/common/graph/base_graph.py,sha256=2Bds0yeUA8HtF1Rg2moBwLRN5cG0p7wBjdIEVT3RJo8,18395
28
22
  model_compression_toolkit/common/graph/edge.py,sha256=J1ksD-NHXNA_aMqoEQEaR81URLuAht2KNwu_ARwODHE,3712
29
- model_compression_toolkit/common/graph/graph_matchers.py,sha256=8m51Ve__sQ-joGxNSyQZJImUE5zGgkp2xuSdhWIb1zY,4754
30
- model_compression_toolkit/common/graph/graph_searches.py,sha256=HYG5f1SN3cKA3Jeq9iUnubk1EJiPD6k-g6ufFTXaIKc,5093
23
+ model_compression_toolkit/common/graph/graph_matchers.py,sha256=2k5yogBwXcxIAENZTIcpZ-lhCe23ZGqnERcd3RmTWik,4860
24
+ model_compression_toolkit/common/graph/graph_searches.py,sha256=o3xcFMtKJMElFGt0wcEMC0c48Qcv0azsF-e__Ti_TKk,5199
31
25
  model_compression_toolkit/common/graph/graph_vis.py,sha256=NqWkLW75ap2n-RX_TmooVtxJ63iR3BgtfRR6f4eUWYs,3704
32
26
  model_compression_toolkit/common/graph/node.py,sha256=LUi8l1-mzrI-0_USw1Zt-1j6xkHT4tDSAPx9OXBc5pI,7527
33
27
  model_compression_toolkit/common/matchers/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
@@ -38,7 +32,7 @@ model_compression_toolkit/common/matchers/function.py,sha256=Aqe0gLQJUSAn-unHYF5
38
32
  model_compression_toolkit/common/matchers/node_matcher.py,sha256=fyKrVaC_S1qU-Fx-SJVW0YxEDv7LMHDUZQyfMgbH-xY,2746
39
33
  model_compression_toolkit/common/matchers/walk_matcher.py,sha256=wZuTitcKe4CeGizwb4P1_D5o1wd7ZfaKRA_tTng8FeQ,1112
40
34
  model_compression_toolkit/common/mixed_precision/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
41
- model_compression_toolkit/common/mixed_precision/bit_width_setter.py,sha256=DpgZatq-gEhUHfuN4xw3iagiMrNih1poWGd3uVDVW8g,6431
35
+ model_compression_toolkit/common/mixed_precision/bit_width_setter.py,sha256=b47YTTlz3emzcnFLTYUXyO9nG1AOHe1didsx6KqK8EY,6348
42
36
  model_compression_toolkit/common/mixed_precision/distance_weighting.py,sha256=3a42gnctiGI5sLDTQj0WjSaaWp8AImM3nMQHNVpyies,2426
43
37
  model_compression_toolkit/common/mixed_precision/kpi.py,sha256=eUYXmR7LJC2SG77XI2yy29UkS_2uimxcS03EzXZJox4,1279
44
38
  model_compression_toolkit/common/mixed_precision/mixed_precision_quantization_config.py,sha256=fkTsB_-mbwMWfxXJxpStUBGU8Os3_vLI267oeWKS58A,3423
@@ -49,26 +43,26 @@ model_compression_toolkit/common/mixed_precision/search_methods/linear_programmi
49
43
  model_compression_toolkit/common/network_editors/__init__.py,sha256=dJWrbHs2NeGYW7Kzy7m7waIw8Ounjd1tiiL0a4nRAos,1235
50
44
  model_compression_toolkit/common/network_editors/actions.py,sha256=czJlYQnC6r38XJcpipvcCy2DQNd01QRwkIh_d3BLwSg,13118
51
45
  model_compression_toolkit/common/network_editors/edit_network.py,sha256=KsfILh0y_n5kqkaAJDeuHtJ5Pgf4OW6oWjl13ySV0QQ,1723
52
- model_compression_toolkit/common/network_editors/node_filters.py,sha256=W10GgyW8JicTgc10VjJtbFTuFVHuoOb8HDwWzxl4qdg,3151
46
+ model_compression_toolkit/common/network_editors/node_filters.py,sha256=T3jMUoajr32BxS28NctEzMm5kBxD9nGSFZCaimAhpaI,3152
53
47
  model_compression_toolkit/common/quantization/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
54
48
  model_compression_toolkit/common/quantization/node_quantization_config.py,sha256=e81Yoqtp6e8saCDZFd2bYJMRtdxQ7aPCorwhYQSws6g,9359
55
49
  model_compression_toolkit/common/quantization/quantization_analyzer.py,sha256=xcZ2BPvUZ5ATNStavKjWOcJnvx8OQ7yBG0O_GgcsvuE,2972
56
50
  model_compression_toolkit/common/quantization/quantization_config.py,sha256=N2JZFO2AFlAqv9KbKYtIZlfKi2Dsb5RjNCNR8O76DPs,8503
57
51
  model_compression_toolkit/common/quantization/quantization_params_fn_selection.py,sha256=WSarcRfTcG2t62JCXIIucbe86oM_856vCb_1ygn-FeM,5448
58
- model_compression_toolkit/common/quantization/quantize_graph_weights.py,sha256=VopC_NxKd1rWjXhSTigZdnyldEocfsPAyeX-3WtEzkM,2993
59
- model_compression_toolkit/common/quantization/quantize_node.py,sha256=MatcsOaUX0ogXzD8WH5v7zUc2ITsHBTpBGw3diYT88w,3593
52
+ model_compression_toolkit/common/quantization/quantize_graph_weights.py,sha256=7f9OUmG1BqPT2BmySNh49e3_NjW-MmG34Yun7mkpTew,2698
53
+ model_compression_toolkit/common/quantization/quantize_node.py,sha256=_g4QVfqwR2MjomqZT6N_8lTEAmYu3RKv1-JdcnJD9Mc,3387
60
54
  model_compression_toolkit/common/quantization/set_node_quantization_config.py,sha256=MNn4o26souwd1rkdxWbXF97TA_12knNEuRiVaj4vlCw,7729
61
55
  model_compression_toolkit/common/quantization/quantization_params_generation/__init__.py,sha256=8wFyT6t-uX0MQ-2QCa4qGZYmquhWroEdBrsren6wrlI,2045
62
56
  model_compression_toolkit/common/quantization/quantization_params_generation/kl_selection.py,sha256=Rk7yKNZDXePffzcOxlqK1er8mXkw0k_Qn3TY-Ve18eE,15158
63
57
  model_compression_toolkit/common/quantization/quantization_params_generation/kmeans_params.py,sha256=e8E7Htdc00mqATtHpOJUATadDY-LuHSUfJBh1-FLeKw,2648
64
- model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py,sha256=9TIgdYVi0WPCodpk2lJ1BsEjoVTWQ9fINoXGtOkB0bI,6643
58
+ model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py,sha256=8ShckR-tkrWTBiJEFhogHhTsShRF5hTTlHHV7OVJWAA,6642
65
59
  model_compression_toolkit/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=YUkjhYl4rHDtPIOUfMS7gxKiMFWIcIYhR1_g9Rdk6WQ,3559
66
60
  model_compression_toolkit/common/quantization/quantization_params_generation/mae_selection.py,sha256=ipfyH29zpSK78C_ktEOdch5O2-fBDHYOd3b4IU4EFfk,5343
67
61
  model_compression_toolkit/common/quantization/quantization_params_generation/mse_selection.py,sha256=R2pBjOIeozczdVi-nqfQvy0Auj0QLP0XsgaCIH6DdMk,5420
68
62
  model_compression_toolkit/common/quantization/quantization_params_generation/no_clipping.py,sha256=0TTU_rptXNjhldoH_K1-CE8eWEGq-EoSJPLlRM8dxUQ,6466
69
63
  model_compression_toolkit/common/quantization/quantization_params_generation/outlier_filter.py,sha256=SeouwEGp5oQ9vZu6OxHKpw3Hu01mqYmCoS-xneLN4kk,1773
70
64
  model_compression_toolkit/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=4pHn5oX2vfzfiLv2g7YX925wnSnF0qd5o-q0VcLtbSI,3066
71
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py,sha256=OFrboVTE1D1taaP9a1tgY5P2FCvQkJaB4EvHoPZF4WE,4941
65
+ model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py,sha256=OSiSoIsvVsQ0hZKA6TF_P6NTyXgWeODBRX39ef4taio,4733
72
66
  model_compression_toolkit/common/quantization/quantization_params_generation/qparams_search.py,sha256=RP1ylMBISrCAxRVZfPoD6RC5msb-Y6bvrJN9lC8xKGk,8565
73
67
  model_compression_toolkit/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=G_IqFKpn4yVT6Q-uKbtJaF0N0QUxKvHeKdcNEGKlxwc,4943
74
68
  model_compression_toolkit/common/quantization/quantizers/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
@@ -76,37 +70,36 @@ model_compression_toolkit/common/quantization/quantizers/kmeans_quantizer.py,sha
76
70
  model_compression_toolkit/common/quantization/quantizers/lut_kmeans_quantizer.py,sha256=IhyoB8dWfwuXo4Ow-8iWWCln8lUZh3gX6cGHIZpf3CQ,2797
77
71
  model_compression_toolkit/common/quantization/quantizers/power_of_two_quantizer.py,sha256=NI2UG_cn76svg1j2sGNVWOWuRcqcrVGh2Dge8CMpKB0,2108
78
72
  model_compression_toolkit/common/quantization/quantizers/quantizers_helpers.py,sha256=muetNju10st84LnjCTLueSsNZxsXtatxGlV77crb6ms,6014
79
- model_compression_toolkit/common/substitutions/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
80
- model_compression_toolkit/common/substitutions/apply_substitutions.py,sha256=hfRgeqS873jMnF9CQJCUVMlElYX1tddGSFo1HfEOMuM,1459
81
73
  model_compression_toolkit/common/visualization/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
82
74
  model_compression_toolkit/common/visualization/tensorboard_writer.py,sha256=_aqRGqgZ7XcWK5gWl8aYFoMMk79K2Ex3c9aRvrqFAMI,16899
83
- model_compression_toolkit/keras/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
75
+ model_compression_toolkit/keras/__init__.py,sha256=J44ch9RssgFKwQA1XfWDFfA-UiusQmW4jMBNc69QNgQ,754
84
76
  model_compression_toolkit/keras/constants.py,sha256=NbsC3VSuZ4Tl28o7FR5PFIJwG5y3iGhHbsqUM90gaFs,1418
85
- model_compression_toolkit/keras/default_framework_info.py,sha256=_wZv6luoQOk4eaHGRHXasGVOKqZx5nGVwj3VCi4dPec,5188
86
- model_compression_toolkit/keras/keras_implementation.py,sha256=FoGhblVhLN7aQtEp6VNJn1qBc56f9avlEYWOFlHAxko,10763
87
- model_compression_toolkit/keras/quantization_facade.py,sha256=hYzQaVrhFcR0x4uBhDrb11_PN4OssRCqDY9Utgoexmg,13440
77
+ model_compression_toolkit/keras/default_framework_info.py,sha256=-gip0vu3wQXoABILprPlAwwexqFhDSBKfSk-_c-h0WM,5187
78
+ model_compression_toolkit/keras/quantization_facade.py,sha256=cc6Zj_J57blCWDM_M6Yhqj1uhR1SYiInArRYZYkml00,28485
88
79
  model_compression_toolkit/keras/tensor_marking.py,sha256=lfVqpVW3Pfb1BS7Q2d0yra3dFfKBsqGzy9CjkCfwnZA,5067
89
80
  model_compression_toolkit/keras/back2framework/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
90
81
  model_compression_toolkit/keras/back2framework/instance_builder.py,sha256=PO_UaPVEWlKqfw8dexOdEeesQF_SUIACT_kGSLoDDHI,3344
91
- model_compression_toolkit/keras/back2framework/model_builder.py,sha256=b4qnP2hDrLl6DU5gGHzgej7Ol2Wxp6ZdcCd0-Qos3jg,14531
82
+ model_compression_toolkit/keras/back2framework/model_builder.py,sha256=-NZbDh9iYqFWXrE4z7U59i1YV9cLfldeD78TvgdouqE,15292
83
+ model_compression_toolkit/keras/back2framework/model_collector.py,sha256=i1_uvmeQdzOeIc0eaMCq7knYCTvqeZKJNZF9hcV7L1s,4171
92
84
  model_compression_toolkit/keras/gradient_ptq/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
93
85
  model_compression_toolkit/keras/gradient_ptq/gptq_loss.py,sha256=GGog0LDTb2_XptBjrlNjn2CJUlh6n7T32GNM1bMi5yo,1841
94
86
  model_compression_toolkit/keras/gradient_ptq/graph_info.py,sha256=3538-D-4u1K8nXQcOeDyC-8dcP1tEk7pfv2cbRXwvzE,3188
95
87
  model_compression_toolkit/keras/gradient_ptq/graph_update.py,sha256=vdbO2ZyXR4tkMf47CCW67bxuxFp4uTrlVYwK_BHN4Pc,3405
96
- model_compression_toolkit/keras/gradient_ptq/training_wrapper.py,sha256=38n9GKmig5-Ll8EPotJtVGY9h-6Tt9-onB6aq-67LNg,5991
88
+ model_compression_toolkit/keras/gradient_ptq/training_wrapper.py,sha256=W3wZnEskqR8vv8XQT21QAsiAvQF5IIoytACNxCuK-TE,7942
97
89
  model_compression_toolkit/keras/graph_substitutions/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
90
+ model_compression_toolkit/keras/graph_substitutions/substituter.py,sha256=mzuIqICMvXJTfpMo_hHZ9xOLOl-yZVtoyrakYUzaQGY,7776
98
91
  model_compression_toolkit/keras/graph_substitutions/substitutions/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
99
92
  model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py,sha256=S-Y3FBIk_IrigIsuJCENa23ZkRSq9PoTDvGe-_8snyY,3886
100
93
  model_compression_toolkit/keras/graph_substitutions/substitutions/batchnorm_folding.py,sha256=U2OjwCZ2sfS_OR4SLVjtx5w79Fg6fte5pLyLU_ec_2Q,5190
101
- model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py,sha256=KSR73RspupgO_ETr7BQWvvjjhRY86VEGtqDSN9ZeEYw,5184
94
+ model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py,sha256=7pcFm_fxf4d7Hg9NLu37a91OPTjqrSreacw6DyO1cE8,6546
102
95
  model_compression_toolkit/keras/graph_substitutions/substitutions/mark_activation.py,sha256=uCZfrIWXw6OyihHV5UlGVy9S6HVnO1HHhZ6V5wmEFbM,3038
103
- model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py,sha256=ltBqOQKwRn3PNIKiFS62fmXi8aa3HA-WLYypWsIHXLs,5167
96
+ model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py,sha256=jyOOhESbfmePHGhS9MnE2pygxBK24PUAEdRychim__s,5636
104
97
  model_compression_toolkit/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py,sha256=8wOjEGvTG2nAR5Zq2TSny_eeDep_8sYbDqyOT_xGtbY,2237
105
98
  model_compression_toolkit/keras/graph_substitutions/substitutions/scale_equalization.py,sha256=nRMmqGMF-MJlp3dd6hGSSldGWq6P-wCvGyvSRClf_JI,15460
106
99
  model_compression_toolkit/keras/graph_substitutions/substitutions/separableconv_decomposition.py,sha256=rL3dqnRXN7vj8PkBcdP73fwVV__BajBTRLMF8ilViBc,7609
107
100
  model_compression_toolkit/keras/graph_substitutions/substitutions/shift_negative_activation.py,sha256=iBCF1K2lHCnVlmjaeE6GfbVRHZuH3jmQcGce9DOoYkQ,23234
108
101
  model_compression_toolkit/keras/mixed_precision/__init__.py,sha256=vXN-Q5V_3byIScc1j2ePqHBwR60knX1Vy6-Oh-Ke5sk,698
109
- model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py,sha256=aBoWVKmfqRcCpJKLGKASndZe7aJS1X14PLBkXLLTnAg,11477
102
+ model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py,sha256=1SnCo59VmcJVwb9gJD5poJrObytcltznGjf9sWniqZw,11399
110
103
  model_compression_toolkit/keras/quantizer/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
111
104
  model_compression_toolkit/keras/quantizer/base_quantizer.py,sha256=WoD4UrgXcCjZSXPnOX9slxJLrYDYN7go6nKs-bWjfuo,1735
112
105
  model_compression_toolkit/keras/quantizer/fake_quant_builder.py,sha256=jtGyNrTGFt6PFoPG9eQA56eIKUDLsulDAsde7o0mx2A,3342
@@ -125,7 +118,7 @@ model_compression_toolkit/keras/quantizer/mixed_precision/selective_quantizer.py
125
118
  model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py,sha256=zGiaXcUmbiNRIb2Q9E6BocyCL4cyQjuDpfguC9sFVf8,7521
126
119
  model_compression_toolkit/keras/reader/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
127
120
  model_compression_toolkit/keras/reader/common.py,sha256=KCsvLFOeFltb9NBqQ2nU0MHYVPDBCYDBKF3oeoaangg,2417
128
- model_compression_toolkit/keras/reader/connectivity_handler.py,sha256=Pj4nmZl0C2K7iIbaC4PYS5Uvd82flvA9ctbI6DYYB7E,11115
121
+ model_compression_toolkit/keras/reader/connectivity_handler.py,sha256=-AYoapd40V25yt1h9MqDzXUDRVUOYho_h6F-U5zDXoY,11139
129
122
  model_compression_toolkit/keras/reader/node_builder.py,sha256=N9KSkY7AaZeq0-Zxi5MhKdCsTW2qmFp1yjqk5aPVPpA,3420
130
123
  model_compression_toolkit/keras/reader/reader.py,sha256=0i9GT3rTWDDK5fKbtt6PTqFpM59obk8mK0SDXVbycn4,7939
131
124
  model_compression_toolkit/keras/reader/nested_model/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
@@ -134,9 +127,9 @@ model_compression_toolkit/keras/reader/nested_model/nested_model_handler.py,sha2
134
127
  model_compression_toolkit/keras/reader/nested_model/nodes_merger.py,sha256=iI63gMLuE4_C-8aKSOLT3xRowViI3Je8aN8r5A37vy0,2167
135
128
  model_compression_toolkit/keras/reader/nested_model/outputs_merger.py,sha256=bTzVEtXk-7zrtGN7_NZSAKSuCeBWTXd_sPe7RFYgHUQ,2381
136
129
  model_compression_toolkit/keras/visualization/__init__.py,sha256=fpP2DhMfJXbHLoq20ziy48ymPyevl5N7P2U1fzlYRSY,699
137
- model_compression_toolkit/keras/visualization/nn_visualizer.py,sha256=ya6r3l_O5i6RJCYMean8uoKiv8BXOMK9298nI-yBCe4,4789
138
- mct_nightly-0.0.0.dist-info/LICENSE,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
139
- mct_nightly-0.0.0.dist-info/METADATA,sha256=W_28CTJ-lAtI2LQUNPasjNH9w0SkudUidwmZzmTYbhw,5879
140
- mct_nightly-0.0.0.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92
141
- mct_nightly-0.0.0.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
142
- mct_nightly-0.0.0.dist-info/RECORD,,
130
+ model_compression_toolkit/keras/visualization/nn_visualizer.py,sha256=AgrgGRHmV8M_NtFH5JzqmjAfBW2Q6PElHWwAq9WRSRU,4726
131
+ mct_nightly-1.1.0.1122021.post3325.dist-info/LICENSE,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
132
+ mct_nightly-1.1.0.1122021.post3325.dist-info/METADATA,sha256=_IHYbd9ovG8mCilk8viHqm1sTN9_kYRduu9amEN16XY,5950
133
+ mct_nightly-1.1.0.1122021.post3325.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92
134
+ mct_nightly-1.1.0.1122021.post3325.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
135
+ mct_nightly-1.1.0.1122021.post3325.dist-info/RECORD,,
@@ -13,8 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
-
17
- from model_compression_toolkit.common.gptq.gptq_config import GradientPTQConfig
16
+ from model_compression_toolkit.keras.quantization_facade import keras_post_training_quantization, keras_post_training_quantization_mixed_precision
17
+ from model_compression_toolkit.keras.gradient_ptq.training_wrapper import GradientPTQConfig
18
18
  from model_compression_toolkit.common.quantization import quantization_config
19
19
  from model_compression_toolkit.common.mixed_precision import mixed_precision_quantization_config
20
20
  from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig, \
@@ -27,8 +27,4 @@ from model_compression_toolkit.common.framework_info import FrameworkInfo
27
27
  from model_compression_toolkit.common.defaultdict import DefaultDict
28
28
  from model_compression_toolkit.common import network_editors as network_editor
29
29
 
30
- from model_compression_toolkit.keras.quantization_facade import keras_post_training_quantization, \
31
- keras_post_training_quantization_mixed_precision
32
-
33
-
34
30
  __version__ = "1.1.0"
@@ -20,6 +20,7 @@ from typing import Any
20
20
  from model_compression_toolkit.common.graph.base_graph import Graph
21
21
  from model_compression_toolkit.common.matchers.base_matcher import BaseMatcher
22
22
 
23
+
23
24
  class BaseSubstitution(ABC):
24
25
  """
25
26
  Base class for all substitution classes.
@@ -14,13 +14,13 @@
14
14
  # ==============================================================================
15
15
  import copy
16
16
 
17
- from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
18
17
  from model_compression_toolkit.common.framework_info import FrameworkInfo
19
18
  from model_compression_toolkit.common import Graph, Node
19
+ from model_compression_toolkit.keras.constants import BIAS, USE_BIAS
20
+
20
21
 
21
22
  def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
22
- fw_info: FrameworkInfo,
23
- fw_impl: FrameworkImplementation) -> Graph:
23
+ fw_info: FrameworkInfo) -> Graph:
24
24
  """
25
25
  Get a graph, where each node has a final weights quantization configuration (with a bias
26
26
  correction term in it), and apply the bias correction for each node in the graph.
@@ -28,7 +28,6 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
28
28
  Args:
29
29
  graph_to_apply_bias_correction: Graph to apply bias correction to.
30
30
  fw_info: Framework information (e.g, operators to quantize their weights).
31
- fw_impl: FrameworkImplementation object with a specific framework methods implementation.
32
31
 
33
32
  Returns:
34
33
  Graph with bias correction apply to its' nodes.
@@ -40,28 +39,26 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
40
39
  # If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg,
41
40
  # a bias correction term was calculated during model preparation, and is used now in the node's bias term.
42
41
  if n.final_weights_quantization_cfg.weights_bias_correction:
43
- _apply_bias_correction_to_node(n, fw_impl)
42
+ _apply_bias_correction_to_node(n)
44
43
  return graph
45
44
 
46
45
 
47
- def _apply_bias_correction_to_node(node:Node,
48
- fw_impl: FrameworkImplementation):
46
+ def _apply_bias_correction_to_node(node:Node):
49
47
  """
50
48
  Set new bias to node using the bias-correction term that is stored in the
51
49
  final weights quantization configuration.
52
50
 
53
51
  Args:
54
52
  node: Node to set its corrected bias after bias-correction.
55
- fw_impl: FrameworkImplementation object with a specific framework methods implementation.
56
53
 
57
54
  """
58
55
  correction = node.final_weights_quantization_cfg.bias_corrected
59
56
 
60
- bias = node.get_weights_by_keys(fw_impl.constants.BIAS) # get original bias from node's weights
57
+ bias = node.get_weights_by_keys(BIAS) # get original bias from node's weights
61
58
 
62
59
  if bias is not None: # It the layer has bias, we subtract the correction from original bias
63
- node.set_weights_by_keys(fw_impl.constants.BIAS, node.get_weights_by_keys(fw_impl.constants.BIAS) - correction)
60
+ node.set_weights_by_keys(BIAS, node.get_weights_by_keys(BIAS) - correction)
64
61
 
65
62
  else: # It the layer has no bias, we consider it as if it has and its value is 0.
66
- node.set_weights_by_keys(fw_impl.constants.BIAS, - correction)
67
- node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node.
63
+ node.set_weights_by_keys(BIAS, - correction)
64
+ node.framework_attr[USE_BIAS] = True # Mark the use_bias attribute of the node.
@@ -12,22 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
15
  import copy
17
16
  from typing import Any
18
17
 
19
18
  import numpy as np
20
19
 
21
- from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
22
20
  from model_compression_toolkit.common.framework_info import FrameworkInfo
23
21
  from model_compression_toolkit.common import Node, Logger, Graph
24
22
  from model_compression_toolkit.common.quantization.quantize_node import get_quantized_kernel_by_weights_qc
25
23
  from model_compression_toolkit.common.statistics_collector import BaseStatsContainer
24
+ from model_compression_toolkit.keras.constants import KERNEL
26
25
 
27
26
 
28
27
  def compute_bias_correction_of_graph(graph_co_compute_bias: Graph,
29
- fw_info: FrameworkInfo,
30
- fw_impl: FrameworkImplementation) -> Graph:
28
+ fw_info: FrameworkInfo) -> Graph:
31
29
  """
32
30
  For each node in a graph, and for each candidate weights quantization configuration,
33
31
  compute the bias-correction term, and store it in the candidate weights quantization configuration.
@@ -36,7 +34,6 @@ def compute_bias_correction_of_graph(graph_co_compute_bias: Graph,
36
34
  graph_co_compute_bias: Graph with nodes to compute the bias correction for
37
35
  each node's weights quantization configuration candidates.
38
36
  fw_info: Framework info like lists of nodes their kernel should quantized.
39
- fw_impl: FrameworkImplementation object with a specific framework methods implementation.
40
37
 
41
38
  Returns:
42
39
  Graph with bias correction for each weights quantization configuration candidate
@@ -46,17 +43,13 @@ def compute_bias_correction_of_graph(graph_co_compute_bias: Graph,
46
43
  graph = copy.deepcopy(graph_co_compute_bias)
47
44
  for n in graph.nodes:
48
45
  if fw_info.in_kernel_ops(n):
49
- _compute_bias_correction_per_candidate_qc(n,
50
- fw_info,
51
- graph.get_in_stats_collector(n),
52
- fw_impl=fw_impl)
46
+ _compute_bias_correction_per_candidate_qc(n, fw_info, graph.get_in_stats_collector(n))
53
47
  return graph
54
48
 
55
49
 
56
50
  def _compute_bias_correction_per_candidate_qc(node: Node,
57
51
  fw_info: FrameworkInfo,
58
- node_in_stats_collector: BaseStatsContainer,
59
- fw_impl: FrameworkImplementation):
52
+ node_in_stats_collector: BaseStatsContainer):
60
53
  """
61
54
  For each candidate weights quantization configuration of a given node,
62
55
  compute the bias-correction term, and store it in the candidate weights quantization configuration.
@@ -65,7 +58,6 @@ def _compute_bias_correction_per_candidate_qc(node: Node,
65
58
  node: Node to compute the bias correction for its different candidates.
66
59
  fw_info: Framework info like lists of nodes their kernel should quantized.
67
60
  node_in_stats_collector: Statistics collector of the node for the mean per-channel.
68
- fw_impl: FrameworkImplementation object with a specific framework methods implementation.
69
61
 
70
62
  """
71
63
 
@@ -74,8 +66,7 @@ def _compute_bias_correction_per_candidate_qc(node: Node,
74
66
  if fw_info.in_kernel_ops(node) and weights_qc.enable_weights_quantization:
75
67
  quantized_kernel, io_channels_axes = get_quantized_kernel_by_weights_qc(fw_info,
76
68
  node,
77
- weights_qc,
78
- fw_impl=fw_impl)
69
+ weights_qc)
79
70
 
80
71
  # If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg,
81
72
  # a bias correction term is being calculated and used in the node's bias term.
@@ -84,8 +75,7 @@ def _compute_bias_correction_per_candidate_qc(node: Node,
84
75
  node,
85
76
  node_in_stats_collector,
86
77
  io_channels_axes[1],
87
- quantized_kernel,
88
- fw_impl=fw_impl)
78
+ quantized_kernel)
89
79
 
90
80
  # Store the correction term to use it later,
91
81
  weights_qc.bias_corrected = bias_correction_term
@@ -135,8 +125,7 @@ def _get_bias_correction_term_of_node(input_channels_axis: int,
135
125
  n: Node,
136
126
  node_in_stats_collector: BaseStatsContainer,
137
127
  output_channels_axis: int,
138
- quantized_kernel: np.ndarray,
139
- fw_impl: FrameworkImplementation):
128
+ quantized_kernel: np.ndarray):
140
129
  """
141
130
  Get the bias correction term for a node, using a quantized kernel (which can be quantized
142
131
  using any possible bit width)
@@ -147,8 +136,6 @@ def _get_bias_correction_term_of_node(input_channels_axis: int,
147
136
  node_in_stats_collector: Input statistics collector of the node.
148
137
  output_channels_axis: Index of output channels of the kernel.
149
138
  quantized_kernel: Quantized kernel of the node.
150
- fw_impl: FrameworkImplementation object with a specific framework methods implementation.
151
-
152
139
 
153
140
  Returns:
154
141
  Bias-correction term to subtract from the current node's bias.
@@ -163,7 +150,7 @@ def _get_bias_correction_term_of_node(input_channels_axis: int,
163
150
  f'Unknown input channel axis for node named: {n.name},'
164
151
  f' please update channel mapping function')
165
152
  # Compute the bias correction term.
166
- correction = _compute_bias_correction(n.get_weights_by_keys(fw_impl.constants.KERNEL),
153
+ correction = _compute_bias_correction(n.get_weights_by_keys(KERNEL),
167
154
  quantized_kernel,
168
155
  node_in_stats_collector,
169
156
  output_channels_axis,
@@ -74,7 +74,7 @@ class HistogramCollector(BaseCollector):
74
74
  merged_histogram_min = np.min(bins_stack)
75
75
  merged_histogram_max = np.max(bins_stack)
76
76
  merged_bin_width = np.min(bins_stack[:, 1] - bins_stack[:, 0])
77
- merged_histogram_bins = np.arange(merged_histogram_min, merged_histogram_max+merged_bin_width, merged_bin_width)
77
+ merged_histogram_bins = np.arange(merged_histogram_min, merged_histogram_max, merged_bin_width)
78
78
 
79
79
  merged_histogram_counts = None
80
80
  for histogram in self.__histogram_per_iteration: # Iterate all collected histograms and merge them
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from collections import namedtuple
15
+
16
16
 
17
17
  from copy import copy, deepcopy
18
18
  from typing import List
@@ -29,11 +29,9 @@ from model_compression_toolkit.common.graph.node import Node
29
29
  from model_compression_toolkit.common.statistics_collector import BaseStatsContainer
30
30
  from model_compression_toolkit.common.statistics_collector import scale_statistics, shift_statistics
31
31
  from model_compression_toolkit.common.user_info import UserInformation
32
+ from model_compression_toolkit.keras.reader.connectivity_handler import OutTensor
32
33
  from model_compression_toolkit.common.logger import Logger
33
34
 
34
- OutTensor = namedtuple('OutTensor', 'node node_out_index')
35
-
36
-
37
35
  class Graph(nx.MultiDiGraph, GraphSearches):
38
36
  """
39
37
  Base graph representing a model to be optimized.
@@ -19,7 +19,9 @@ from typing import Any, List
19
19
  from tensorflow.keras.layers import Layer
20
20
 
21
21
  from model_compression_toolkit.common.graph.node import Node
22
- from model_compression_toolkit.common.matchers import node_matcher, walk_matcher, edge_matcher
22
+ from model_compression_toolkit.common.matchers import edge_matcher
23
+ from model_compression_toolkit.common.matchers import node_matcher
24
+ from model_compression_toolkit.common.matchers import walk_matcher
23
25
 
24
26
 
25
27
  class NodeOperationMatcher(node_matcher.BaseNodeMatcher):
@@ -17,7 +17,9 @@ from abc import ABC
17
17
  from typing import List, Any
18
18
 
19
19
  from model_compression_toolkit.common.graph.node import Node
20
- from model_compression_toolkit.common.matchers import node_matcher, base_graph_filter, edge_matcher
20
+ from model_compression_toolkit.common.matchers import base_graph_filter
21
+ from model_compression_toolkit.common.matchers import edge_matcher
22
+ from model_compression_toolkit.common.matchers import node_matcher
21
23
  from model_compression_toolkit.common.matchers.walk_matcher import WalkMatcherList
22
24
 
23
25
 
@@ -17,7 +17,6 @@ import copy
17
17
 
18
18
  from typing import Any, List
19
19
 
20
- from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
21
20
  from model_compression_toolkit.common import Graph, Node
22
21
  from model_compression_toolkit.common.framework_info import FrameworkInfo
23
22
  from model_compression_toolkit.common.logger import Logger
@@ -26,7 +25,7 @@ from model_compression_toolkit.common.mixed_precision.mixed_precision_quantizati
26
25
 
27
26
 
28
27
 
29
- def set_bit_widths(quant_config: QuantizationConfig,
28
+ def set_bit_widths(quant_config: MixedPrecisionQuantizationConfig,
30
29
  graph_to_set_bit_widths: Graph,
31
30
  fw_info: FrameworkInfo = None,
32
31
  bit_widths_config: List[int] = None) -> Graph:
@@ -16,6 +16,7 @@
16
16
  from typing import Any
17
17
  from model_compression_toolkit.common.matchers.node_matcher import BaseNodeMatcher
18
18
 
19
+
19
20
  class NodeTypeFilter(BaseNodeMatcher):
20
21
  """
21
22
  Class NodeNameFilter to check if a node is of a specific type.
@@ -20,7 +20,7 @@ from model_compression_toolkit.common.constants import MIN_THRESHOLD
20
20
  from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_search import qparams_selection_tensor_search, qparams_selection_histogram_search
21
21
  from model_compression_toolkit.common.similarity_analyzer import compute_lp_norm
22
22
 
23
- from model_compression_toolkit.common.constants import THRESHOLD
23
+ from model_compression_toolkit.keras.constants import THRESHOLD
24
24
 
25
25
 
26
26
  def lp_selection_tensor(tensor_data: np.ndarray,
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
  from typing import List
16
16
 
17
- from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
18
17
  from model_compression_toolkit.common.framework_info import FrameworkInfo
19
18
  from model_compression_toolkit.common import Graph, Node, Logger
20
19
  from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_activations_computation \
@@ -22,13 +21,13 @@ from model_compression_toolkit.common.quantization.quantization_params_generatio
22
21
  get_activations_qparams
23
22
  from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_weights_computation import \
24
23
  get_weights_qparams, get_channels_axis
24
+ from model_compression_toolkit.keras.constants import KERNEL
25
25
 
26
26
 
27
27
  def calculate_quantization_params(graph: Graph,
28
28
  fw_info: FrameworkInfo,
29
29
  nodes: List[Node] = [],
30
- specific_nodes: bool = False,
31
- fw_impl: FrameworkImplementation = None):
30
+ specific_nodes: bool = False):
32
31
  """
33
32
  For a graph, go over its nodes, compute quantization params (for both weights and activations according
34
33
  to the given framework info), and create and attach a NodeQuantizationConfig to each node (containing the
@@ -43,7 +42,6 @@ def calculate_quantization_params(graph: Graph,
43
42
  graph: Graph to compute its nodes' thresholds.
44
43
  nodes: List of nodes to compute their thresholds instead of computing it for all nodes in the graph.
45
44
  specific_nodes: Flag to compute thresholds for only specific nodes.
46
- fw_impl: FrameworkImplementation with specific framework implementations.
47
45
 
48
46
  """
49
47
 
@@ -59,7 +57,7 @@ def calculate_quantization_params(graph: Graph,
59
57
 
60
58
  for candidtae_qc in n.candidates_weights_quantization_cfg:
61
59
  output_channels_axis, _ = get_channels_axis(candidtae_qc, fw_info, n.layer_class)
62
- weights_params = get_weights_qparams(n.get_weights_by_keys(fw_impl.constants.KERNEL),
60
+ weights_params = get_weights_qparams(n.get_weights_by_keys(KERNEL),
63
61
  candidtae_qc,
64
62
  output_channels_axis)
65
63
 
@@ -16,15 +16,14 @@
16
16
  import copy
17
17
 
18
18
  from model_compression_toolkit import common
19
- from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
20
19
  from model_compression_toolkit.common.framework_info import FrameworkInfo
21
20
  from model_compression_toolkit.common.graph.base_graph import Graph
22
21
  from model_compression_toolkit.common.quantization.quantize_node import get_quantized_kernel_by_weights_qc
22
+ from model_compression_toolkit.keras.constants import KERNEL
23
23
 
24
24
 
25
25
  def quantize_graph_weights(graph_to_quantize: Graph,
26
- fw_info: FrameworkInfo,
27
- fw_impl: FrameworkImplementation) -> Graph:
26
+ fw_info: FrameworkInfo) -> Graph:
28
27
  """
29
28
  Get a graph representing a model, and quantize its nodes' weights.
30
29
  Each node is quantized according to the passed framework info and quantization configuration.
@@ -34,7 +33,6 @@ def quantize_graph_weights(graph_to_quantize: Graph,
34
33
  Args:
35
34
  graph_to_quantize: Graph to quantize its nodes.
36
35
  fw_info: Framework information needed for quantizing the graph's nodes' weights and activations.
37
- fw_impl: FrameworkImplementation with specific framework implementations.
38
36
 
39
37
  """
40
38
  graph = copy.deepcopy(graph_to_quantize)
@@ -45,14 +43,13 @@ def quantize_graph_weights(graph_to_quantize: Graph,
45
43
  if fw_info.in_kernel_ops(n) and n.final_weights_quantization_cfg.enable_weights_quantization:
46
44
  quantized_kernel, io_channels_axes = get_quantized_kernel_by_weights_qc(fw_info,
47
45
  n,
48
- n.final_weights_quantization_cfg,
49
- fw_impl=fw_impl)
46
+ n.final_weights_quantization_cfg)
50
47
 
51
48
  common.Logger.debug(
52
49
  f'Node name: {n.name} has the following quantization params: '
53
50
  f'{str(n.final_weights_quantization_cfg.weights_quantization_params)}')
54
51
 
55
52
  # Set the kernel node to be the quantized kernel.
56
- n.set_weights_by_keys(fw_impl.constants.KERNEL, quantized_kernel)
53
+ n.set_weights_by_keys(KERNEL, quantized_kernel)
57
54
 
58
55
  return graph
@@ -18,18 +18,17 @@ import copy
18
18
 
19
19
  from model_compression_toolkit import common
20
20
  from model_compression_toolkit.common import Logger
21
- from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
22
21
  from model_compression_toolkit.common.graph.node import Node
23
22
  from model_compression_toolkit.common.framework_info import FrameworkInfo
24
23
  from model_compression_toolkit.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
25
24
  from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_weights_computation import \
26
25
  get_channels_axis
26
+ from model_compression_toolkit.keras.constants import KERNEL
27
27
 
28
28
 
29
29
  def get_quantized_kernel_by_weights_qc(fw_info:FrameworkInfo,
30
30
  n:Node,
31
- weights_qc: NodeWeightsQuantizationConfig,
32
- fw_impl: FrameworkImplementation):
31
+ weights_qc: NodeWeightsQuantizationConfig):
33
32
  """
34
33
  For a node and a weights quantization configuration, compute
35
34
  the quantized kernel of the node and return it and the input/output channels indices.
@@ -38,7 +37,6 @@ def get_quantized_kernel_by_weights_qc(fw_info:FrameworkInfo,
38
37
  fw_info: A FrameworkInfo object Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
39
38
  n: Node to quantize its kernel.
40
39
  weights_qc: Weight quantization configuration to use for the quantization.
41
- fw_impl: FrameworkImplementation with specific framework implementations.
42
40
 
43
41
  Returns:
44
42
  A quantized kernel of the node using a weights quantization configuration.
@@ -55,7 +53,7 @@ def get_quantized_kernel_by_weights_qc(fw_info:FrameworkInfo,
55
53
  n.layer_class)
56
54
 
57
55
  Logger.debug(f'quantizing {n.name} with {weights_qc.weights_n_bits} bits')
58
- quantized_kernel = weights_qc.weights_quantization_fn(n.get_weights_by_keys(fw_impl.constants.KERNEL),
56
+ quantized_kernel = weights_qc.weights_quantization_fn(n.get_weights_by_keys(KERNEL),
59
57
  n_bits=weights_qc.weights_n_bits,
60
58
  signed=True,
61
59
  quantization_params=weights_qc.weights_quantization_params,
@@ -13,3 +13,5 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+
17
+ from model_compression_toolkit.keras import quantizer