mct-nightly 0.0.0__py3-none-any.whl → 1.1.0.02122021-003117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/METADATA +3 -2
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/RECORD +31 -38
- model_compression_toolkit/__init__.py +2 -6
- model_compression_toolkit/common/base_substitutions.py +1 -0
- model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +9 -12
- model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +8 -21
- model_compression_toolkit/common/collectors/histogram_collector.py +1 -1
- model_compression_toolkit/common/graph/base_graph.py +2 -4
- model_compression_toolkit/common/graph/graph_matchers.py +3 -1
- model_compression_toolkit/common/graph/graph_searches.py +3 -1
- model_compression_toolkit/common/mixed_precision/bit_width_setter.py +1 -2
- model_compression_toolkit/common/network_editors/node_filters.py +1 -0
- model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py +1 -1
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +3 -5
- model_compression_toolkit/common/quantization/quantize_graph_weights.py +4 -7
- model_compression_toolkit/common/quantization/quantize_node.py +3 -5
- model_compression_toolkit/keras/__init__.py +2 -0
- model_compression_toolkit/keras/back2framework/model_builder.py +24 -1
- model_compression_toolkit/{common → keras/back2framework}/model_collector.py +9 -18
- model_compression_toolkit/keras/default_framework_info.py +0 -1
- model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +57 -10
- model_compression_toolkit/keras/graph_substitutions/substituter.py +171 -0
- model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +26 -6
- model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +12 -5
- model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +3 -4
- model_compression_toolkit/keras/quantization_facade.py +524 -188
- model_compression_toolkit/keras/reader/connectivity_handler.py +4 -1
- model_compression_toolkit/keras/visualization/nn_visualizer.py +1 -2
- model_compression_toolkit/common/framework_implementation.py +0 -239
- model_compression_toolkit/common/gptq/__init__.py +0 -14
- model_compression_toolkit/common/gptq/gptq_config.py +0 -65
- model_compression_toolkit/common/model_builder_mode.py +0 -34
- model_compression_toolkit/common/post_training_quantization.py +0 -459
- model_compression_toolkit/common/substitutions/__init__.py +0 -14
- model_compression_toolkit/common/substitutions/apply_substitutions.py +0 -40
- model_compression_toolkit/keras/keras_implementation.py +0 -256
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/LICENSE +0 -0
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/WHEEL +0 -0
- {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/top_level.txt +0 -0
|
@@ -14,12 +14,13 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
|
|
17
|
+
from collections import namedtuple
|
|
18
|
+
|
|
17
19
|
import tensorflow as tf
|
|
18
20
|
from tensorflow.python.keras.engine.node import Node as KerasNode
|
|
19
21
|
from tensorflow.python.util.object_identity import Reference as TFReference
|
|
20
22
|
from typing import List, Tuple
|
|
21
23
|
|
|
22
|
-
from model_compression_toolkit.common.graph.base_graph import OutTensor
|
|
23
24
|
from model_compression_toolkit.common.graph.node import Node
|
|
24
25
|
from model_compression_toolkit.keras.reader.common import is_node_an_input_layer
|
|
25
26
|
from model_compression_toolkit.keras.reader.node_builder import build_node
|
|
@@ -27,6 +28,8 @@ from model_compression_toolkit.keras.reader.node_builder import build_node
|
|
|
27
28
|
keras = tf.keras
|
|
28
29
|
layers = keras.layers
|
|
29
30
|
|
|
31
|
+
OutTensor = namedtuple('OutTensor', 'node node_out_index')
|
|
32
|
+
|
|
30
33
|
|
|
31
34
|
class ConnectivityHandler(object):
|
|
32
35
|
"""
|
|
@@ -20,8 +20,7 @@ from matplotlib.figure import Figure
|
|
|
20
20
|
|
|
21
21
|
from model_compression_toolkit.common import Graph
|
|
22
22
|
from model_compression_toolkit.common.similarity_analyzer import compute_cs
|
|
23
|
-
from model_compression_toolkit.keras.back2framework.model_builder import model_builder
|
|
24
|
-
from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
|
|
23
|
+
from model_compression_toolkit.keras.back2framework.model_builder import model_builder, ModelBuilderMode
|
|
25
24
|
from model_compression_toolkit.common.graph.node import Node
|
|
26
25
|
|
|
27
26
|
|
|
@@ -1,239 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from abc import ABC, abstractmethod
|
|
16
|
-
from typing import Callable, Any, List, Tuple
|
|
17
|
-
|
|
18
|
-
import numpy as np
|
|
19
|
-
|
|
20
|
-
from model_compression_toolkit import common, GradientPTQConfig, MixedPrecisionQuantizationConfig
|
|
21
|
-
from model_compression_toolkit.common import Node
|
|
22
|
-
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
23
|
-
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
24
|
-
from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
|
|
25
|
-
from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
|
|
26
|
-
from model_compression_toolkit.common.user_info import UserInformation
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class FrameworkImplementation(ABC):
|
|
30
|
-
"""
|
|
31
|
-
An abstract class with abstract methods that should be implemented when supporting a new
|
|
32
|
-
framework in MCT.
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
@property
|
|
36
|
-
def constants(self):
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
Returns: Module of the framework constants.
|
|
40
|
-
|
|
41
|
-
"""
|
|
42
|
-
raise Exception(f'{self.__class__.__name__} did not supply a constants module.')
|
|
43
|
-
|
|
44
|
-
@abstractmethod
|
|
45
|
-
def model_builder(self,
|
|
46
|
-
graph: Graph,
|
|
47
|
-
mode: ModelBuilderMode,
|
|
48
|
-
append2output: List[Any],
|
|
49
|
-
fw_info: FrameworkInfo) -> Tuple[Any, UserInformation]:
|
|
50
|
-
"""
|
|
51
|
-
Build a framework model from a graph.
|
|
52
|
-
The mode determines how the model should be build. append2output is a list of Nodes
|
|
53
|
-
to set as the model outputs.
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
graph: Graph to build the model from it.
|
|
57
|
-
mode: Mode for how to build the model.
|
|
58
|
-
append2output: List of Nodes to set as the model's outputs.
|
|
59
|
-
fw_info: FrameworkInfo object with information about the specific framework's model
|
|
60
|
-
|
|
61
|
-
Returns:
|
|
62
|
-
A tuple of the model that was built and an UserInformation object.
|
|
63
|
-
"""
|
|
64
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
65
|
-
f'framework\'s model_builder method.')
|
|
66
|
-
|
|
67
|
-
@abstractmethod
|
|
68
|
-
def shift_negative_correction(self,
|
|
69
|
-
graph: Graph,
|
|
70
|
-
qc: QuantizationConfig,
|
|
71
|
-
fw_info: FrameworkInfo) -> Graph:
|
|
72
|
-
"""
|
|
73
|
-
Apply shift negative correction (SNC) on a graph.
|
|
74
|
-
|
|
75
|
-
Args:
|
|
76
|
-
graph: Graph to apply SNC on.
|
|
77
|
-
qc: Quantization configuration.
|
|
78
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
|
79
|
-
|
|
80
|
-
Returns:
|
|
81
|
-
Graph after SNC.
|
|
82
|
-
"""
|
|
83
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
84
|
-
f'framework\'s apply_shift_negative_correction method.')
|
|
85
|
-
|
|
86
|
-
@abstractmethod
|
|
87
|
-
def to_numpy(self, tensor: Any) -> np.ndarray:
|
|
88
|
-
"""
|
|
89
|
-
Convert framework's tensor to a Numpy array.
|
|
90
|
-
Args:
|
|
91
|
-
tensor: Framework's tensor.
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
Numpy array converted from the input tensor.
|
|
95
|
-
"""
|
|
96
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
97
|
-
f'framework\'s to_numpy method.')
|
|
98
|
-
|
|
99
|
-
@abstractmethod
|
|
100
|
-
def model_reader(self, model: Any) -> Graph:
|
|
101
|
-
"""
|
|
102
|
-
Convert a framework's model into a graph.
|
|
103
|
-
Args:
|
|
104
|
-
model: Framework's model.
|
|
105
|
-
|
|
106
|
-
Returns:
|
|
107
|
-
Graph representing the input model.
|
|
108
|
-
"""
|
|
109
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
110
|
-
f'framework\'s model_reader method.')
|
|
111
|
-
|
|
112
|
-
@abstractmethod
|
|
113
|
-
def attach_sc_to_node(self, node:Node,
|
|
114
|
-
fw_info:FrameworkInfo) -> common.statistics_collector.BaseStatsContainer:
|
|
115
|
-
"""
|
|
116
|
-
Return a statistics collector that should be attached to a node's output
|
|
117
|
-
during statistics collection.
|
|
118
|
-
|
|
119
|
-
Args:
|
|
120
|
-
node: Node to return its collector.
|
|
121
|
-
fw_info: FrameworkInfo object with information about the specific framework's model
|
|
122
|
-
|
|
123
|
-
Returns:
|
|
124
|
-
Statistics collector for the node.
|
|
125
|
-
"""
|
|
126
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
127
|
-
f'framework\'s attach_sc_to_node method.')
|
|
128
|
-
|
|
129
|
-
@abstractmethod
|
|
130
|
-
def get_substitutions_marking(self) -> List[common.BaseSubstitution]:
|
|
131
|
-
"""
|
|
132
|
-
|
|
133
|
-
Returns: A list of the framework substitutions used for marking
|
|
134
|
-
points we fuse.
|
|
135
|
-
|
|
136
|
-
"""
|
|
137
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
138
|
-
f'framework\'s get_substitutions_marking method.')
|
|
139
|
-
|
|
140
|
-
@abstractmethod
|
|
141
|
-
def get_substitutions_pre_statistics_collection(self) -> List[common.BaseSubstitution]:
|
|
142
|
-
"""
|
|
143
|
-
|
|
144
|
-
Returns: A list of the framework substitutions used before we collect statistics.
|
|
145
|
-
|
|
146
|
-
"""
|
|
147
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
148
|
-
f'framework\'s get_substitutions_pre_statistics_collection method.')
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
@abstractmethod
|
|
152
|
-
def get_substitutions_pre_build(self) -> List[common.BaseSubstitution]:
|
|
153
|
-
"""
|
|
154
|
-
|
|
155
|
-
Returns: A list of the framework substitutions used before we build a quantized model.
|
|
156
|
-
|
|
157
|
-
"""
|
|
158
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
159
|
-
f'framework\'s get_substitutions_pre_build method.')
|
|
160
|
-
|
|
161
|
-
@abstractmethod
|
|
162
|
-
def get_substitutions_post_statistics_collection(self, quant_config:QuantizationConfig) -> List[common.BaseSubstitution]:
|
|
163
|
-
"""
|
|
164
|
-
Return a list of the framework substitutions used after we collect statistics.
|
|
165
|
-
|
|
166
|
-
Args:
|
|
167
|
-
quant_config: QuantizationConfig to determine which substitutions to return.
|
|
168
|
-
|
|
169
|
-
Returns:
|
|
170
|
-
A list of the framework substitutions used after we collect statistics.
|
|
171
|
-
"""
|
|
172
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
173
|
-
f'framework\'s get_substitutions_post_statistics_collection method.')
|
|
174
|
-
|
|
175
|
-
@abstractmethod
|
|
176
|
-
def get_substitutions_channel_equalization(self,
|
|
177
|
-
quant_config: QuantizationConfig,
|
|
178
|
-
fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
|
|
179
|
-
"""
|
|
180
|
-
Return a list of the framework substitutions used for channel equalization.
|
|
181
|
-
|
|
182
|
-
Args:
|
|
183
|
-
quant_config: QuantizationConfig to determine which substitutions to return.
|
|
184
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
|
185
|
-
|
|
186
|
-
Returns:
|
|
187
|
-
A list of the framework substitutions used after we collect statistics.
|
|
188
|
-
"""
|
|
189
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
190
|
-
f'framework\'s get_substitutions_channel_equalization method.')
|
|
191
|
-
|
|
192
|
-
@abstractmethod
|
|
193
|
-
def gptq_training(self,
|
|
194
|
-
graph: Graph,
|
|
195
|
-
representative_data_gen: Callable,
|
|
196
|
-
gptq_config: GradientPTQConfig,
|
|
197
|
-
fw_info: FrameworkInfo) -> Graph:
|
|
198
|
-
"""
|
|
199
|
-
Update a graph using GPTQ after minimizing the loss between the float model's output
|
|
200
|
-
and the quantized model's outputs.
|
|
201
|
-
|
|
202
|
-
Args:
|
|
203
|
-
graph: Graph to fine-tune.
|
|
204
|
-
representative_data_gen: Dataset to use for inputs of the models.
|
|
205
|
-
gptq_config: GradientPTQConfig with configuration for the fine-tuning process.
|
|
206
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
|
207
|
-
|
|
208
|
-
Returns:
|
|
209
|
-
Updated graph after GPTQ.
|
|
210
|
-
"""
|
|
211
|
-
|
|
212
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
213
|
-
f'framework\'s gptq_training method.')
|
|
214
|
-
|
|
215
|
-
@abstractmethod
|
|
216
|
-
def get_sensitivity_evaluation_fn(self,
|
|
217
|
-
graph: Graph,
|
|
218
|
-
quant_config: MixedPrecisionQuantizationConfig,
|
|
219
|
-
metrics_weights: np.ndarray,
|
|
220
|
-
representative_data_gen: Callable,
|
|
221
|
-
fw_info: FrameworkInfo) -> Callable:
|
|
222
|
-
"""
|
|
223
|
-
Create and return a function to compute a sensitivity metric for a mixed-precision
|
|
224
|
-
configuration (comparing to the float model).
|
|
225
|
-
|
|
226
|
-
Args:
|
|
227
|
-
graph: Graph to build it's float and mixed-precision models.
|
|
228
|
-
quant_config: QuantizationConfig of how the model should be quantized.
|
|
229
|
-
metrics_weights: Array of weights to weight the sensitivity among different layers.
|
|
230
|
-
representative_data_gen: Dataset to use for retrieving images for the models inputs.
|
|
231
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
|
232
|
-
|
|
233
|
-
Returns:
|
|
234
|
-
A function that computes the metric.
|
|
235
|
-
"""
|
|
236
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
237
|
-
f'framework\'s get_sensitivity_evaluation_fn method.')
|
|
238
|
-
|
|
239
|
-
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from typing import Callable, Any
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class GradientPTQConfig:
|
|
20
|
-
"""
|
|
21
|
-
Configuration to use for quantization with GradientPTQ (experimental).
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
def __init__(self,
|
|
25
|
-
n_iter: int,
|
|
26
|
-
optimizer: Any,
|
|
27
|
-
loss: Callable = None,
|
|
28
|
-
log_function: Callable = None,
|
|
29
|
-
train_bias: bool = True,
|
|
30
|
-
representative_data_gen: Callable = None):
|
|
31
|
-
"""
|
|
32
|
-
Initialize a GradientPTQConfig.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
n_iter (int): Number of iterations to train.
|
|
36
|
-
optimizer (OptimizerV2): Optimizer to use.
|
|
37
|
-
loss (Callable): the loss to use. should accept 2 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors.
|
|
38
|
-
log_function (Callable): Function to log information about the GPTQ process.
|
|
39
|
-
train_bias (bool): Whether to update the bias during the training or not.
|
|
40
|
-
representative_data_gen (Callable): Dataset generator.
|
|
41
|
-
|
|
42
|
-
Examples:
|
|
43
|
-
Create a GradientPTQConfig to run for 5 iteration and uses a random dataset generator:
|
|
44
|
-
|
|
45
|
-
>>> import numpy as np
|
|
46
|
-
>>> def repr_datagen(): return [np.random.random((1,224,224,3))]
|
|
47
|
-
>>> gptq_conf = GradientPTQConfig(n_iter=5, representative_data_gen=repr_datagen)
|
|
48
|
-
|
|
49
|
-
An optimizer can be passed:
|
|
50
|
-
|
|
51
|
-
>>> gptq_conf = GradientPTQConfig(n_iter=5, representative_data_gen=repr_datagen, optimizer=tf.keras.optimizers.Nadam(learning_rate=0.2))
|
|
52
|
-
|
|
53
|
-
To disable the biases training, one may set train_bias to False (enabled by default):
|
|
54
|
-
|
|
55
|
-
>>> gptq_conf = GradientPTQConfig(n_iter=5, representative_data_gen=repr_datagen, train_bias=False)
|
|
56
|
-
|
|
57
|
-
The configuration can then be passed to :func:`~model_compression_toolkit.keras_post_training_quantization`.
|
|
58
|
-
|
|
59
|
-
"""
|
|
60
|
-
self.n_iter = n_iter
|
|
61
|
-
self.optimizer = optimizer
|
|
62
|
-
self.loss = loss
|
|
63
|
-
self.log_function = log_function
|
|
64
|
-
self.train_bias = train_bias
|
|
65
|
-
self.representative_data_gen = representative_data_gen
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from enum import Enum
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class ModelBuilderMode(Enum):
|
|
20
|
-
"""
|
|
21
|
-
Mode for building the model back from a graph:
|
|
22
|
-
FLOAT - Build model for statistics collection. Model's outputs list contain all output tensors of all nodes
|
|
23
|
-
in the graph.
|
|
24
|
-
QUANTIZED - Build a quantized model using the nodes' quantization attributes for adding
|
|
25
|
-
quantization nodes to the model.
|
|
26
|
-
GPTQ - Build a quantized model using the nodes' quantization attributes for wrapping
|
|
27
|
-
layers with QuantizeWrapper and output comparing points.
|
|
28
|
-
MIXEDPRECISION - Build a quantized model where the layers that their weights should be quantized
|
|
29
|
-
can use different quantized weights according to the possible bitwidths of each layer.
|
|
30
|
-
"""
|
|
31
|
-
FLOAT = 0
|
|
32
|
-
QUANTIZED = 1
|
|
33
|
-
GPTQ = 2
|
|
34
|
-
MIXEDPRECISION = 3
|