mct-nightly 0.0.0__py3-none-any.whl → 1.1.0.02122021-003117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/METADATA +3 -2
  2. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/RECORD +31 -38
  3. model_compression_toolkit/__init__.py +2 -6
  4. model_compression_toolkit/common/base_substitutions.py +1 -0
  5. model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +9 -12
  6. model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +8 -21
  7. model_compression_toolkit/common/collectors/histogram_collector.py +1 -1
  8. model_compression_toolkit/common/graph/base_graph.py +2 -4
  9. model_compression_toolkit/common/graph/graph_matchers.py +3 -1
  10. model_compression_toolkit/common/graph/graph_searches.py +3 -1
  11. model_compression_toolkit/common/mixed_precision/bit_width_setter.py +1 -2
  12. model_compression_toolkit/common/network_editors/node_filters.py +1 -0
  13. model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py +1 -1
  14. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +3 -5
  15. model_compression_toolkit/common/quantization/quantize_graph_weights.py +4 -7
  16. model_compression_toolkit/common/quantization/quantize_node.py +3 -5
  17. model_compression_toolkit/keras/__init__.py +2 -0
  18. model_compression_toolkit/keras/back2framework/model_builder.py +24 -1
  19. model_compression_toolkit/{common → keras/back2framework}/model_collector.py +9 -18
  20. model_compression_toolkit/keras/default_framework_info.py +0 -1
  21. model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +57 -10
  22. model_compression_toolkit/keras/graph_substitutions/substituter.py +171 -0
  23. model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +26 -6
  24. model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +12 -5
  25. model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +3 -4
  26. model_compression_toolkit/keras/quantization_facade.py +524 -188
  27. model_compression_toolkit/keras/reader/connectivity_handler.py +4 -1
  28. model_compression_toolkit/keras/visualization/nn_visualizer.py +1 -2
  29. model_compression_toolkit/common/framework_implementation.py +0 -239
  30. model_compression_toolkit/common/gptq/__init__.py +0 -14
  31. model_compression_toolkit/common/gptq/gptq_config.py +0 -65
  32. model_compression_toolkit/common/model_builder_mode.py +0 -34
  33. model_compression_toolkit/common/post_training_quantization.py +0 -459
  34. model_compression_toolkit/common/substitutions/__init__.py +0 -14
  35. model_compression_toolkit/common/substitutions/apply_substitutions.py +0 -40
  36. model_compression_toolkit/keras/keras_implementation.py +0 -256
  37. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/LICENSE +0 -0
  38. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/WHEEL +0 -0
  39. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/top_level.txt +0 -0
@@ -1,459 +0,0 @@
1
- # Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import copy
18
- import os
19
- from functools import partial
20
- from typing import Callable, List, Tuple, Any
21
- from tqdm import tqdm
22
-
23
- from model_compression_toolkit import common
24
- from model_compression_toolkit.common.gptq.gptq_config import GradientPTQConfig
25
- from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
26
- from model_compression_toolkit.common.mixed_precision.kpi import KPI
27
- from model_compression_toolkit.common import FrameworkInfo
28
- from model_compression_toolkit.common.constants import NUM_SAMPLES_CS_TENSORBOARD
29
- from model_compression_toolkit.common.graph.base_graph import Graph
30
- from model_compression_toolkit.common.mixed_precision.bit_width_setter import set_bit_widths
31
-
32
- from model_compression_toolkit.common.mixed_precision.mixed_precision_search_facade import search_bit_width
33
- from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
34
- from model_compression_toolkit.common.network_editors.actions import EditRule
35
- from model_compression_toolkit.common.network_editors.edit_network import edit_network_graph
36
- from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
37
- MixedPrecisionQuantizationConfig
38
- from model_compression_toolkit.common.quantization.quantize_graph_weights import quantize_graph_weights
39
- from model_compression_toolkit.common.bias_correction.compute_bias_correction_of_graph import compute_bias_correction_of_graph
40
-
41
- from model_compression_toolkit.common.quantization.quantization_analyzer import analyzer_graph
42
- from model_compression_toolkit.common.quantization.quantization_config import DEFAULTCONFIG
43
- from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
44
- from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_computation import \
45
- calculate_quantization_params
46
-
47
- from model_compression_toolkit.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
48
-
49
- from model_compression_toolkit.common.substitutions.apply_substitutions import substitute
50
- from model_compression_toolkit.common.user_info import UserInformation
51
- from model_compression_toolkit.common.model_collector import ModelCollector
52
-
53
- from model_compression_toolkit.common.visualization.tensorboard_writer import TensorboardWriter
54
- from model_compression_toolkit.common.bias_correction.apply_bias_correction_to_graph import apply_bias_correction_to_graph
55
-
56
-
57
-
58
-
59
-
60
- def post_training_quantization(in_model: Any,
61
- representative_data_gen: Callable,
62
- n_iter: int,
63
- quant_config: QuantizationConfig,
64
- fw_info: FrameworkInfo,
65
- fw_impl: FrameworkImplementation,
66
- network_editor: List[EditRule] = [],
67
- gptq_config: GradientPTQConfig = None,
68
- analyze_similarity: bool = False,
69
- target_kpi: KPI = None):
70
- """
71
- Quantize a trained model using post-training quantization. The model is quantized using a
72
- symmetric constraint quantization thresholds (power of two).
73
- The model is first optimized using several transformations (e.g. BatchNormalization folding to
74
- preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
75
- being collected for each layer's output (and input, depends on the quantization configuration).
76
- Thresholds are then being calculated using the collected statistics and the model is quantized
77
- (both coefficients and activations by default).
78
- If a gptq configuration is passed, the quantized weights are optimized using knowledge
79
- distillation by comparing points between the float and quantized models, and minimizing the observed loss.
80
-
81
- Args:
82
- in_model: Model to quantize.
83
- representative_data_gen: Dataset used for calibration.
84
- n_iter: Number of calibration iterations to run.
85
- quant_config: QuantizationConfig containing parameters of how the model should be quantized. `Default configuration. <https://github.com/sony/model_optimization/blob/21e21c95ca25a31874a5be7af9dd2dd5da8f3a10/model_compression_toolkit/common/quantization/quantization_config.py#L163>`_
86
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/21e21c95ca25a31874a5be7af9dd2dd5da8f3a10/model_compression_toolkit/keras/default_framework_info.py#L114>`_
87
- fw_impl: FrameworkImplementation object with a specific framework methods implementation.
88
- network_editor: List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
89
- gptq_config: Configuration for using gradient-based PTQ (e.g. optimizer).
90
- analyze_similarity: Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
91
- target_kpi: KPI to constraint the search of the mixed-precision configuration for the model.
92
-
93
- Returns:
94
- A quantized model and information the user may need to handle the quantized model.
95
-
96
- """
97
- if quant_config.weights_bias_correction and gptq_config is not None:
98
- common.Logger.error('weights_bias_correction should be disabled in GPTQ mode')
99
-
100
- tb_w = _init_tensorboard_writer()
101
-
102
- tg = _prepare_model_for_quantization(in_model,
103
- representative_data_gen,
104
- network_editor,
105
- n_iter,
106
- quant_config,
107
- fw_info,
108
- tb_w,
109
- fw_impl)
110
-
111
-
112
- ######################################
113
- # Finalize bit widths
114
- ######################################
115
- if target_kpi is not None:
116
- assert isinstance(quant_config, MixedPrecisionQuantizationConfig)
117
- bit_widths_config = search_bit_width(tg,
118
- quant_config,
119
- fw_info,
120
- target_kpi,
121
- partial(fw_impl.get_sensitivity_evaluation_fn,
122
- representative_data_gen=representative_data_gen,
123
- fw_info=fw_info))
124
- else:
125
- bit_widths_config = None
126
-
127
- tg = set_bit_widths(quant_config,
128
- tg,
129
- fw_info,
130
- bit_widths_config)
131
-
132
- quantized_model, user_info = _quantize_fixed_bit_widths_graph(analyze_similarity,
133
- fw_info,
134
- gptq_config,
135
- representative_data_gen,
136
- tb_w,
137
- tg,
138
- fw_impl)
139
- user_info.mixed_precision_cfg = bit_widths_config
140
-
141
- return quantized_model, user_info
142
-
143
-
144
-
145
-
146
-
147
- def _init_tensorboard_writer() -> TensorboardWriter:
148
- """
149
-
150
- Returns: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
151
-
152
- """
153
- tb_w = None
154
- if common.Logger.LOG_PATH is not None:
155
- tb_log_dir = os.path.join(os.getcwd(), common.Logger.LOG_PATH, 'tensorboard_logs')
156
- common.Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
157
- tb_w = TensorboardWriter(tb_log_dir)
158
- return tb_w
159
-
160
-
161
- def _quantize_model(fw_info: FrameworkInfo,
162
- tb_w: TensorboardWriter,
163
- tg: Graph,
164
- fw_impl: FrameworkImplementation) -> Tuple[Any, UserInformation]:
165
- """
166
- Quantize graph's weights, and build a quantized Keras model from it.
167
-
168
- Args:
169
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
170
- tb_w: TensorBoardWriter object to log events.
171
- tg: A prepared for quantization graph.
172
-
173
- Returns:
174
- Quantize Keras model, and informat the user may need to use the quantized model.
175
- """
176
-
177
- quantized_tg = quantize_graph_weights(tg,
178
- fw_info=fw_info,
179
- fw_impl=fw_impl)
180
- if tb_w is not None:
181
- tb_w.add_graph(quantized_tg, 'after_quantization')
182
-
183
- quantized_graph_with_bias_correction = apply_bias_correction_to_graph(quantized_tg,
184
- fw_info=fw_info,
185
- fw_impl=fw_impl)
186
- if tb_w is not None:
187
- tb_w.add_graph(quantized_graph_with_bias_correction, 'after_bias_correction')
188
-
189
- ######################################
190
- # Back2Framework
191
- ######################################
192
- # Before building a quantized model, first apply some substitutions.
193
- quantized_graph_with_bias_correction = substitute(quantized_graph_with_bias_correction,
194
- fw_impl.get_substitutions_pre_build())
195
- quantized_model, user_info = fw_impl.model_builder(quantized_graph_with_bias_correction,
196
- mode=ModelBuilderMode.QUANTIZED,
197
- fw_info=fw_info)
198
- return quantized_model, user_info
199
-
200
-
201
- def _analyze_similarity(representative_data_gen: Callable,
202
- tb_w: TensorboardWriter,
203
- tg: Graph,
204
- tg_float: Graph):
205
- """
206
- Plot the cosine similarity of different points on the graph between the float and quantized
207
- graphs. Add them to the passed TensorboardWriter object and close all tensorboard writer open
208
- files.
209
-
210
- Args:
211
- representative_data_gen: Dataset used for calibration.
212
- tb_w: TensorBoardWriter object to log events.
213
- tg: Graph of quantized model.
214
- tg_float: Graph of float model.
215
-
216
- """
217
- if tb_w is not None:
218
- visual = KerasNNVisualizer(tg_float, tg)
219
- for i in range(NUM_SAMPLES_CS_TENSORBOARD):
220
- figure = visual.plot_cs_graph(representative_data_gen())
221
- tb_w.add_figure(figure, f'cosine_similarity_sample_{i}')
222
- tb_w.close()
223
-
224
-
225
- def _apply_gptq(gptq_config: GradientPTQConfig,
226
- representative_data_gen: Callable,
227
- tb_w: TensorboardWriter,
228
- tg: Graph,
229
- fw_info: FrameworkInfo,
230
- fw_impl: FrameworkImplementation) -> Graph:
231
- """
232
- Apply GPTQ to improve accuracy of quantized model.
233
- Build two models from a graph: A teacher network (float model) and a student network (quantized model).
234
- and use the dataset generator to pass images through the teacher and student networks to get intermediate
235
- layers outputs and maximize their similarity.
236
-
237
- Args:
238
- gptq_config: Configuration for using GPTQ (e.g. optimizer).
239
- representative_data_gen: Dataset used for calibration.
240
- tb_w: TensorBoardWriter object to log events.
241
- tg: Graph of quantized model.
242
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
243
-
244
- Returns:
245
-
246
- """
247
- if gptq_config is not None:
248
- common.Logger.info("Using experimental Gradient Based PTQ: If you encounter an issue "
249
- "please file a bug. To disable it, do not pass a gptq configuration.")
250
-
251
- tg = fw_impl.gptq_training(tg,
252
- representative_data_gen,
253
- gptq_config,
254
- fw_info)
255
-
256
- if tb_w is not None:
257
- tb_w.add_graph(tg, 'after_gptq')
258
- return tg
259
-
260
-
261
- def _quantize_fixed_bit_widths_graph(analyze_similarity: bool,
262
- fw_info: FrameworkInfo,
263
- gptq_config: GradientPTQConfig,
264
- representative_data_gen: Callable,
265
- tb_w: TensorboardWriter,
266
- tg: Graph,
267
- fw_impl: FrameworkImplementation) -> Tuple[Any, UserInformation]:
268
- """
269
- Quantize a graph that has final weights candidates quantization configurations.
270
- Before we quantize the graph weights, we apply GPTQ to get an improved graph.
271
-
272
- Args:
273
- analyze_similarity: Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
274
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.)
275
- gptq_config: Configuration for using GPTQ (e.g. optimizer).
276
- representative_data_gen: Dataset used for GPTQ fine tuning.
277
- tb_w: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
278
- tg: Graph to apply GPTQ and to quantize.
279
- fw_impl: FrameworkImplementation object with a specific framework methods implementation.
280
-
281
- Returns:
282
- A tuple of the quantized model and an object of UserInformation.
283
-
284
- """
285
-
286
-
287
- #############################################
288
- # Gradient Based Post Training Quantization
289
- #############################################
290
- tg = _apply_gptq(gptq_config,
291
- representative_data_gen,
292
- tb_w,
293
- tg,
294
- fw_info,
295
- fw_impl)
296
-
297
- tg_float = copy.deepcopy(tg) # Copy graph before quantization (for similarity analyzer)
298
- ######################################
299
- # Model Quantization
300
- ######################################
301
- quantized_model, user_info = _quantize_model(fw_info,
302
- tb_w,
303
- tg,
304
- fw_impl)
305
- if analyze_similarity:
306
- _analyze_similarity(representative_data_gen,
307
- tb_w,
308
- tg,
309
- tg_float)
310
-
311
- return quantized_model, user_info
312
-
313
-
314
-
315
- def _prepare_model_for_quantization(in_model: Any,
316
- representative_data_gen: Callable,
317
- network_editor: List[EditRule] = [],
318
- n_iter: int = 500,
319
- quant_config: QuantizationConfig = DEFAULTCONFIG,
320
- fw_info: FrameworkInfo = None,
321
- tb_w: TensorboardWriter = None,
322
- fw_impl: FrameworkImplementation = None) -> Graph:
323
- """
324
- Prepare a trained Keras model for post-training quantization. The model is prepared to be quantized using a
325
- symmetric constraint quantization thresholds (power of two).
326
- The model is first read into a graph object and being optimized using several transformations (e.g.
327
- BatchNormalization folding to preceding layers). Then, using a given dataset, statistics (e.g. min/max,
328
- histogram, etc.) are being collected for each layer's output (and input, depends on the quantization configuration).
329
- Thresholds are then being calculated using the collected statistics. Finally, more transformations (based on
330
- statistics) are applied to increase model's performance.
331
-
332
- Args:
333
- in_model (Model): Keras model to optimize and prepare for quantization.
334
- representative_data_gen (Callable): Dataset used for calibration.
335
- network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to
336
- change quantization settings of the filtered nodes.
337
- n_iter (int): Number of calibration iterations to run.
338
- quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
339
- quantized.
340
- fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g.,
341
- kernel channels indices, groups of layers by how they should be quantized, etc.)
342
- tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
343
- fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
344
-
345
- Returns:
346
- Graph object that represents the Keras model, contains thresholds, and ready for quantization.
347
- """
348
-
349
- ######################################
350
- # Represent model in a graph
351
- ######################################
352
- graph = fw_impl.model_reader(in_model) # model reading
353
-
354
- if tb_w is not None:
355
- tb_w.add_graph(graph, 'initial_graph')
356
-
357
- ######################################
358
- # Graph substitution (pre statistics collection)
359
- ######################################
360
- transformed_graph = substitute(graph, fw_impl.get_substitutions_pre_statistics_collection())
361
-
362
- if tb_w is not None:
363
- tb_w.add_graph(transformed_graph, 'pre_statistics_collection_substitutions')
364
-
365
- ######################################
366
- # Graph marking points
367
- ######################################
368
- transformed_graph = substitute(transformed_graph, fw_impl.get_substitutions_marking())
369
-
370
- if tb_w is not None:
371
- tb_w.add_graph(transformed_graph, 'after_graph_marking')
372
-
373
- ######################################
374
- # Graph analyzing (attaching statistics collectors)
375
- ######################################
376
- analyzer_graph(fw_impl.attach_sc_to_node,
377
- transformed_graph,
378
- fw_info,
379
- quant_config) # Mark points for statistics collection
380
-
381
- if tb_w is not None:
382
- tb_w.add_graph(transformed_graph, 'after_analyzer_graph')
383
-
384
- ######################################
385
- # Statistic collection
386
- ######################################
387
- mi = ModelCollector(transformed_graph,
388
- fw_impl,
389
- fw_info)
390
-
391
- for _ in tqdm(range(n_iter)):
392
- mi.infer(representative_data_gen())
393
-
394
- ######################################
395
- # Add quantization configurations
396
- ######################################
397
- transformed_graph = set_quantization_configuration_to_graph(transformed_graph,
398
- quant_config,
399
- fw_info)
400
-
401
- ######################################
402
- # Edit network according to user specific settings
403
- ######################################
404
- edit_network_graph(transformed_graph, fw_info, network_editor)
405
-
406
- ######################################
407
- # Calculate quantization params
408
- ######################################
409
- calculate_quantization_params(transformed_graph,
410
- fw_info,
411
- fw_impl=fw_impl)
412
-
413
- if tb_w is not None:
414
- tb_w.add_graph(transformed_graph, 'thresholds_selection')
415
- tb_w.add_all_statistics(transformed_graph, 'thresholds_selection')
416
-
417
- ######################################
418
- # Graph substitution (post statistics collection)
419
- ######################################
420
- transformed_graph = substitute(transformed_graph,
421
- fw_impl.get_substitutions_post_statistics_collection(quant_config))
422
-
423
-
424
- ######################################
425
- # Channel equalization
426
- ######################################
427
- transformed_graph = substitute(transformed_graph,
428
- fw_impl.get_substitutions_channel_equalization(quant_config,
429
- fw_info))
430
-
431
- ######################################
432
- # Shift Negative Activations
433
- ######################################
434
- if quant_config.shift_negative_activation_correction:
435
- transformed_graph = fw_impl.shift_negative_correction(transformed_graph,
436
- quant_config,
437
- fw_info)
438
- if tb_w is not None:
439
- tb_w.add_graph(transformed_graph, 'after_shift_negative_correction')
440
- tb_w.add_all_statistics(transformed_graph, 'after_shift_negative_correction')
441
-
442
- if tb_w is not None:
443
- tb_w.add_graph(transformed_graph, 'post_statistics_collection_substitutions')
444
- tb_w.add_all_statistics(transformed_graph, 'post_statistics_collection_substitutions')
445
-
446
- ########################################################
447
- # Compute bias correction to nodes' config candidates
448
- ########################################################
449
- tg_with_bias = compute_bias_correction_of_graph(transformed_graph,
450
- fw_info,
451
- fw_impl)
452
-
453
- if tb_w is not None:
454
- tb_w.add_graph(tg_with_bias, 'bias_correction_computation')
455
-
456
- for n in tg_with_bias.nodes:
457
- assert n.final_weights_quantization_cfg is None
458
-
459
- return tg_with_bias
@@ -1,14 +0,0 @@
1
- # Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,40 +0,0 @@
1
- # Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import copy
17
-
18
- from typing import List
19
-
20
- from model_compression_toolkit import common
21
-
22
-
23
- def substitute(graph_to_substitute: common.Graph,
24
- substitutions_list: List[common.BaseSubstitution]) -> common.Graph:
25
- """
26
- Apply a list of substitutions on a graph.
27
- Args:
28
- graph: Graph to transform.
29
- substitutions_list: List of substitutions to apply on the graph.
30
-
31
- Returns:
32
- Transformed graph after applying all substitutions in substitutions_list.
33
- """
34
-
35
- graph = copy.deepcopy(graph_to_substitute)
36
- for substitution in substitutions_list:
37
- matched_nodes = graph.filter(substitution.matcher_instance)
38
- for idn in matched_nodes:
39
- graph = substitution.substitute(graph, idn)
40
- return graph