mct-nightly 0.0.0__py3-none-any.whl → 1.1.0.01122021-003325__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/METADATA +3 -2
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/RECORD +31 -38
- model_compression_toolkit/__init__.py +2 -6
- model_compression_toolkit/common/base_substitutions.py +1 -0
- model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +9 -12
- model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +8 -21
- model_compression_toolkit/common/collectors/histogram_collector.py +1 -1
- model_compression_toolkit/common/graph/base_graph.py +2 -4
- model_compression_toolkit/common/graph/graph_matchers.py +3 -1
- model_compression_toolkit/common/graph/graph_searches.py +3 -1
- model_compression_toolkit/common/mixed_precision/bit_width_setter.py +1 -2
- model_compression_toolkit/common/network_editors/node_filters.py +1 -0
- model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py +1 -1
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +3 -5
- model_compression_toolkit/common/quantization/quantize_graph_weights.py +4 -7
- model_compression_toolkit/common/quantization/quantize_node.py +3 -5
- model_compression_toolkit/keras/__init__.py +2 -0
- model_compression_toolkit/keras/back2framework/model_builder.py +24 -1
- model_compression_toolkit/{common → keras/back2framework}/model_collector.py +9 -18
- model_compression_toolkit/keras/default_framework_info.py +0 -1
- model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +57 -10
- model_compression_toolkit/keras/graph_substitutions/substituter.py +171 -0
- model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +26 -6
- model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +12 -5
- model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +3 -4
- model_compression_toolkit/keras/quantization_facade.py +524 -188
- model_compression_toolkit/keras/reader/connectivity_handler.py +4 -1
- model_compression_toolkit/keras/visualization/nn_visualizer.py +1 -2
- model_compression_toolkit/common/framework_implementation.py +0 -239
- model_compression_toolkit/common/gptq/__init__.py +0 -14
- model_compression_toolkit/common/gptq/gptq_config.py +0 -65
- model_compression_toolkit/common/model_builder_mode.py +0 -34
- model_compression_toolkit/common/post_training_quantization.py +0 -459
- model_compression_toolkit/common/substitutions/__init__.py +0 -14
- model_compression_toolkit/common/substitutions/apply_substitutions.py +0 -40
- model_compression_toolkit/keras/keras_implementation.py +0 -256
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/LICENSE +0 -0
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/WHEEL +0 -0
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/top_level.txt +0 -0
|
@@ -1,459 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
import copy
|
|
18
|
-
import os
|
|
19
|
-
from functools import partial
|
|
20
|
-
from typing import Callable, List, Tuple, Any
|
|
21
|
-
from tqdm import tqdm
|
|
22
|
-
|
|
23
|
-
from model_compression_toolkit import common
|
|
24
|
-
from model_compression_toolkit.common.gptq.gptq_config import GradientPTQConfig
|
|
25
|
-
from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
|
|
26
|
-
from model_compression_toolkit.common.mixed_precision.kpi import KPI
|
|
27
|
-
from model_compression_toolkit.common import FrameworkInfo
|
|
28
|
-
from model_compression_toolkit.common.constants import NUM_SAMPLES_CS_TENSORBOARD
|
|
29
|
-
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
30
|
-
from model_compression_toolkit.common.mixed_precision.bit_width_setter import set_bit_widths
|
|
31
|
-
|
|
32
|
-
from model_compression_toolkit.common.mixed_precision.mixed_precision_search_facade import search_bit_width
|
|
33
|
-
from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
|
|
34
|
-
from model_compression_toolkit.common.network_editors.actions import EditRule
|
|
35
|
-
from model_compression_toolkit.common.network_editors.edit_network import edit_network_graph
|
|
36
|
-
from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
|
|
37
|
-
MixedPrecisionQuantizationConfig
|
|
38
|
-
from model_compression_toolkit.common.quantization.quantize_graph_weights import quantize_graph_weights
|
|
39
|
-
from model_compression_toolkit.common.bias_correction.compute_bias_correction_of_graph import compute_bias_correction_of_graph
|
|
40
|
-
|
|
41
|
-
from model_compression_toolkit.common.quantization.quantization_analyzer import analyzer_graph
|
|
42
|
-
from model_compression_toolkit.common.quantization.quantization_config import DEFAULTCONFIG
|
|
43
|
-
from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
|
|
44
|
-
from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_computation import \
|
|
45
|
-
calculate_quantization_params
|
|
46
|
-
|
|
47
|
-
from model_compression_toolkit.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
|
48
|
-
|
|
49
|
-
from model_compression_toolkit.common.substitutions.apply_substitutions import substitute
|
|
50
|
-
from model_compression_toolkit.common.user_info import UserInformation
|
|
51
|
-
from model_compression_toolkit.common.model_collector import ModelCollector
|
|
52
|
-
|
|
53
|
-
from model_compression_toolkit.common.visualization.tensorboard_writer import TensorboardWriter
|
|
54
|
-
from model_compression_toolkit.common.bias_correction.apply_bias_correction_to_graph import apply_bias_correction_to_graph
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def post_training_quantization(in_model: Any,
|
|
61
|
-
representative_data_gen: Callable,
|
|
62
|
-
n_iter: int,
|
|
63
|
-
quant_config: QuantizationConfig,
|
|
64
|
-
fw_info: FrameworkInfo,
|
|
65
|
-
fw_impl: FrameworkImplementation,
|
|
66
|
-
network_editor: List[EditRule] = [],
|
|
67
|
-
gptq_config: GradientPTQConfig = None,
|
|
68
|
-
analyze_similarity: bool = False,
|
|
69
|
-
target_kpi: KPI = None):
|
|
70
|
-
"""
|
|
71
|
-
Quantize a trained model using post-training quantization. The model is quantized using a
|
|
72
|
-
symmetric constraint quantization thresholds (power of two).
|
|
73
|
-
The model is first optimized using several transformations (e.g. BatchNormalization folding to
|
|
74
|
-
preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
|
|
75
|
-
being collected for each layer's output (and input, depends on the quantization configuration).
|
|
76
|
-
Thresholds are then being calculated using the collected statistics and the model is quantized
|
|
77
|
-
(both coefficients and activations by default).
|
|
78
|
-
If a gptq configuration is passed, the quantized weights are optimized using knowledge
|
|
79
|
-
distillation by comparing points between the float and quantized models, and minimizing the observed loss.
|
|
80
|
-
|
|
81
|
-
Args:
|
|
82
|
-
in_model: Model to quantize.
|
|
83
|
-
representative_data_gen: Dataset used for calibration.
|
|
84
|
-
n_iter: Number of calibration iterations to run.
|
|
85
|
-
quant_config: QuantizationConfig containing parameters of how the model should be quantized. `Default configuration. <https://github.com/sony/model_optimization/blob/21e21c95ca25a31874a5be7af9dd2dd5da8f3a10/model_compression_toolkit/common/quantization/quantization_config.py#L163>`_
|
|
86
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/21e21c95ca25a31874a5be7af9dd2dd5da8f3a10/model_compression_toolkit/keras/default_framework_info.py#L114>`_
|
|
87
|
-
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
88
|
-
network_editor: List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
|
|
89
|
-
gptq_config: Configuration for using gradient-based PTQ (e.g. optimizer).
|
|
90
|
-
analyze_similarity: Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
|
|
91
|
-
target_kpi: KPI to constraint the search of the mixed-precision configuration for the model.
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
A quantized model and information the user may need to handle the quantized model.
|
|
95
|
-
|
|
96
|
-
"""
|
|
97
|
-
if quant_config.weights_bias_correction and gptq_config is not None:
|
|
98
|
-
common.Logger.error('weights_bias_correction should be disabled in GPTQ mode')
|
|
99
|
-
|
|
100
|
-
tb_w = _init_tensorboard_writer()
|
|
101
|
-
|
|
102
|
-
tg = _prepare_model_for_quantization(in_model,
|
|
103
|
-
representative_data_gen,
|
|
104
|
-
network_editor,
|
|
105
|
-
n_iter,
|
|
106
|
-
quant_config,
|
|
107
|
-
fw_info,
|
|
108
|
-
tb_w,
|
|
109
|
-
fw_impl)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
######################################
|
|
113
|
-
# Finalize bit widths
|
|
114
|
-
######################################
|
|
115
|
-
if target_kpi is not None:
|
|
116
|
-
assert isinstance(quant_config, MixedPrecisionQuantizationConfig)
|
|
117
|
-
bit_widths_config = search_bit_width(tg,
|
|
118
|
-
quant_config,
|
|
119
|
-
fw_info,
|
|
120
|
-
target_kpi,
|
|
121
|
-
partial(fw_impl.get_sensitivity_evaluation_fn,
|
|
122
|
-
representative_data_gen=representative_data_gen,
|
|
123
|
-
fw_info=fw_info))
|
|
124
|
-
else:
|
|
125
|
-
bit_widths_config = None
|
|
126
|
-
|
|
127
|
-
tg = set_bit_widths(quant_config,
|
|
128
|
-
tg,
|
|
129
|
-
fw_info,
|
|
130
|
-
bit_widths_config)
|
|
131
|
-
|
|
132
|
-
quantized_model, user_info = _quantize_fixed_bit_widths_graph(analyze_similarity,
|
|
133
|
-
fw_info,
|
|
134
|
-
gptq_config,
|
|
135
|
-
representative_data_gen,
|
|
136
|
-
tb_w,
|
|
137
|
-
tg,
|
|
138
|
-
fw_impl)
|
|
139
|
-
user_info.mixed_precision_cfg = bit_widths_config
|
|
140
|
-
|
|
141
|
-
return quantized_model, user_info
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
def _init_tensorboard_writer() -> TensorboardWriter:
|
|
148
|
-
"""
|
|
149
|
-
|
|
150
|
-
Returns: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
|
|
151
|
-
|
|
152
|
-
"""
|
|
153
|
-
tb_w = None
|
|
154
|
-
if common.Logger.LOG_PATH is not None:
|
|
155
|
-
tb_log_dir = os.path.join(os.getcwd(), common.Logger.LOG_PATH, 'tensorboard_logs')
|
|
156
|
-
common.Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
|
|
157
|
-
tb_w = TensorboardWriter(tb_log_dir)
|
|
158
|
-
return tb_w
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def _quantize_model(fw_info: FrameworkInfo,
|
|
162
|
-
tb_w: TensorboardWriter,
|
|
163
|
-
tg: Graph,
|
|
164
|
-
fw_impl: FrameworkImplementation) -> Tuple[Any, UserInformation]:
|
|
165
|
-
"""
|
|
166
|
-
Quantize graph's weights, and build a quantized Keras model from it.
|
|
167
|
-
|
|
168
|
-
Args:
|
|
169
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
|
|
170
|
-
tb_w: TensorBoardWriter object to log events.
|
|
171
|
-
tg: A prepared for quantization graph.
|
|
172
|
-
|
|
173
|
-
Returns:
|
|
174
|
-
Quantize Keras model, and informat the user may need to use the quantized model.
|
|
175
|
-
"""
|
|
176
|
-
|
|
177
|
-
quantized_tg = quantize_graph_weights(tg,
|
|
178
|
-
fw_info=fw_info,
|
|
179
|
-
fw_impl=fw_impl)
|
|
180
|
-
if tb_w is not None:
|
|
181
|
-
tb_w.add_graph(quantized_tg, 'after_quantization')
|
|
182
|
-
|
|
183
|
-
quantized_graph_with_bias_correction = apply_bias_correction_to_graph(quantized_tg,
|
|
184
|
-
fw_info=fw_info,
|
|
185
|
-
fw_impl=fw_impl)
|
|
186
|
-
if tb_w is not None:
|
|
187
|
-
tb_w.add_graph(quantized_graph_with_bias_correction, 'after_bias_correction')
|
|
188
|
-
|
|
189
|
-
######################################
|
|
190
|
-
# Back2Framework
|
|
191
|
-
######################################
|
|
192
|
-
# Before building a quantized model, first apply some substitutions.
|
|
193
|
-
quantized_graph_with_bias_correction = substitute(quantized_graph_with_bias_correction,
|
|
194
|
-
fw_impl.get_substitutions_pre_build())
|
|
195
|
-
quantized_model, user_info = fw_impl.model_builder(quantized_graph_with_bias_correction,
|
|
196
|
-
mode=ModelBuilderMode.QUANTIZED,
|
|
197
|
-
fw_info=fw_info)
|
|
198
|
-
return quantized_model, user_info
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
def _analyze_similarity(representative_data_gen: Callable,
|
|
202
|
-
tb_w: TensorboardWriter,
|
|
203
|
-
tg: Graph,
|
|
204
|
-
tg_float: Graph):
|
|
205
|
-
"""
|
|
206
|
-
Plot the cosine similarity of different points on the graph between the float and quantized
|
|
207
|
-
graphs. Add them to the passed TensorboardWriter object and close all tensorboard writer open
|
|
208
|
-
files.
|
|
209
|
-
|
|
210
|
-
Args:
|
|
211
|
-
representative_data_gen: Dataset used for calibration.
|
|
212
|
-
tb_w: TensorBoardWriter object to log events.
|
|
213
|
-
tg: Graph of quantized model.
|
|
214
|
-
tg_float: Graph of float model.
|
|
215
|
-
|
|
216
|
-
"""
|
|
217
|
-
if tb_w is not None:
|
|
218
|
-
visual = KerasNNVisualizer(tg_float, tg)
|
|
219
|
-
for i in range(NUM_SAMPLES_CS_TENSORBOARD):
|
|
220
|
-
figure = visual.plot_cs_graph(representative_data_gen())
|
|
221
|
-
tb_w.add_figure(figure, f'cosine_similarity_sample_{i}')
|
|
222
|
-
tb_w.close()
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
def _apply_gptq(gptq_config: GradientPTQConfig,
|
|
226
|
-
representative_data_gen: Callable,
|
|
227
|
-
tb_w: TensorboardWriter,
|
|
228
|
-
tg: Graph,
|
|
229
|
-
fw_info: FrameworkInfo,
|
|
230
|
-
fw_impl: FrameworkImplementation) -> Graph:
|
|
231
|
-
"""
|
|
232
|
-
Apply GPTQ to improve accuracy of quantized model.
|
|
233
|
-
Build two models from a graph: A teacher network (float model) and a student network (quantized model).
|
|
234
|
-
and use the dataset generator to pass images through the teacher and student networks to get intermediate
|
|
235
|
-
layers outputs and maximize their similarity.
|
|
236
|
-
|
|
237
|
-
Args:
|
|
238
|
-
gptq_config: Configuration for using GPTQ (e.g. optimizer).
|
|
239
|
-
representative_data_gen: Dataset used for calibration.
|
|
240
|
-
tb_w: TensorBoardWriter object to log events.
|
|
241
|
-
tg: Graph of quantized model.
|
|
242
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
|
|
243
|
-
|
|
244
|
-
Returns:
|
|
245
|
-
|
|
246
|
-
"""
|
|
247
|
-
if gptq_config is not None:
|
|
248
|
-
common.Logger.info("Using experimental Gradient Based PTQ: If you encounter an issue "
|
|
249
|
-
"please file a bug. To disable it, do not pass a gptq configuration.")
|
|
250
|
-
|
|
251
|
-
tg = fw_impl.gptq_training(tg,
|
|
252
|
-
representative_data_gen,
|
|
253
|
-
gptq_config,
|
|
254
|
-
fw_info)
|
|
255
|
-
|
|
256
|
-
if tb_w is not None:
|
|
257
|
-
tb_w.add_graph(tg, 'after_gptq')
|
|
258
|
-
return tg
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
def _quantize_fixed_bit_widths_graph(analyze_similarity: bool,
|
|
262
|
-
fw_info: FrameworkInfo,
|
|
263
|
-
gptq_config: GradientPTQConfig,
|
|
264
|
-
representative_data_gen: Callable,
|
|
265
|
-
tb_w: TensorboardWriter,
|
|
266
|
-
tg: Graph,
|
|
267
|
-
fw_impl: FrameworkImplementation) -> Tuple[Any, UserInformation]:
|
|
268
|
-
"""
|
|
269
|
-
Quantize a graph that has final weights candidates quantization configurations.
|
|
270
|
-
Before we quantize the graph weights, we apply GPTQ to get an improved graph.
|
|
271
|
-
|
|
272
|
-
Args:
|
|
273
|
-
analyze_similarity: Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
|
|
274
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.)
|
|
275
|
-
gptq_config: Configuration for using GPTQ (e.g. optimizer).
|
|
276
|
-
representative_data_gen: Dataset used for GPTQ fine tuning.
|
|
277
|
-
tb_w: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
|
|
278
|
-
tg: Graph to apply GPTQ and to quantize.
|
|
279
|
-
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
280
|
-
|
|
281
|
-
Returns:
|
|
282
|
-
A tuple of the quantized model and an object of UserInformation.
|
|
283
|
-
|
|
284
|
-
"""
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
#############################################
|
|
288
|
-
# Gradient Based Post Training Quantization
|
|
289
|
-
#############################################
|
|
290
|
-
tg = _apply_gptq(gptq_config,
|
|
291
|
-
representative_data_gen,
|
|
292
|
-
tb_w,
|
|
293
|
-
tg,
|
|
294
|
-
fw_info,
|
|
295
|
-
fw_impl)
|
|
296
|
-
|
|
297
|
-
tg_float = copy.deepcopy(tg) # Copy graph before quantization (for similarity analyzer)
|
|
298
|
-
######################################
|
|
299
|
-
# Model Quantization
|
|
300
|
-
######################################
|
|
301
|
-
quantized_model, user_info = _quantize_model(fw_info,
|
|
302
|
-
tb_w,
|
|
303
|
-
tg,
|
|
304
|
-
fw_impl)
|
|
305
|
-
if analyze_similarity:
|
|
306
|
-
_analyze_similarity(representative_data_gen,
|
|
307
|
-
tb_w,
|
|
308
|
-
tg,
|
|
309
|
-
tg_float)
|
|
310
|
-
|
|
311
|
-
return quantized_model, user_info
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
def _prepare_model_for_quantization(in_model: Any,
|
|
316
|
-
representative_data_gen: Callable,
|
|
317
|
-
network_editor: List[EditRule] = [],
|
|
318
|
-
n_iter: int = 500,
|
|
319
|
-
quant_config: QuantizationConfig = DEFAULTCONFIG,
|
|
320
|
-
fw_info: FrameworkInfo = None,
|
|
321
|
-
tb_w: TensorboardWriter = None,
|
|
322
|
-
fw_impl: FrameworkImplementation = None) -> Graph:
|
|
323
|
-
"""
|
|
324
|
-
Prepare a trained Keras model for post-training quantization. The model is prepared to be quantized using a
|
|
325
|
-
symmetric constraint quantization thresholds (power of two).
|
|
326
|
-
The model is first read into a graph object and being optimized using several transformations (e.g.
|
|
327
|
-
BatchNormalization folding to preceding layers). Then, using a given dataset, statistics (e.g. min/max,
|
|
328
|
-
histogram, etc.) are being collected for each layer's output (and input, depends on the quantization configuration).
|
|
329
|
-
Thresholds are then being calculated using the collected statistics. Finally, more transformations (based on
|
|
330
|
-
statistics) are applied to increase model's performance.
|
|
331
|
-
|
|
332
|
-
Args:
|
|
333
|
-
in_model (Model): Keras model to optimize and prepare for quantization.
|
|
334
|
-
representative_data_gen (Callable): Dataset used for calibration.
|
|
335
|
-
network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to
|
|
336
|
-
change quantization settings of the filtered nodes.
|
|
337
|
-
n_iter (int): Number of calibration iterations to run.
|
|
338
|
-
quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
|
|
339
|
-
quantized.
|
|
340
|
-
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g.,
|
|
341
|
-
kernel channels indices, groups of layers by how they should be quantized, etc.)
|
|
342
|
-
tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
|
|
343
|
-
fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
|
|
344
|
-
|
|
345
|
-
Returns:
|
|
346
|
-
Graph object that represents the Keras model, contains thresholds, and ready for quantization.
|
|
347
|
-
"""
|
|
348
|
-
|
|
349
|
-
######################################
|
|
350
|
-
# Represent model in a graph
|
|
351
|
-
######################################
|
|
352
|
-
graph = fw_impl.model_reader(in_model) # model reading
|
|
353
|
-
|
|
354
|
-
if tb_w is not None:
|
|
355
|
-
tb_w.add_graph(graph, 'initial_graph')
|
|
356
|
-
|
|
357
|
-
######################################
|
|
358
|
-
# Graph substitution (pre statistics collection)
|
|
359
|
-
######################################
|
|
360
|
-
transformed_graph = substitute(graph, fw_impl.get_substitutions_pre_statistics_collection())
|
|
361
|
-
|
|
362
|
-
if tb_w is not None:
|
|
363
|
-
tb_w.add_graph(transformed_graph, 'pre_statistics_collection_substitutions')
|
|
364
|
-
|
|
365
|
-
######################################
|
|
366
|
-
# Graph marking points
|
|
367
|
-
######################################
|
|
368
|
-
transformed_graph = substitute(transformed_graph, fw_impl.get_substitutions_marking())
|
|
369
|
-
|
|
370
|
-
if tb_w is not None:
|
|
371
|
-
tb_w.add_graph(transformed_graph, 'after_graph_marking')
|
|
372
|
-
|
|
373
|
-
######################################
|
|
374
|
-
# Graph analyzing (attaching statistics collectors)
|
|
375
|
-
######################################
|
|
376
|
-
analyzer_graph(fw_impl.attach_sc_to_node,
|
|
377
|
-
transformed_graph,
|
|
378
|
-
fw_info,
|
|
379
|
-
quant_config) # Mark points for statistics collection
|
|
380
|
-
|
|
381
|
-
if tb_w is not None:
|
|
382
|
-
tb_w.add_graph(transformed_graph, 'after_analyzer_graph')
|
|
383
|
-
|
|
384
|
-
######################################
|
|
385
|
-
# Statistic collection
|
|
386
|
-
######################################
|
|
387
|
-
mi = ModelCollector(transformed_graph,
|
|
388
|
-
fw_impl,
|
|
389
|
-
fw_info)
|
|
390
|
-
|
|
391
|
-
for _ in tqdm(range(n_iter)):
|
|
392
|
-
mi.infer(representative_data_gen())
|
|
393
|
-
|
|
394
|
-
######################################
|
|
395
|
-
# Add quantization configurations
|
|
396
|
-
######################################
|
|
397
|
-
transformed_graph = set_quantization_configuration_to_graph(transformed_graph,
|
|
398
|
-
quant_config,
|
|
399
|
-
fw_info)
|
|
400
|
-
|
|
401
|
-
######################################
|
|
402
|
-
# Edit network according to user specific settings
|
|
403
|
-
######################################
|
|
404
|
-
edit_network_graph(transformed_graph, fw_info, network_editor)
|
|
405
|
-
|
|
406
|
-
######################################
|
|
407
|
-
# Calculate quantization params
|
|
408
|
-
######################################
|
|
409
|
-
calculate_quantization_params(transformed_graph,
|
|
410
|
-
fw_info,
|
|
411
|
-
fw_impl=fw_impl)
|
|
412
|
-
|
|
413
|
-
if tb_w is not None:
|
|
414
|
-
tb_w.add_graph(transformed_graph, 'thresholds_selection')
|
|
415
|
-
tb_w.add_all_statistics(transformed_graph, 'thresholds_selection')
|
|
416
|
-
|
|
417
|
-
######################################
|
|
418
|
-
# Graph substitution (post statistics collection)
|
|
419
|
-
######################################
|
|
420
|
-
transformed_graph = substitute(transformed_graph,
|
|
421
|
-
fw_impl.get_substitutions_post_statistics_collection(quant_config))
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
######################################
|
|
425
|
-
# Channel equalization
|
|
426
|
-
######################################
|
|
427
|
-
transformed_graph = substitute(transformed_graph,
|
|
428
|
-
fw_impl.get_substitutions_channel_equalization(quant_config,
|
|
429
|
-
fw_info))
|
|
430
|
-
|
|
431
|
-
######################################
|
|
432
|
-
# Shift Negative Activations
|
|
433
|
-
######################################
|
|
434
|
-
if quant_config.shift_negative_activation_correction:
|
|
435
|
-
transformed_graph = fw_impl.shift_negative_correction(transformed_graph,
|
|
436
|
-
quant_config,
|
|
437
|
-
fw_info)
|
|
438
|
-
if tb_w is not None:
|
|
439
|
-
tb_w.add_graph(transformed_graph, 'after_shift_negative_correction')
|
|
440
|
-
tb_w.add_all_statistics(transformed_graph, 'after_shift_negative_correction')
|
|
441
|
-
|
|
442
|
-
if tb_w is not None:
|
|
443
|
-
tb_w.add_graph(transformed_graph, 'post_statistics_collection_substitutions')
|
|
444
|
-
tb_w.add_all_statistics(transformed_graph, 'post_statistics_collection_substitutions')
|
|
445
|
-
|
|
446
|
-
########################################################
|
|
447
|
-
# Compute bias correction to nodes' config candidates
|
|
448
|
-
########################################################
|
|
449
|
-
tg_with_bias = compute_bias_correction_of_graph(transformed_graph,
|
|
450
|
-
fw_info,
|
|
451
|
-
fw_impl)
|
|
452
|
-
|
|
453
|
-
if tb_w is not None:
|
|
454
|
-
tb_w.add_graph(tg_with_bias, 'bias_correction_computation')
|
|
455
|
-
|
|
456
|
-
for n in tg_with_bias.nodes:
|
|
457
|
-
assert n.final_weights_quantization_cfg is None
|
|
458
|
-
|
|
459
|
-
return tg_with_bias
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
@@ -1,40 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
import copy
|
|
17
|
-
|
|
18
|
-
from typing import List
|
|
19
|
-
|
|
20
|
-
from model_compression_toolkit import common
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def substitute(graph_to_substitute: common.Graph,
|
|
24
|
-
substitutions_list: List[common.BaseSubstitution]) -> common.Graph:
|
|
25
|
-
"""
|
|
26
|
-
Apply a list of substitutions on a graph.
|
|
27
|
-
Args:
|
|
28
|
-
graph: Graph to transform.
|
|
29
|
-
substitutions_list: List of substitutions to apply on the graph.
|
|
30
|
-
|
|
31
|
-
Returns:
|
|
32
|
-
Transformed graph after applying all substitutions in substitutions_list.
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
graph = copy.deepcopy(graph_to_substitute)
|
|
36
|
-
for substitution in substitutions_list:
|
|
37
|
-
matched_nodes = graph.filter(substitution.matcher_instance)
|
|
38
|
-
for idn in matched_nodes:
|
|
39
|
-
graph = substitution.substitute(graph, idn)
|
|
40
|
-
return graph
|