mct-nightly 0.0.0__py3-none-any.whl → 1.1.0.01122021-003325__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/METADATA +3 -2
  2. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/RECORD +31 -38
  3. model_compression_toolkit/__init__.py +2 -6
  4. model_compression_toolkit/common/base_substitutions.py +1 -0
  5. model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +9 -12
  6. model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +8 -21
  7. model_compression_toolkit/common/collectors/histogram_collector.py +1 -1
  8. model_compression_toolkit/common/graph/base_graph.py +2 -4
  9. model_compression_toolkit/common/graph/graph_matchers.py +3 -1
  10. model_compression_toolkit/common/graph/graph_searches.py +3 -1
  11. model_compression_toolkit/common/mixed_precision/bit_width_setter.py +1 -2
  12. model_compression_toolkit/common/network_editors/node_filters.py +1 -0
  13. model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py +1 -1
  14. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +3 -5
  15. model_compression_toolkit/common/quantization/quantize_graph_weights.py +4 -7
  16. model_compression_toolkit/common/quantization/quantize_node.py +3 -5
  17. model_compression_toolkit/keras/__init__.py +2 -0
  18. model_compression_toolkit/keras/back2framework/model_builder.py +24 -1
  19. model_compression_toolkit/{common → keras/back2framework}/model_collector.py +9 -18
  20. model_compression_toolkit/keras/default_framework_info.py +0 -1
  21. model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +57 -10
  22. model_compression_toolkit/keras/graph_substitutions/substituter.py +171 -0
  23. model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +26 -6
  24. model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +12 -5
  25. model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +3 -4
  26. model_compression_toolkit/keras/quantization_facade.py +524 -188
  27. model_compression_toolkit/keras/reader/connectivity_handler.py +4 -1
  28. model_compression_toolkit/keras/visualization/nn_visualizer.py +1 -2
  29. model_compression_toolkit/common/framework_implementation.py +0 -239
  30. model_compression_toolkit/common/gptq/__init__.py +0 -14
  31. model_compression_toolkit/common/gptq/gptq_config.py +0 -65
  32. model_compression_toolkit/common/model_builder_mode.py +0 -34
  33. model_compression_toolkit/common/post_training_quantization.py +0 -459
  34. model_compression_toolkit/common/substitutions/__init__.py +0 -14
  35. model_compression_toolkit/common/substitutions/apply_substitutions.py +0 -40
  36. model_compression_toolkit/keras/keras_implementation.py +0 -256
  37. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/LICENSE +0 -0
  38. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/WHEEL +0 -0
  39. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/top_level.txt +0 -0
@@ -13,198 +13,534 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Callable, List
16
+
17
+ import copy
18
+ import os
19
+ from functools import partial
20
+ from typing import Callable, List, Tuple
21
+
22
+ from tensorflow.keras.models import Model
23
+ from tqdm import tqdm
17
24
 
18
25
  from model_compression_toolkit import common
19
- from model_compression_toolkit.common import Logger
20
- from model_compression_toolkit.common.gptq.gptq_config import GradientPTQConfig
21
26
  from model_compression_toolkit.common.mixed_precision.kpi import KPI
22
- from model_compression_toolkit.common.framework_info import FrameworkInfo
27
+ from model_compression_toolkit.common import FrameworkInfo
28
+ from model_compression_toolkit.common.constants import NUM_SAMPLES_CS_TENSORBOARD
29
+ from model_compression_toolkit.common.graph.base_graph import Graph
30
+ from model_compression_toolkit.common.mixed_precision.bit_width_setter import set_bit_widths
31
+
32
+ from model_compression_toolkit.common.mixed_precision.mixed_precision_search_facade import search_bit_width
23
33
  from model_compression_toolkit.common.network_editors.actions import EditRule
34
+ from model_compression_toolkit.common.network_editors.edit_network import edit_network_graph
24
35
  from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
25
36
  MixedPrecisionQuantizationConfig, DEFAULT_MIXEDPRECISION_CONFIG
26
- from model_compression_toolkit.common.post_training_quantization import post_training_quantization
27
- from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
28
- from model_compression_toolkit.common.quantization.quantization_config import DEFAULTCONFIG
29
-
30
- import importlib
31
-
32
- if importlib.util.find_spec("tensorflow") is not None\
33
- and importlib.util.find_spec("tensorflow_model_optimization") is not None:
34
- from model_compression_toolkit.keras.default_framework_info import DEFAULT_KERAS_INFO
35
- from model_compression_toolkit.keras.keras_implementation import KerasImplementation
36
- from tensorflow.keras.models import Model
37
-
38
- def keras_post_training_quantization(in_model: Model,
39
- representative_data_gen: Callable,
40
- n_iter: int = 500,
41
- quant_config: QuantizationConfig = DEFAULTCONFIG,
42
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
43
- network_editor: List[EditRule] = [],
44
- gptq_config: GradientPTQConfig = None,
45
- analyze_similarity: bool = False):
46
- """
47
- Quantize a trained Keras model using post-training quantization. The model is quantized using a
48
- symmetric constraint quantization thresholds (power of two).
49
- The model is first optimized using several transformations (e.g. BatchNormalization folding to
50
- preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
51
- being collected for each layer's output (and input, depends on the quantization configuration).
52
- Thresholds are then being calculated using the collected statistics and the model is quantized
53
- (both coefficients and activations by default).
54
- If a gptq configuration is passed, the quantized weights are optimized using gradient based post
55
- training quantization by comparing points between the float and quantized models, and minimizing the observed loss.
56
-
57
- Args:
58
- in_model (Model): Keras model to quantize.
59
- representative_data_gen (Callable): Dataset used for calibration.
60
- n_iter (int): Number of calibration iterations to run.
61
- quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized. `Default configuration. <https://github.com/sony/model_optimization/blob/21e21c95ca25a31874a5be7af9dd2dd5da8f3a10/model_compression_toolkit/common/quantization/quantization_config.py#L154>`_
62
- fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/21e21c95ca25a31874a5be7af9dd2dd5da8f3a10/model_compression_toolkit/keras/default_framework_info.py#L113>`_
63
- network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
64
- gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
65
- analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
66
-
67
- Returns:
68
- A quantized model and information the user may need to handle the quantized model.
69
-
70
- Examples:
71
- Import a Keras model:
72
-
73
- >>> from tensorflow.keras.applications.mobilenet import MobileNet
74
- >>> model = MobileNet()
75
-
76
- Create a random dataset generator:
77
-
78
- >>> import numpy as np
79
- >>> def repr_datagen(): return [np.random.random((1,224,224,3))]
80
-
81
- Import mct and pass the model with the representative dataset generator to get a quantized model:
82
-
83
- >>> import model_compression_toolkit as mct
84
- >>> quantized_model, quantization_info = mct.keras_post_training_quantization(model, repr_datagen)
85
-
86
- """
87
-
88
- return post_training_quantization(in_model,
89
- representative_data_gen,
90
- n_iter,
91
- quant_config,
92
- fw_info,
93
- KerasImplementation(),
94
- network_editor,
95
- gptq_config,
96
- analyze_similarity)
97
-
98
-
99
- def keras_post_training_quantization_mixed_precision(in_model: Model,
100
- representative_data_gen: Callable,
101
- n_iter: int = 500,
102
- quant_config: MixedPrecisionQuantizationConfig = DEFAULT_MIXEDPRECISION_CONFIG,
103
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
104
- network_editor: List[EditRule] = [],
105
- gptq_config: GradientPTQConfig = None,
106
- analyze_similarity: bool = False,
107
- target_kpi: KPI = None):
108
- """
109
- Quantize a trained Keras model using post-training quantization. The model is quantized using a
110
- symmetric constraint quantization thresholds (power of two).
111
- The model is first optimized using several transformations (e.g. BatchNormalization folding to
112
- preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
113
- being collected for each layer's output (and input, depends on the quantization configuration).
114
- For each possible bit width (per layer) a threshold is then being calculated using the collected
115
- statistics. Then, using an ILP solver we find a mixed-precision configuration, and set a bit width
116
- for each layer. The model is then quantized (both coefficients and activations by default).
117
- In order to limit the maximal model's size, a target KPI can be passed after weights_memory
118
- is set (in bytes).
119
- For now, mixed precision is supported for weights only.
120
- If a gptq configuration is passed, the quantized weights are optimized using gradient based post
121
- training quantization by comparing points between the float and quantized models, and minimizing the observed loss.
122
- Notice that this feature is experimental.
123
-
124
- Args:
125
- in_model (Model): Keras model to quantize.
126
- representative_data_gen (Callable): Dataset used for calibration.
127
- n_iter (int): Number of calibration iterations to run.
128
- quant_config (MixedPrecisionQuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized.
129
- fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/keras/default_framework_info.py#L100>`_
130
- network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
131
- gptq_config (GradientPTQConfig): Configuration for using GPTQ (e.g. optimizer).
132
- analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
133
- target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
134
-
135
- Returns:
136
- A quantized model and information the user may need to handle the quantized model.
137
-
138
- Examples:
139
- Import MCT:
140
-
141
- >>> import model_compression_toolkit as mct
142
-
143
- Import a Keras model:
144
-
145
- >>> from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
146
- >>> model = MobileNetV2()
147
-
148
- Create a random dataset generator:
149
-
150
- >>> import numpy as np
151
- >>> def repr_datagen(): return [np.random.random((1,224,224,3))]
152
-
153
- Create a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
154
- Here, each layer can be quantized by 2, 4 or 8 bits:
155
-
156
- >>> config = mct.MixedPrecisionQuantizationConfig(weights_n_bits=[4, 2, 8])
157
-
158
- Create a KPI object to limit our returned model's size. Note that this value affects only coefficients that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value, while the bias will not):
159
-
160
- >>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
161
-
162
- Pass the model, the representative dataset generator, the configuration and the target KPI to get a quantized model:
163
-
164
- >>> quantized_model, quantization_info = mct.keras_post_training_quantization_mixed_precision(model, repr_datagen, n_iter=10, quant_config=config, target_kpi=kpi)
165
-
166
- For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
167
-
168
- """
169
-
170
- if target_kpi is None:
171
- common.Logger.warning("No KPI was passed. Using non mixed-precision compression process...")
172
- # Before starting non-mixed-precision process, we need to set only single bit width, so we take the best
173
- # option which is the maximal number of bits.
174
- quant_config.weights_n_bits = [max(quant_config.weights_n_bits)]
175
- return keras_post_training_quantization(in_model,
176
- representative_data_gen,
177
- n_iter,
178
- quant_config,
179
- fw_info,
180
- network_editor,
181
- gptq_config,
182
- analyze_similarity)
183
-
184
- common.Logger.info("Using experimental mixed-precision quantization. "
185
- "If you encounter an issue please file a bug.")
186
-
187
- return post_training_quantization(in_model,
188
- representative_data_gen,
189
- n_iter,
190
- quant_config,
191
- fw_info,
192
- KerasImplementation(),
193
- network_editor,
194
- gptq_config,
195
- analyze_similarity,
196
- target_kpi)
197
-
198
- else:
199
- # If tensorflow or tensorflow_model_optimization are not installed,
200
- # we raise an exception when trying to use these functions.
201
- def keras_post_training_quantization(*args, **kwargs):
202
- Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
203
- 'when using keras_post_training_quantization. '
204
- 'Could not find Tensorflow package.')
205
-
206
- def keras_post_training_quantization_mixed_precision(*args, **kwargs):
207
- Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
208
- 'when using keras_post_training_quantization_mixed_precision. '
209
- 'Could not find Tensorflow package.')
37
+ from model_compression_toolkit.common.quantization.quantize_graph_weights import quantize_graph_weights
38
+ from model_compression_toolkit.common.bias_correction.compute_bias_correction_of_graph import compute_bias_correction_of_graph
210
39
 
40
+ from model_compression_toolkit.common.quantization.quantization_analyzer import analyzer_graph
41
+ from model_compression_toolkit.common.quantization.quantization_config import DEFAULTCONFIG
42
+ from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
43
+ from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_computation import \
44
+ calculate_quantization_params
45
+
46
+ from model_compression_toolkit.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
47
+ from model_compression_toolkit.common.user_info import UserInformation
48
+ from model_compression_toolkit.keras.back2framework.model_builder import model_builder, ModelBuilderMode
49
+ from model_compression_toolkit.keras.back2framework.model_collector import ModelCollector
50
+ from model_compression_toolkit.keras.default_framework_info import DEFAULT_KERAS_INFO
51
+ from model_compression_toolkit.keras.graph_substitutions.substituter import graph_marking_substitute
52
+ from model_compression_toolkit.keras.graph_substitutions.substituter import post_statistics_collection_substitute
53
+ from model_compression_toolkit.keras.graph_substitutions.substituter import pre_statistics_collection_substitute
54
+ from model_compression_toolkit.keras.gradient_ptq.training_wrapper import GradientPTQConfig
55
+ from model_compression_toolkit.keras.gradient_ptq.training_wrapper import gptq_training_wrapper
56
+ from model_compression_toolkit.keras.mixed_precision.sensitivity_evaluation import get_sensitivity_evaluation
57
+ from model_compression_toolkit.keras.reader.reader import model_reader
58
+ from model_compression_toolkit.keras.tensor_marking import get_node_stats_collector
59
+ from model_compression_toolkit.common.visualization.tensorboard_writer import TensorboardWriter
60
+ from model_compression_toolkit.common.bias_correction.apply_bias_correction_to_graph import apply_bias_correction_to_graph
61
+ from model_compression_toolkit.keras.visualization.nn_visualizer import KerasNNVisualizer
62
+
63
+
64
+ def _prepare_model_for_quantization(in_model: Model,
65
+ representative_data_gen: Callable,
66
+ network_editor: List[EditRule] = [],
67
+ n_iter: int = 500,
68
+ quant_config: QuantizationConfig = DEFAULTCONFIG,
69
+ fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
70
+ tb_w: TensorboardWriter = None) -> Graph:
71
+ """
72
+ Prepare a trained Keras model for post-training quantization. The model is prepared to be quantized using a
73
+ symmetric constraint quantization thresholds (power of two).
74
+ The model is first read into a graph object and being optimized using several transformations (e.g.
75
+ BatchNormalization folding to preceding layers). Then, using a given dataset, statistics (e.g. min/max,
76
+ histogram, etc.) are being collected for each layer's output (and input, depends on the quantization configuration).
77
+ Thresholds are then being calculated using the collected statistics. Finally, more transformations (based on
78
+ statistics) are applied to increase model's performance.
79
+
80
+ Args:
81
+ in_model (Model): Keras model to optimize and prepare for quantization.
82
+ representative_data_gen (Callable): Dataset used for calibration.
83
+ network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to
84
+ change quantization settings of the filtered nodes.
85
+ n_iter (int): Number of calibration iterations to run.
86
+ quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
87
+ quantized.
88
+ fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g.,
89
+ kernel channels indices, groups of layers by how they should be quantized, etc.)
90
+ tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
91
+
92
+ Returns:
93
+ Graph object that represents the Keras model, contains thresholds, and ready for quantization.
94
+ """
95
+
96
+ ######################################
97
+ # Represent model in a graph
98
+ ######################################
99
+ graph = model_reader(in_model) # model reading
100
+
101
+ if tb_w is not None:
102
+ tb_w.add_graph(graph, 'initial_graph')
103
+
104
+ ######################################
105
+ # Graph substitution (pre statistics collection)
106
+ ######################################
107
+ transformed_graph = pre_statistics_collection_substitute(graph)
108
+
109
+ if tb_w is not None:
110
+ tb_w.add_graph(transformed_graph, 'pre_statistics_collection_substitutions')
111
+
112
+ ######################################
113
+ # Graph marking points
114
+ ######################################
115
+ transformed_graph = graph_marking_substitute(transformed_graph)
116
+
117
+ if tb_w is not None:
118
+ tb_w.add_graph(transformed_graph, 'after_graph_marking')
119
+
120
+ ######################################
121
+ # Graph analyzing (attaching statistics collectors)
122
+ ######################################
123
+ analyzer_graph(get_node_stats_collector,
124
+ transformed_graph,
125
+ fw_info,
126
+ quant_config) # Mark points for statistics collection
127
+
128
+ if tb_w is not None:
129
+ tb_w.add_graph(transformed_graph, 'after_analyzer_graph')
130
+
131
+ ######################################
132
+ # Statistic collection
133
+ ######################################
134
+ mi = ModelCollector(transformed_graph)
135
+ for _ in tqdm(range(n_iter)):
136
+ mi.infer(representative_data_gen())
137
+
138
+ ######################################
139
+ # Add quantization configurations
140
+ ######################################
141
+ transformed_graph = set_quantization_configuration_to_graph(transformed_graph, quant_config, fw_info)
142
+
143
+ ######################################
144
+ # Edit network according to user specific settings
145
+ ######################################
146
+ edit_network_graph(transformed_graph, fw_info, network_editor)
147
+
148
+ ######################################
149
+ # Calculate quantization params
150
+ ######################################
151
+ calculate_quantization_params(transformed_graph, fw_info)
152
+
153
+ if tb_w is not None:
154
+ tb_w.add_graph(transformed_graph, 'thresholds_selection')
155
+ tb_w.add_all_statistics(transformed_graph, 'thresholds_selection')
156
+
157
+ ######################################
158
+ # Graph substitution (post statistics collection)
159
+ ######################################
160
+ transformed_graph = post_statistics_collection_substitute(transformed_graph,
161
+ quant_config,
162
+ fw_info)
163
+
164
+ if tb_w is not None:
165
+ tb_w.add_graph(transformed_graph, 'post_statistics_collection_substitutions')
166
+ tb_w.add_all_statistics(transformed_graph, 'post_statistics_collection_substitutions')
167
+
168
+ ########################################################
169
+ # Compute bias correction to nodes' config candidates
170
+ ########################################################
171
+ tg_with_bias = compute_bias_correction_of_graph(transformed_graph, fw_info)
172
+
173
+ if tb_w is not None:
174
+ tb_w.add_graph(tg_with_bias, 'bias_correction_computation')
175
+
176
+ for n in tg_with_bias.nodes:
177
+ assert n.final_weights_quantization_cfg is None
178
+
179
+ return tg_with_bias
180
+
181
+
182
+ def keras_post_training_quantization(in_model: Model,
183
+ representative_data_gen: Callable,
184
+ n_iter: int = 500,
185
+ quant_config: QuantizationConfig = DEFAULTCONFIG,
186
+ fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
187
+ network_editor: List[EditRule] = [],
188
+ gptq_config: GradientPTQConfig = None,
189
+ analyze_similarity: bool = False):
190
+ """
191
+ Quantize a trained Keras model using post-training quantization. The model is quantized using a
192
+ symmetric constraint quantization thresholds (power of two).
193
+ The model is first optimized using several transformations (e.g. BatchNormalization folding to
194
+ preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
195
+ being collected for each layer's output (and input, depends on the quantization configuration).
196
+ Thresholds are then being calculated using the collected statistics and the model is quantized
197
+ (both coefficients and activations by default).
198
+ If a gptq configuration is passed, the quantized weights are optimized using gradient based post
199
+ training quantization by comparing points between the float and quantized models, and minimizing the observed loss.
200
+
201
+ Args:
202
+ in_model (Model): Keras model to quantize.
203
+ representative_data_gen (Callable): Dataset used for calibration.
204
+ n_iter (int): Number of calibration iterations to run.
205
+ quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized. `Default configuration. <https://github.com/sony/model_optimization/blob/21e21c95ca25a31874a5be7af9dd2dd5da8f3a10/model_compression_toolkit/common/quantization/quantization_config.py#L154>`_
206
+ fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/21e21c95ca25a31874a5be7af9dd2dd5da8f3a10/model_compression_toolkit/keras/default_framework_info.py#L113>`_
207
+ network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
208
+ gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
209
+ analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
210
+
211
+ Returns:
212
+ A quantized model and information the user may need to handle the quantized model.
213
+
214
+ Examples:
215
+ Import a Keras model:
216
+
217
+ >>> from tensorflow.keras.applications.mobilenet import MobileNet
218
+ >>> model = MobileNet()
219
+
220
+ Create a random dataset generator:
221
+
222
+ >>> import numpy as np
223
+ >>> def repr_datagen(): return [np.random.random((1,224,224,3))]
224
+
225
+ Import mct and pass the model with the representative dataset generator to get a quantized model:
226
+
227
+ >>> import model_compression_toolkit as mct
228
+ >>> quantized_model, quantization_info = mct.keras_post_training_quantization(model, repr_datagen)
229
+
230
+ """
231
+
232
+ if quant_config.weights_bias_correction and gptq_config is not None:
233
+ common.Logger.error('weights_bias_correction should be disabled in GPTQ mode')
234
+
235
+ tb_w = _init_tensorboard_writer()
236
+
237
+ tg = _prepare_model_for_quantization(in_model,
238
+ representative_data_gen,
239
+ network_editor,
240
+ n_iter,
241
+ quant_config,
242
+ fw_info,
243
+ tb_w)
244
+
245
+ ######################################
246
+ # Finalize bit widths
247
+ ######################################
248
+ tg = set_bit_widths(quant_config,
249
+ tg,
250
+ fw_info)
251
+
252
+ quantized_model, user_info = _quantize_fixed_bit_widths_graph(analyze_similarity,
253
+ fw_info,
254
+ gptq_config,
255
+ representative_data_gen,
256
+ tb_w,
257
+ tg)
258
+
259
+ return quantized_model, user_info
260
+
261
+
262
+ def keras_post_training_quantization_mixed_precision(in_model: Model,
263
+ representative_data_gen: Callable,
264
+ n_iter: int = 500,
265
+ quant_config: MixedPrecisionQuantizationConfig = DEFAULT_MIXEDPRECISION_CONFIG,
266
+ fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
267
+ network_editor: List[EditRule] = [],
268
+ gptq_config: GradientPTQConfig = None,
269
+ bit_widths_config: List[int] = None,
270
+ analyze_similarity: bool = False,
271
+ target_kpi: KPI = None):
272
+ """
273
+ Quantize a trained Keras model using post-training quantization. The model is quantized using a
274
+ symmetric constraint quantization thresholds (power of two).
275
+ The model is first optimized using several transformations (e.g. BatchNormalization folding to
276
+ preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
277
+ being collected for each layer's output (and input, depends on the quantization configuration).
278
+ For each possible bit width (per layer) a threshold is then being calculated using the collected
279
+ statistics. Then, using an ILP solver we find a mixed-precision configuration, and set a bit width
280
+ for each layer. The model is then quantized (both coefficients and activations by default).
281
+ In order to limit the maximal model's size, a target KPI can be passed after weights_memory
282
+ is set (in bytes).
283
+ For now, mixed precision is supported for weights only.
284
+ If a gptq configuration is passed, the quantized weights are optimized using gradient based post
285
+ training quantization by comparing points between the float and quantized models, and minimizing the observed loss.
286
+ Notice that this feature is experimental.
287
+
288
+ Args:
289
+ in_model (Model): Keras model to quantize.
290
+ representative_data_gen (Callable): Dataset used for calibration.
291
+ n_iter (int): Number of calibration iterations to run.
292
+ quant_config (MixedPrecisionQuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized.
293
+ fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/keras/default_framework_info.py#L100>`_
294
+ network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
295
+ gptq_config (GradientPTQConfig): Configuration for using GPTQ (e.g. optimizer).
296
+ bit_widths_config (List[int]): Mixed-precision configuration to set bit widths for different layers.
297
+ analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
298
+ target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
299
+
300
+ Returns:
301
+ A quantized model and information the user may need to handle the quantized model.
302
+
303
+ Examples:
304
+ Import MCT:
305
+
306
+ >>> import model_compression_toolkit as mct
307
+
308
+ Import a Keras model:
309
+
310
+ >>> from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
311
+ >>> model = MobileNetV2()
312
+
313
+ Create a random dataset generator:
314
+
315
+ >>> import numpy as np
316
+ >>> def repr_datagen(): return [np.random.random((1,224,224,3))]
317
+
318
+ Create a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
319
+ Here, each layer can be quantized by 2, 4 or 8 bits:
320
+
321
+ >>> config = mct.MixedPrecisionQuantizationConfig(weights_n_bits=[4, 2, 8])
322
+
323
+ Create a KPI object to limit our returned model's size. Note that this value affects only coefficients that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value, while the bias will not):
324
+
325
+ >>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
326
+
327
+ Pass the model, the representative dataset generator, the configuration and the target KPI to get a quantized model:
328
+
329
+ >>> quantized_model, quantization_info = mct.keras_post_training_quantization_mixed_precision(model, repr_datagen, n_iter=10, quant_config=config, target_kpi=kpi)
330
+
331
+ For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
332
+
333
+ """
334
+
335
+ if quant_config.weights_bias_correction and gptq_config is not None:
336
+ common.Logger.error('weights_bias_correction should be disabled in GPTQ mode')
337
+
338
+ common.Logger.info("Using experimental mixed-precision quantization. "
339
+ "If you encounter an issue please file a bug.")
340
+
341
+ if target_kpi is None:
342
+ common.Logger.warning("No KPI was passed. Using non mixed-precision compression process...")
343
+ # Before starting non-mixed-precision process, we need to set only single bit width, so we take the best
344
+ # option which is the maximal number of bits.
345
+ quant_config.weights_n_bits = [max(quant_config.weights_n_bits)]
346
+ return keras_post_training_quantization(in_model,
347
+ representative_data_gen,
348
+ n_iter,
349
+ quant_config,
350
+ fw_info,
351
+ network_editor,
352
+ gptq_config,
353
+ analyze_similarity)
354
+
355
+ tb_w = _init_tensorboard_writer()
356
+
357
+ tg = _prepare_model_for_quantization(in_model,
358
+ representative_data_gen,
359
+ network_editor,
360
+ n_iter,
361
+ quant_config,
362
+ fw_info,
363
+ tb_w)
364
+
365
+ ######################################
366
+ # Finalize bit widths
367
+ ######################################
368
+
369
+ if bit_widths_config is None:
370
+ bit_widths_config = search_bit_width(tg,
371
+ quant_config,
372
+ fw_info,
373
+ target_kpi,
374
+ partial(get_sensitivity_evaluation,
375
+ representative_data_gen=representative_data_gen,
376
+ fw_info=fw_info))
377
+
378
+ tg = set_bit_widths(quant_config,
379
+ tg,
380
+ fw_info,
381
+ bit_widths_config)
382
+
383
+ quantized_model, user_info = _quantize_fixed_bit_widths_graph(analyze_similarity,
384
+ fw_info,
385
+ gptq_config,
386
+ representative_data_gen,
387
+ tb_w,
388
+ tg)
389
+ user_info.mixed_precision_cfg = bit_widths_config
390
+
391
+ return quantized_model, user_info
392
+
393
+
394
+ def _quantize_fixed_bit_widths_graph(analyze_similarity: bool,
395
+ fw_info: FrameworkInfo,
396
+ gptq_config: GradientPTQConfig,
397
+ representative_data_gen: Callable,
398
+ tb_w: TensorboardWriter,
399
+ tg: Graph) -> Tuple[Model, UserInformation]:
400
+ """
401
+ Quantize a graph that has final weights candidates quantization configurations.
402
+ Before we quantize the graph weights, we apply GPTQ to get an improved graph.
403
+
404
+ Args:
405
+ analyze_similarity: Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
406
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.)
407
+ gptq_config: Configuration for using GPTQ (e.g. optimizer).
408
+ representative_data_gen: Dataset used for GPTQ fine tuning.
409
+ tb_w: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
410
+ tg: Graph to apply GPTQ and to quantize.
411
+
412
+ Returns:
413
+ A tuple of the quantized model and an object of UserInformation.
414
+
415
+ """
416
+
417
+
418
+ #############################################
419
+ # Gradient Based Post Training Quantization
420
+ #############################################
421
+ tg = _apply_gptq(gptq_config,
422
+ representative_data_gen,
423
+ tb_w,
424
+ tg,
425
+ fw_info)
426
+
427
+ tg_float = copy.deepcopy(tg) # Copy graph before quantization (for similarity analyzer)
428
+ ######################################
429
+ # Model Quantization
430
+ ######################################
431
+ quantized_model, user_info = _quantize_model(fw_info,
432
+ tb_w,
433
+ tg)
434
+ if analyze_similarity:
435
+ _analyze_similarity(representative_data_gen,
436
+ tb_w,
437
+ tg,
438
+ tg_float)
439
+
440
+ return quantized_model, user_info
441
+
442
+
443
+ def _init_tensorboard_writer() -> TensorboardWriter:
444
+ """
445
+
446
+ Returns: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
447
+
448
+ """
449
+ tb_w = None
450
+ if common.Logger.LOG_PATH is not None:
451
+ tb_log_dir = os.path.join(os.getcwd(), common.Logger.LOG_PATH, 'tensorboard_logs')
452
+ common.Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
453
+ tb_w = TensorboardWriter(tb_log_dir)
454
+ return tb_w
455
+
456
+
457
+ def _quantize_model(fw_info: FrameworkInfo,
458
+ tb_w: TensorboardWriter,
459
+ tg: Graph) -> Tuple[Model, UserInformation]:
460
+ """
461
+ Quantize graph's weights, and build a quantized Keras model from it.
462
+
463
+ Args:
464
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
465
+ tb_w: TensorBoardWriter object to log events.
466
+ tg: A prepared for quantization graph.
467
+
468
+ Returns:
469
+ Quantize Keras model, and informat the user may need to use the quantized model.
470
+ """
471
+
472
+ quantized_tg = quantize_graph_weights(tg, fw_info=fw_info)
473
+ if tb_w is not None:
474
+ tb_w.add_graph(quantized_tg, 'after_quantization')
475
+
476
+ quantized_graph_with_bias_correction = apply_bias_correction_to_graph(quantized_tg,
477
+ fw_info=fw_info)
478
+ if tb_w is not None:
479
+ tb_w.add_graph(quantized_graph_with_bias_correction, 'after_bias_correction')
480
+
481
+ ######################################
482
+ # Back2Framework
483
+ ######################################
484
+ quantized_model, user_info = model_builder(quantized_graph_with_bias_correction,
485
+ mode=ModelBuilderMode.QUANTIZED)
486
+
487
+ return quantized_model, user_info
488
+
489
+
490
+ def _analyze_similarity(representative_data_gen: Callable,
491
+ tb_w: TensorboardWriter,
492
+ tg: Graph,
493
+ tg_float: Graph):
494
+ """
495
+ Plot the cosine similarity of different points on the graph between the float and quantized
496
+ graphs. Add them to the passed TensorboardWriter object and close all tensorboard writer open
497
+ files.
498
+
499
+ Args:
500
+ representative_data_gen: Dataset used for calibration.
501
+ tb_w: TensorBoardWriter object to log events.
502
+ tg: Graph of quantized model.
503
+ tg_float: Graph of float model.
504
+
505
+ """
506
+ if tb_w is not None:
507
+ visual = KerasNNVisualizer(tg_float, tg)
508
+ for i in range(NUM_SAMPLES_CS_TENSORBOARD):
509
+ figure = visual.plot_cs_graph(representative_data_gen())
510
+ tb_w.add_figure(figure, f'cosine_similarity_sample_{i}')
511
+ tb_w.close()
512
+
513
+
514
+ def _apply_gptq(gptq_config: GradientPTQConfig,
515
+ representative_data_gen: Callable,
516
+ tb_w: TensorboardWriter,
517
+ tg: Graph,
518
+ fw_info: FrameworkInfo) -> Graph:
519
+ """
520
+ Apply GPTQ to improve accuracy of quantized model.
521
+ Build two models from a graph: A teacher network (float model) and a student network (quantized model).
522
+ and use the dataset generator to pass images through the teacher and student networks to get intermediate
523
+ layers outputs and maximize their similarity.
524
+
525
+ Args:
526
+ gptq_config: Configuration for using GPTQ (e.g. optimizer).
527
+ representative_data_gen: Dataset used for calibration.
528
+ tb_w: TensorBoardWriter object to log events.
529
+ tg: Graph of quantized model.
530
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
531
+
532
+ Returns:
533
+
534
+ """
535
+ if gptq_config is not None:
536
+ common.Logger.info("Using experimental Gradient Based PTQ: If you encounter an issue "
537
+ "please file a bug. To disable it, do not pass a gptq configuration.")
538
+
539
+ tg = gptq_training_wrapper(tg,
540
+ representative_data_gen,
541
+ gptq_config,
542
+ fw_info)
543
+
544
+ if tb_w is not None:
545
+ tb_w.add_graph(tg, 'after_gptq')
546
+ return tg