mct-nightly 0.0.0__py3-none-any.whl → 1.1.0.01122021-003325__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/METADATA +3 -2
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/RECORD +31 -38
- model_compression_toolkit/__init__.py +2 -6
- model_compression_toolkit/common/base_substitutions.py +1 -0
- model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +9 -12
- model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +8 -21
- model_compression_toolkit/common/collectors/histogram_collector.py +1 -1
- model_compression_toolkit/common/graph/base_graph.py +2 -4
- model_compression_toolkit/common/graph/graph_matchers.py +3 -1
- model_compression_toolkit/common/graph/graph_searches.py +3 -1
- model_compression_toolkit/common/mixed_precision/bit_width_setter.py +1 -2
- model_compression_toolkit/common/network_editors/node_filters.py +1 -0
- model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py +1 -1
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +3 -5
- model_compression_toolkit/common/quantization/quantize_graph_weights.py +4 -7
- model_compression_toolkit/common/quantization/quantize_node.py +3 -5
- model_compression_toolkit/keras/__init__.py +2 -0
- model_compression_toolkit/keras/back2framework/model_builder.py +24 -1
- model_compression_toolkit/{common → keras/back2framework}/model_collector.py +9 -18
- model_compression_toolkit/keras/default_framework_info.py +0 -1
- model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +57 -10
- model_compression_toolkit/keras/graph_substitutions/substituter.py +171 -0
- model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +26 -6
- model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +12 -5
- model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +3 -4
- model_compression_toolkit/keras/quantization_facade.py +524 -188
- model_compression_toolkit/keras/reader/connectivity_handler.py +4 -1
- model_compression_toolkit/keras/visualization/nn_visualizer.py +1 -2
- model_compression_toolkit/common/framework_implementation.py +0 -239
- model_compression_toolkit/common/gptq/__init__.py +0 -14
- model_compression_toolkit/common/gptq/gptq_config.py +0 -65
- model_compression_toolkit/common/model_builder_mode.py +0 -34
- model_compression_toolkit/common/post_training_quantization.py +0 -459
- model_compression_toolkit/common/substitutions/__init__.py +0 -14
- model_compression_toolkit/common/substitutions/apply_substitutions.py +0 -40
- model_compression_toolkit/keras/keras_implementation.py +0 -256
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/LICENSE +0 -0
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/WHEEL +0 -0
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/top_level.txt +0 -0
|
@@ -13,198 +13,534 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
|
|
17
|
+
import copy
|
|
18
|
+
import os
|
|
19
|
+
from functools import partial
|
|
20
|
+
from typing import Callable, List, Tuple
|
|
21
|
+
|
|
22
|
+
from tensorflow.keras.models import Model
|
|
23
|
+
from tqdm import tqdm
|
|
17
24
|
|
|
18
25
|
from model_compression_toolkit import common
|
|
19
|
-
from model_compression_toolkit.common import Logger
|
|
20
|
-
from model_compression_toolkit.common.gptq.gptq_config import GradientPTQConfig
|
|
21
26
|
from model_compression_toolkit.common.mixed_precision.kpi import KPI
|
|
22
|
-
from model_compression_toolkit.common
|
|
27
|
+
from model_compression_toolkit.common import FrameworkInfo
|
|
28
|
+
from model_compression_toolkit.common.constants import NUM_SAMPLES_CS_TENSORBOARD
|
|
29
|
+
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
30
|
+
from model_compression_toolkit.common.mixed_precision.bit_width_setter import set_bit_widths
|
|
31
|
+
|
|
32
|
+
from model_compression_toolkit.common.mixed_precision.mixed_precision_search_facade import search_bit_width
|
|
23
33
|
from model_compression_toolkit.common.network_editors.actions import EditRule
|
|
34
|
+
from model_compression_toolkit.common.network_editors.edit_network import edit_network_graph
|
|
24
35
|
from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
|
|
25
36
|
MixedPrecisionQuantizationConfig, DEFAULT_MIXEDPRECISION_CONFIG
|
|
26
|
-
from model_compression_toolkit.common.
|
|
27
|
-
from model_compression_toolkit.common.
|
|
28
|
-
from model_compression_toolkit.common.quantization.quantization_config import DEFAULTCONFIG
|
|
29
|
-
|
|
30
|
-
import importlib
|
|
31
|
-
|
|
32
|
-
if importlib.util.find_spec("tensorflow") is not None\
|
|
33
|
-
and importlib.util.find_spec("tensorflow_model_optimization") is not None:
|
|
34
|
-
from model_compression_toolkit.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
35
|
-
from model_compression_toolkit.keras.keras_implementation import KerasImplementation
|
|
36
|
-
from tensorflow.keras.models import Model
|
|
37
|
-
|
|
38
|
-
def keras_post_training_quantization(in_model: Model,
|
|
39
|
-
representative_data_gen: Callable,
|
|
40
|
-
n_iter: int = 500,
|
|
41
|
-
quant_config: QuantizationConfig = DEFAULTCONFIG,
|
|
42
|
-
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
43
|
-
network_editor: List[EditRule] = [],
|
|
44
|
-
gptq_config: GradientPTQConfig = None,
|
|
45
|
-
analyze_similarity: bool = False):
|
|
46
|
-
"""
|
|
47
|
-
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
|
48
|
-
symmetric constraint quantization thresholds (power of two).
|
|
49
|
-
The model is first optimized using several transformations (e.g. BatchNormalization folding to
|
|
50
|
-
preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
|
|
51
|
-
being collected for each layer's output (and input, depends on the quantization configuration).
|
|
52
|
-
Thresholds are then being calculated using the collected statistics and the model is quantized
|
|
53
|
-
(both coefficients and activations by default).
|
|
54
|
-
If a gptq configuration is passed, the quantized weights are optimized using gradient based post
|
|
55
|
-
training quantization by comparing points between the float and quantized models, and minimizing the observed loss.
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
in_model (Model): Keras model to quantize.
|
|
59
|
-
representative_data_gen (Callable): Dataset used for calibration.
|
|
60
|
-
n_iter (int): Number of calibration iterations to run.
|
|
61
|
-
quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized. `Default configuration. <https://github.com/sony/model_optimization/blob/21e21c95ca25a31874a5be7af9dd2dd5da8f3a10/model_compression_toolkit/common/quantization/quantization_config.py#L154>`_
|
|
62
|
-
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/21e21c95ca25a31874a5be7af9dd2dd5da8f3a10/model_compression_toolkit/keras/default_framework_info.py#L113>`_
|
|
63
|
-
network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
|
|
64
|
-
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
|
|
65
|
-
analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
|
|
66
|
-
|
|
67
|
-
Returns:
|
|
68
|
-
A quantized model and information the user may need to handle the quantized model.
|
|
69
|
-
|
|
70
|
-
Examples:
|
|
71
|
-
Import a Keras model:
|
|
72
|
-
|
|
73
|
-
>>> from tensorflow.keras.applications.mobilenet import MobileNet
|
|
74
|
-
>>> model = MobileNet()
|
|
75
|
-
|
|
76
|
-
Create a random dataset generator:
|
|
77
|
-
|
|
78
|
-
>>> import numpy as np
|
|
79
|
-
>>> def repr_datagen(): return [np.random.random((1,224,224,3))]
|
|
80
|
-
|
|
81
|
-
Import mct and pass the model with the representative dataset generator to get a quantized model:
|
|
82
|
-
|
|
83
|
-
>>> import model_compression_toolkit as mct
|
|
84
|
-
>>> quantized_model, quantization_info = mct.keras_post_training_quantization(model, repr_datagen)
|
|
85
|
-
|
|
86
|
-
"""
|
|
87
|
-
|
|
88
|
-
return post_training_quantization(in_model,
|
|
89
|
-
representative_data_gen,
|
|
90
|
-
n_iter,
|
|
91
|
-
quant_config,
|
|
92
|
-
fw_info,
|
|
93
|
-
KerasImplementation(),
|
|
94
|
-
network_editor,
|
|
95
|
-
gptq_config,
|
|
96
|
-
analyze_similarity)
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def keras_post_training_quantization_mixed_precision(in_model: Model,
|
|
100
|
-
representative_data_gen: Callable,
|
|
101
|
-
n_iter: int = 500,
|
|
102
|
-
quant_config: MixedPrecisionQuantizationConfig = DEFAULT_MIXEDPRECISION_CONFIG,
|
|
103
|
-
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
104
|
-
network_editor: List[EditRule] = [],
|
|
105
|
-
gptq_config: GradientPTQConfig = None,
|
|
106
|
-
analyze_similarity: bool = False,
|
|
107
|
-
target_kpi: KPI = None):
|
|
108
|
-
"""
|
|
109
|
-
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
|
110
|
-
symmetric constraint quantization thresholds (power of two).
|
|
111
|
-
The model is first optimized using several transformations (e.g. BatchNormalization folding to
|
|
112
|
-
preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
|
|
113
|
-
being collected for each layer's output (and input, depends on the quantization configuration).
|
|
114
|
-
For each possible bit width (per layer) a threshold is then being calculated using the collected
|
|
115
|
-
statistics. Then, using an ILP solver we find a mixed-precision configuration, and set a bit width
|
|
116
|
-
for each layer. The model is then quantized (both coefficients and activations by default).
|
|
117
|
-
In order to limit the maximal model's size, a target KPI can be passed after weights_memory
|
|
118
|
-
is set (in bytes).
|
|
119
|
-
For now, mixed precision is supported for weights only.
|
|
120
|
-
If a gptq configuration is passed, the quantized weights are optimized using gradient based post
|
|
121
|
-
training quantization by comparing points between the float and quantized models, and minimizing the observed loss.
|
|
122
|
-
Notice that this feature is experimental.
|
|
123
|
-
|
|
124
|
-
Args:
|
|
125
|
-
in_model (Model): Keras model to quantize.
|
|
126
|
-
representative_data_gen (Callable): Dataset used for calibration.
|
|
127
|
-
n_iter (int): Number of calibration iterations to run.
|
|
128
|
-
quant_config (MixedPrecisionQuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized.
|
|
129
|
-
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/keras/default_framework_info.py#L100>`_
|
|
130
|
-
network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
|
|
131
|
-
gptq_config (GradientPTQConfig): Configuration for using GPTQ (e.g. optimizer).
|
|
132
|
-
analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
|
|
133
|
-
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
134
|
-
|
|
135
|
-
Returns:
|
|
136
|
-
A quantized model and information the user may need to handle the quantized model.
|
|
137
|
-
|
|
138
|
-
Examples:
|
|
139
|
-
Import MCT:
|
|
140
|
-
|
|
141
|
-
>>> import model_compression_toolkit as mct
|
|
142
|
-
|
|
143
|
-
Import a Keras model:
|
|
144
|
-
|
|
145
|
-
>>> from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
|
|
146
|
-
>>> model = MobileNetV2()
|
|
147
|
-
|
|
148
|
-
Create a random dataset generator:
|
|
149
|
-
|
|
150
|
-
>>> import numpy as np
|
|
151
|
-
>>> def repr_datagen(): return [np.random.random((1,224,224,3))]
|
|
152
|
-
|
|
153
|
-
Create a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
|
|
154
|
-
Here, each layer can be quantized by 2, 4 or 8 bits:
|
|
155
|
-
|
|
156
|
-
>>> config = mct.MixedPrecisionQuantizationConfig(weights_n_bits=[4, 2, 8])
|
|
157
|
-
|
|
158
|
-
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value, while the bias will not):
|
|
159
|
-
|
|
160
|
-
>>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
161
|
-
|
|
162
|
-
Pass the model, the representative dataset generator, the configuration and the target KPI to get a quantized model:
|
|
163
|
-
|
|
164
|
-
>>> quantized_model, quantization_info = mct.keras_post_training_quantization_mixed_precision(model, repr_datagen, n_iter=10, quant_config=config, target_kpi=kpi)
|
|
165
|
-
|
|
166
|
-
For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
|
|
167
|
-
|
|
168
|
-
"""
|
|
169
|
-
|
|
170
|
-
if target_kpi is None:
|
|
171
|
-
common.Logger.warning("No KPI was passed. Using non mixed-precision compression process...")
|
|
172
|
-
# Before starting non-mixed-precision process, we need to set only single bit width, so we take the best
|
|
173
|
-
# option which is the maximal number of bits.
|
|
174
|
-
quant_config.weights_n_bits = [max(quant_config.weights_n_bits)]
|
|
175
|
-
return keras_post_training_quantization(in_model,
|
|
176
|
-
representative_data_gen,
|
|
177
|
-
n_iter,
|
|
178
|
-
quant_config,
|
|
179
|
-
fw_info,
|
|
180
|
-
network_editor,
|
|
181
|
-
gptq_config,
|
|
182
|
-
analyze_similarity)
|
|
183
|
-
|
|
184
|
-
common.Logger.info("Using experimental mixed-precision quantization. "
|
|
185
|
-
"If you encounter an issue please file a bug.")
|
|
186
|
-
|
|
187
|
-
return post_training_quantization(in_model,
|
|
188
|
-
representative_data_gen,
|
|
189
|
-
n_iter,
|
|
190
|
-
quant_config,
|
|
191
|
-
fw_info,
|
|
192
|
-
KerasImplementation(),
|
|
193
|
-
network_editor,
|
|
194
|
-
gptq_config,
|
|
195
|
-
analyze_similarity,
|
|
196
|
-
target_kpi)
|
|
197
|
-
|
|
198
|
-
else:
|
|
199
|
-
# If tensorflow or tensorflow_model_optimization are not installed,
|
|
200
|
-
# we raise an exception when trying to use these functions.
|
|
201
|
-
def keras_post_training_quantization(*args, **kwargs):
|
|
202
|
-
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
203
|
-
'when using keras_post_training_quantization. '
|
|
204
|
-
'Could not find Tensorflow package.')
|
|
205
|
-
|
|
206
|
-
def keras_post_training_quantization_mixed_precision(*args, **kwargs):
|
|
207
|
-
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
208
|
-
'when using keras_post_training_quantization_mixed_precision. '
|
|
209
|
-
'Could not find Tensorflow package.')
|
|
37
|
+
from model_compression_toolkit.common.quantization.quantize_graph_weights import quantize_graph_weights
|
|
38
|
+
from model_compression_toolkit.common.bias_correction.compute_bias_correction_of_graph import compute_bias_correction_of_graph
|
|
210
39
|
|
|
40
|
+
from model_compression_toolkit.common.quantization.quantization_analyzer import analyzer_graph
|
|
41
|
+
from model_compression_toolkit.common.quantization.quantization_config import DEFAULTCONFIG
|
|
42
|
+
from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
|
|
43
|
+
from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_computation import \
|
|
44
|
+
calculate_quantization_params
|
|
45
|
+
|
|
46
|
+
from model_compression_toolkit.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
|
47
|
+
from model_compression_toolkit.common.user_info import UserInformation
|
|
48
|
+
from model_compression_toolkit.keras.back2framework.model_builder import model_builder, ModelBuilderMode
|
|
49
|
+
from model_compression_toolkit.keras.back2framework.model_collector import ModelCollector
|
|
50
|
+
from model_compression_toolkit.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
51
|
+
from model_compression_toolkit.keras.graph_substitutions.substituter import graph_marking_substitute
|
|
52
|
+
from model_compression_toolkit.keras.graph_substitutions.substituter import post_statistics_collection_substitute
|
|
53
|
+
from model_compression_toolkit.keras.graph_substitutions.substituter import pre_statistics_collection_substitute
|
|
54
|
+
from model_compression_toolkit.keras.gradient_ptq.training_wrapper import GradientPTQConfig
|
|
55
|
+
from model_compression_toolkit.keras.gradient_ptq.training_wrapper import gptq_training_wrapper
|
|
56
|
+
from model_compression_toolkit.keras.mixed_precision.sensitivity_evaluation import get_sensitivity_evaluation
|
|
57
|
+
from model_compression_toolkit.keras.reader.reader import model_reader
|
|
58
|
+
from model_compression_toolkit.keras.tensor_marking import get_node_stats_collector
|
|
59
|
+
from model_compression_toolkit.common.visualization.tensorboard_writer import TensorboardWriter
|
|
60
|
+
from model_compression_toolkit.common.bias_correction.apply_bias_correction_to_graph import apply_bias_correction_to_graph
|
|
61
|
+
from model_compression_toolkit.keras.visualization.nn_visualizer import KerasNNVisualizer
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _prepare_model_for_quantization(in_model: Model,
|
|
65
|
+
representative_data_gen: Callable,
|
|
66
|
+
network_editor: List[EditRule] = [],
|
|
67
|
+
n_iter: int = 500,
|
|
68
|
+
quant_config: QuantizationConfig = DEFAULTCONFIG,
|
|
69
|
+
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
70
|
+
tb_w: TensorboardWriter = None) -> Graph:
|
|
71
|
+
"""
|
|
72
|
+
Prepare a trained Keras model for post-training quantization. The model is prepared to be quantized using a
|
|
73
|
+
symmetric constraint quantization thresholds (power of two).
|
|
74
|
+
The model is first read into a graph object and being optimized using several transformations (e.g.
|
|
75
|
+
BatchNormalization folding to preceding layers). Then, using a given dataset, statistics (e.g. min/max,
|
|
76
|
+
histogram, etc.) are being collected for each layer's output (and input, depends on the quantization configuration).
|
|
77
|
+
Thresholds are then being calculated using the collected statistics. Finally, more transformations (based on
|
|
78
|
+
statistics) are applied to increase model's performance.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
in_model (Model): Keras model to optimize and prepare for quantization.
|
|
82
|
+
representative_data_gen (Callable): Dataset used for calibration.
|
|
83
|
+
network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to
|
|
84
|
+
change quantization settings of the filtered nodes.
|
|
85
|
+
n_iter (int): Number of calibration iterations to run.
|
|
86
|
+
quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
|
|
87
|
+
quantized.
|
|
88
|
+
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g.,
|
|
89
|
+
kernel channels indices, groups of layers by how they should be quantized, etc.)
|
|
90
|
+
tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Graph object that represents the Keras model, contains thresholds, and ready for quantization.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
######################################
|
|
97
|
+
# Represent model in a graph
|
|
98
|
+
######################################
|
|
99
|
+
graph = model_reader(in_model) # model reading
|
|
100
|
+
|
|
101
|
+
if tb_w is not None:
|
|
102
|
+
tb_w.add_graph(graph, 'initial_graph')
|
|
103
|
+
|
|
104
|
+
######################################
|
|
105
|
+
# Graph substitution (pre statistics collection)
|
|
106
|
+
######################################
|
|
107
|
+
transformed_graph = pre_statistics_collection_substitute(graph)
|
|
108
|
+
|
|
109
|
+
if tb_w is not None:
|
|
110
|
+
tb_w.add_graph(transformed_graph, 'pre_statistics_collection_substitutions')
|
|
111
|
+
|
|
112
|
+
######################################
|
|
113
|
+
# Graph marking points
|
|
114
|
+
######################################
|
|
115
|
+
transformed_graph = graph_marking_substitute(transformed_graph)
|
|
116
|
+
|
|
117
|
+
if tb_w is not None:
|
|
118
|
+
tb_w.add_graph(transformed_graph, 'after_graph_marking')
|
|
119
|
+
|
|
120
|
+
######################################
|
|
121
|
+
# Graph analyzing (attaching statistics collectors)
|
|
122
|
+
######################################
|
|
123
|
+
analyzer_graph(get_node_stats_collector,
|
|
124
|
+
transformed_graph,
|
|
125
|
+
fw_info,
|
|
126
|
+
quant_config) # Mark points for statistics collection
|
|
127
|
+
|
|
128
|
+
if tb_w is not None:
|
|
129
|
+
tb_w.add_graph(transformed_graph, 'after_analyzer_graph')
|
|
130
|
+
|
|
131
|
+
######################################
|
|
132
|
+
# Statistic collection
|
|
133
|
+
######################################
|
|
134
|
+
mi = ModelCollector(transformed_graph)
|
|
135
|
+
for _ in tqdm(range(n_iter)):
|
|
136
|
+
mi.infer(representative_data_gen())
|
|
137
|
+
|
|
138
|
+
######################################
|
|
139
|
+
# Add quantization configurations
|
|
140
|
+
######################################
|
|
141
|
+
transformed_graph = set_quantization_configuration_to_graph(transformed_graph, quant_config, fw_info)
|
|
142
|
+
|
|
143
|
+
######################################
|
|
144
|
+
# Edit network according to user specific settings
|
|
145
|
+
######################################
|
|
146
|
+
edit_network_graph(transformed_graph, fw_info, network_editor)
|
|
147
|
+
|
|
148
|
+
######################################
|
|
149
|
+
# Calculate quantization params
|
|
150
|
+
######################################
|
|
151
|
+
calculate_quantization_params(transformed_graph, fw_info)
|
|
152
|
+
|
|
153
|
+
if tb_w is not None:
|
|
154
|
+
tb_w.add_graph(transformed_graph, 'thresholds_selection')
|
|
155
|
+
tb_w.add_all_statistics(transformed_graph, 'thresholds_selection')
|
|
156
|
+
|
|
157
|
+
######################################
|
|
158
|
+
# Graph substitution (post statistics collection)
|
|
159
|
+
######################################
|
|
160
|
+
transformed_graph = post_statistics_collection_substitute(transformed_graph,
|
|
161
|
+
quant_config,
|
|
162
|
+
fw_info)
|
|
163
|
+
|
|
164
|
+
if tb_w is not None:
|
|
165
|
+
tb_w.add_graph(transformed_graph, 'post_statistics_collection_substitutions')
|
|
166
|
+
tb_w.add_all_statistics(transformed_graph, 'post_statistics_collection_substitutions')
|
|
167
|
+
|
|
168
|
+
########################################################
|
|
169
|
+
# Compute bias correction to nodes' config candidates
|
|
170
|
+
########################################################
|
|
171
|
+
tg_with_bias = compute_bias_correction_of_graph(transformed_graph, fw_info)
|
|
172
|
+
|
|
173
|
+
if tb_w is not None:
|
|
174
|
+
tb_w.add_graph(tg_with_bias, 'bias_correction_computation')
|
|
175
|
+
|
|
176
|
+
for n in tg_with_bias.nodes:
|
|
177
|
+
assert n.final_weights_quantization_cfg is None
|
|
178
|
+
|
|
179
|
+
return tg_with_bias
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def keras_post_training_quantization(in_model: Model,
|
|
183
|
+
representative_data_gen: Callable,
|
|
184
|
+
n_iter: int = 500,
|
|
185
|
+
quant_config: QuantizationConfig = DEFAULTCONFIG,
|
|
186
|
+
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
187
|
+
network_editor: List[EditRule] = [],
|
|
188
|
+
gptq_config: GradientPTQConfig = None,
|
|
189
|
+
analyze_similarity: bool = False):
|
|
190
|
+
"""
|
|
191
|
+
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
|
192
|
+
symmetric constraint quantization thresholds (power of two).
|
|
193
|
+
The model is first optimized using several transformations (e.g. BatchNormalization folding to
|
|
194
|
+
preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
|
|
195
|
+
being collected for each layer's output (and input, depends on the quantization configuration).
|
|
196
|
+
Thresholds are then being calculated using the collected statistics and the model is quantized
|
|
197
|
+
(both coefficients and activations by default).
|
|
198
|
+
If a gptq configuration is passed, the quantized weights are optimized using gradient based post
|
|
199
|
+
training quantization by comparing points between the float and quantized models, and minimizing the observed loss.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
in_model (Model): Keras model to quantize.
|
|
203
|
+
representative_data_gen (Callable): Dataset used for calibration.
|
|
204
|
+
n_iter (int): Number of calibration iterations to run.
|
|
205
|
+
quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized. `Default configuration. <https://github.com/sony/model_optimization/blob/21e21c95ca25a31874a5be7af9dd2dd5da8f3a10/model_compression_toolkit/common/quantization/quantization_config.py#L154>`_
|
|
206
|
+
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/21e21c95ca25a31874a5be7af9dd2dd5da8f3a10/model_compression_toolkit/keras/default_framework_info.py#L113>`_
|
|
207
|
+
network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
|
|
208
|
+
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
|
|
209
|
+
analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
A quantized model and information the user may need to handle the quantized model.
|
|
213
|
+
|
|
214
|
+
Examples:
|
|
215
|
+
Import a Keras model:
|
|
216
|
+
|
|
217
|
+
>>> from tensorflow.keras.applications.mobilenet import MobileNet
|
|
218
|
+
>>> model = MobileNet()
|
|
219
|
+
|
|
220
|
+
Create a random dataset generator:
|
|
221
|
+
|
|
222
|
+
>>> import numpy as np
|
|
223
|
+
>>> def repr_datagen(): return [np.random.random((1,224,224,3))]
|
|
224
|
+
|
|
225
|
+
Import mct and pass the model with the representative dataset generator to get a quantized model:
|
|
226
|
+
|
|
227
|
+
>>> import model_compression_toolkit as mct
|
|
228
|
+
>>> quantized_model, quantization_info = mct.keras_post_training_quantization(model, repr_datagen)
|
|
229
|
+
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
if quant_config.weights_bias_correction and gptq_config is not None:
|
|
233
|
+
common.Logger.error('weights_bias_correction should be disabled in GPTQ mode')
|
|
234
|
+
|
|
235
|
+
tb_w = _init_tensorboard_writer()
|
|
236
|
+
|
|
237
|
+
tg = _prepare_model_for_quantization(in_model,
|
|
238
|
+
representative_data_gen,
|
|
239
|
+
network_editor,
|
|
240
|
+
n_iter,
|
|
241
|
+
quant_config,
|
|
242
|
+
fw_info,
|
|
243
|
+
tb_w)
|
|
244
|
+
|
|
245
|
+
######################################
|
|
246
|
+
# Finalize bit widths
|
|
247
|
+
######################################
|
|
248
|
+
tg = set_bit_widths(quant_config,
|
|
249
|
+
tg,
|
|
250
|
+
fw_info)
|
|
251
|
+
|
|
252
|
+
quantized_model, user_info = _quantize_fixed_bit_widths_graph(analyze_similarity,
|
|
253
|
+
fw_info,
|
|
254
|
+
gptq_config,
|
|
255
|
+
representative_data_gen,
|
|
256
|
+
tb_w,
|
|
257
|
+
tg)
|
|
258
|
+
|
|
259
|
+
return quantized_model, user_info
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def keras_post_training_quantization_mixed_precision(in_model: Model,
|
|
263
|
+
representative_data_gen: Callable,
|
|
264
|
+
n_iter: int = 500,
|
|
265
|
+
quant_config: MixedPrecisionQuantizationConfig = DEFAULT_MIXEDPRECISION_CONFIG,
|
|
266
|
+
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
267
|
+
network_editor: List[EditRule] = [],
|
|
268
|
+
gptq_config: GradientPTQConfig = None,
|
|
269
|
+
bit_widths_config: List[int] = None,
|
|
270
|
+
analyze_similarity: bool = False,
|
|
271
|
+
target_kpi: KPI = None):
|
|
272
|
+
"""
|
|
273
|
+
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
|
274
|
+
symmetric constraint quantization thresholds (power of two).
|
|
275
|
+
The model is first optimized using several transformations (e.g. BatchNormalization folding to
|
|
276
|
+
preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
|
|
277
|
+
being collected for each layer's output (and input, depends on the quantization configuration).
|
|
278
|
+
For each possible bit width (per layer) a threshold is then being calculated using the collected
|
|
279
|
+
statistics. Then, using an ILP solver we find a mixed-precision configuration, and set a bit width
|
|
280
|
+
for each layer. The model is then quantized (both coefficients and activations by default).
|
|
281
|
+
In order to limit the maximal model's size, a target KPI can be passed after weights_memory
|
|
282
|
+
is set (in bytes).
|
|
283
|
+
For now, mixed precision is supported for weights only.
|
|
284
|
+
If a gptq configuration is passed, the quantized weights are optimized using gradient based post
|
|
285
|
+
training quantization by comparing points between the float and quantized models, and minimizing the observed loss.
|
|
286
|
+
Notice that this feature is experimental.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
in_model (Model): Keras model to quantize.
|
|
290
|
+
representative_data_gen (Callable): Dataset used for calibration.
|
|
291
|
+
n_iter (int): Number of calibration iterations to run.
|
|
292
|
+
quant_config (MixedPrecisionQuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized.
|
|
293
|
+
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/keras/default_framework_info.py#L100>`_
|
|
294
|
+
network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
|
|
295
|
+
gptq_config (GradientPTQConfig): Configuration for using GPTQ (e.g. optimizer).
|
|
296
|
+
bit_widths_config (List[int]): Mixed-precision configuration to set bit widths for different layers.
|
|
297
|
+
analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
|
|
298
|
+
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
A quantized model and information the user may need to handle the quantized model.
|
|
302
|
+
|
|
303
|
+
Examples:
|
|
304
|
+
Import MCT:
|
|
305
|
+
|
|
306
|
+
>>> import model_compression_toolkit as mct
|
|
307
|
+
|
|
308
|
+
Import a Keras model:
|
|
309
|
+
|
|
310
|
+
>>> from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
|
|
311
|
+
>>> model = MobileNetV2()
|
|
312
|
+
|
|
313
|
+
Create a random dataset generator:
|
|
314
|
+
|
|
315
|
+
>>> import numpy as np
|
|
316
|
+
>>> def repr_datagen(): return [np.random.random((1,224,224,3))]
|
|
317
|
+
|
|
318
|
+
Create a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
|
|
319
|
+
Here, each layer can be quantized by 2, 4 or 8 bits:
|
|
320
|
+
|
|
321
|
+
>>> config = mct.MixedPrecisionQuantizationConfig(weights_n_bits=[4, 2, 8])
|
|
322
|
+
|
|
323
|
+
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value, while the bias will not):
|
|
324
|
+
|
|
325
|
+
>>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
326
|
+
|
|
327
|
+
Pass the model, the representative dataset generator, the configuration and the target KPI to get a quantized model:
|
|
328
|
+
|
|
329
|
+
>>> quantized_model, quantization_info = mct.keras_post_training_quantization_mixed_precision(model, repr_datagen, n_iter=10, quant_config=config, target_kpi=kpi)
|
|
330
|
+
|
|
331
|
+
For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
|
|
332
|
+
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
if quant_config.weights_bias_correction and gptq_config is not None:
|
|
336
|
+
common.Logger.error('weights_bias_correction should be disabled in GPTQ mode')
|
|
337
|
+
|
|
338
|
+
common.Logger.info("Using experimental mixed-precision quantization. "
|
|
339
|
+
"If you encounter an issue please file a bug.")
|
|
340
|
+
|
|
341
|
+
if target_kpi is None:
|
|
342
|
+
common.Logger.warning("No KPI was passed. Using non mixed-precision compression process...")
|
|
343
|
+
# Before starting non-mixed-precision process, we need to set only single bit width, so we take the best
|
|
344
|
+
# option which is the maximal number of bits.
|
|
345
|
+
quant_config.weights_n_bits = [max(quant_config.weights_n_bits)]
|
|
346
|
+
return keras_post_training_quantization(in_model,
|
|
347
|
+
representative_data_gen,
|
|
348
|
+
n_iter,
|
|
349
|
+
quant_config,
|
|
350
|
+
fw_info,
|
|
351
|
+
network_editor,
|
|
352
|
+
gptq_config,
|
|
353
|
+
analyze_similarity)
|
|
354
|
+
|
|
355
|
+
tb_w = _init_tensorboard_writer()
|
|
356
|
+
|
|
357
|
+
tg = _prepare_model_for_quantization(in_model,
|
|
358
|
+
representative_data_gen,
|
|
359
|
+
network_editor,
|
|
360
|
+
n_iter,
|
|
361
|
+
quant_config,
|
|
362
|
+
fw_info,
|
|
363
|
+
tb_w)
|
|
364
|
+
|
|
365
|
+
######################################
|
|
366
|
+
# Finalize bit widths
|
|
367
|
+
######################################
|
|
368
|
+
|
|
369
|
+
if bit_widths_config is None:
|
|
370
|
+
bit_widths_config = search_bit_width(tg,
|
|
371
|
+
quant_config,
|
|
372
|
+
fw_info,
|
|
373
|
+
target_kpi,
|
|
374
|
+
partial(get_sensitivity_evaluation,
|
|
375
|
+
representative_data_gen=representative_data_gen,
|
|
376
|
+
fw_info=fw_info))
|
|
377
|
+
|
|
378
|
+
tg = set_bit_widths(quant_config,
|
|
379
|
+
tg,
|
|
380
|
+
fw_info,
|
|
381
|
+
bit_widths_config)
|
|
382
|
+
|
|
383
|
+
quantized_model, user_info = _quantize_fixed_bit_widths_graph(analyze_similarity,
|
|
384
|
+
fw_info,
|
|
385
|
+
gptq_config,
|
|
386
|
+
representative_data_gen,
|
|
387
|
+
tb_w,
|
|
388
|
+
tg)
|
|
389
|
+
user_info.mixed_precision_cfg = bit_widths_config
|
|
390
|
+
|
|
391
|
+
return quantized_model, user_info
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def _quantize_fixed_bit_widths_graph(analyze_similarity: bool,
|
|
395
|
+
fw_info: FrameworkInfo,
|
|
396
|
+
gptq_config: GradientPTQConfig,
|
|
397
|
+
representative_data_gen: Callable,
|
|
398
|
+
tb_w: TensorboardWriter,
|
|
399
|
+
tg: Graph) -> Tuple[Model, UserInformation]:
|
|
400
|
+
"""
|
|
401
|
+
Quantize a graph that has final weights candidates quantization configurations.
|
|
402
|
+
Before we quantize the graph weights, we apply GPTQ to get an improved graph.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
analyze_similarity: Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
|
|
406
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.)
|
|
407
|
+
gptq_config: Configuration for using GPTQ (e.g. optimizer).
|
|
408
|
+
representative_data_gen: Dataset used for GPTQ fine tuning.
|
|
409
|
+
tb_w: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
|
|
410
|
+
tg: Graph to apply GPTQ and to quantize.
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
A tuple of the quantized model and an object of UserInformation.
|
|
414
|
+
|
|
415
|
+
"""
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
#############################################
|
|
419
|
+
# Gradient Based Post Training Quantization
|
|
420
|
+
#############################################
|
|
421
|
+
tg = _apply_gptq(gptq_config,
|
|
422
|
+
representative_data_gen,
|
|
423
|
+
tb_w,
|
|
424
|
+
tg,
|
|
425
|
+
fw_info)
|
|
426
|
+
|
|
427
|
+
tg_float = copy.deepcopy(tg) # Copy graph before quantization (for similarity analyzer)
|
|
428
|
+
######################################
|
|
429
|
+
# Model Quantization
|
|
430
|
+
######################################
|
|
431
|
+
quantized_model, user_info = _quantize_model(fw_info,
|
|
432
|
+
tb_w,
|
|
433
|
+
tg)
|
|
434
|
+
if analyze_similarity:
|
|
435
|
+
_analyze_similarity(representative_data_gen,
|
|
436
|
+
tb_w,
|
|
437
|
+
tg,
|
|
438
|
+
tg_float)
|
|
439
|
+
|
|
440
|
+
return quantized_model, user_info
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def _init_tensorboard_writer() -> TensorboardWriter:
|
|
444
|
+
"""
|
|
445
|
+
|
|
446
|
+
Returns: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
|
|
447
|
+
|
|
448
|
+
"""
|
|
449
|
+
tb_w = None
|
|
450
|
+
if common.Logger.LOG_PATH is not None:
|
|
451
|
+
tb_log_dir = os.path.join(os.getcwd(), common.Logger.LOG_PATH, 'tensorboard_logs')
|
|
452
|
+
common.Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
|
|
453
|
+
tb_w = TensorboardWriter(tb_log_dir)
|
|
454
|
+
return tb_w
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def _quantize_model(fw_info: FrameworkInfo,
|
|
458
|
+
tb_w: TensorboardWriter,
|
|
459
|
+
tg: Graph) -> Tuple[Model, UserInformation]:
|
|
460
|
+
"""
|
|
461
|
+
Quantize graph's weights, and build a quantized Keras model from it.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
|
|
465
|
+
tb_w: TensorBoardWriter object to log events.
|
|
466
|
+
tg: A prepared for quantization graph.
|
|
467
|
+
|
|
468
|
+
Returns:
|
|
469
|
+
Quantize Keras model, and informat the user may need to use the quantized model.
|
|
470
|
+
"""
|
|
471
|
+
|
|
472
|
+
quantized_tg = quantize_graph_weights(tg, fw_info=fw_info)
|
|
473
|
+
if tb_w is not None:
|
|
474
|
+
tb_w.add_graph(quantized_tg, 'after_quantization')
|
|
475
|
+
|
|
476
|
+
quantized_graph_with_bias_correction = apply_bias_correction_to_graph(quantized_tg,
|
|
477
|
+
fw_info=fw_info)
|
|
478
|
+
if tb_w is not None:
|
|
479
|
+
tb_w.add_graph(quantized_graph_with_bias_correction, 'after_bias_correction')
|
|
480
|
+
|
|
481
|
+
######################################
|
|
482
|
+
# Back2Framework
|
|
483
|
+
######################################
|
|
484
|
+
quantized_model, user_info = model_builder(quantized_graph_with_bias_correction,
|
|
485
|
+
mode=ModelBuilderMode.QUANTIZED)
|
|
486
|
+
|
|
487
|
+
return quantized_model, user_info
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def _analyze_similarity(representative_data_gen: Callable,
|
|
491
|
+
tb_w: TensorboardWriter,
|
|
492
|
+
tg: Graph,
|
|
493
|
+
tg_float: Graph):
|
|
494
|
+
"""
|
|
495
|
+
Plot the cosine similarity of different points on the graph between the float and quantized
|
|
496
|
+
graphs. Add them to the passed TensorboardWriter object and close all tensorboard writer open
|
|
497
|
+
files.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
representative_data_gen: Dataset used for calibration.
|
|
501
|
+
tb_w: TensorBoardWriter object to log events.
|
|
502
|
+
tg: Graph of quantized model.
|
|
503
|
+
tg_float: Graph of float model.
|
|
504
|
+
|
|
505
|
+
"""
|
|
506
|
+
if tb_w is not None:
|
|
507
|
+
visual = KerasNNVisualizer(tg_float, tg)
|
|
508
|
+
for i in range(NUM_SAMPLES_CS_TENSORBOARD):
|
|
509
|
+
figure = visual.plot_cs_graph(representative_data_gen())
|
|
510
|
+
tb_w.add_figure(figure, f'cosine_similarity_sample_{i}')
|
|
511
|
+
tb_w.close()
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def _apply_gptq(gptq_config: GradientPTQConfig,
|
|
515
|
+
representative_data_gen: Callable,
|
|
516
|
+
tb_w: TensorboardWriter,
|
|
517
|
+
tg: Graph,
|
|
518
|
+
fw_info: FrameworkInfo) -> Graph:
|
|
519
|
+
"""
|
|
520
|
+
Apply GPTQ to improve accuracy of quantized model.
|
|
521
|
+
Build two models from a graph: A teacher network (float model) and a student network (quantized model).
|
|
522
|
+
and use the dataset generator to pass images through the teacher and student networks to get intermediate
|
|
523
|
+
layers outputs and maximize their similarity.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
gptq_config: Configuration for using GPTQ (e.g. optimizer).
|
|
527
|
+
representative_data_gen: Dataset used for calibration.
|
|
528
|
+
tb_w: TensorBoardWriter object to log events.
|
|
529
|
+
tg: Graph of quantized model.
|
|
530
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
|
|
531
|
+
|
|
532
|
+
Returns:
|
|
533
|
+
|
|
534
|
+
"""
|
|
535
|
+
if gptq_config is not None:
|
|
536
|
+
common.Logger.info("Using experimental Gradient Based PTQ: If you encounter an issue "
|
|
537
|
+
"please file a bug. To disable it, do not pass a gptq configuration.")
|
|
538
|
+
|
|
539
|
+
tg = gptq_training_wrapper(tg,
|
|
540
|
+
representative_data_gen,
|
|
541
|
+
gptq_config,
|
|
542
|
+
fw_info)
|
|
543
|
+
|
|
544
|
+
if tb_w is not None:
|
|
545
|
+
tb_w.add_graph(tg, 'after_gptq')
|
|
546
|
+
return tg
|