mct-nightly 1.8.0.20052023.post401__py3-none-any.whl → 1.8.0.20230610.post356__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/METADATA +10 -7
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/RECORD +68 -115
- model_compression_toolkit/__init__.py +23 -3
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +16 -9
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +8 -34
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +103 -28
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +39 -44
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/int8_tflite_exporter.py +20 -18
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +3 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +36 -9
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +24 -32
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +31 -8
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +5 -5
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +34 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -16
- model_compression_toolkit/gptq/keras/graph_info.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +5 -7
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +6 -6
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -7
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +6 -6
- model_compression_toolkit/gptq/pytorch/gptq_training.py +30 -10
- model_compression_toolkit/gptq/pytorch/graph_info.py +5 -2
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +4 -2
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +4 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +7 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -8
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +7 -8
- model_compression_toolkit/qat/common/__init__.py +2 -1
- model_compression_toolkit/qat/common/qat_config.py +2 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -8
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +1 -1
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +11 -11
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +11 -12
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +12 -13
- model_compression_toolkit/qat/pytorch/quantization_facade.py +27 -16
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +31 -4
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +10 -9
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +11 -10
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +2 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +1 -25
- model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py → trainable_infrastructure/__init__.py} +3 -10
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/base_trainable_quantizer.py +3 -3
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizer_config.py +1 -1
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizers.py +3 -3
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/base_keras_quantizer.py +4 -4
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/config_serialization.py +2 -2
- model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure}/keras/load_model.py +16 -23
- model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/pytorch/base_pytorch_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/__init__.py +0 -23
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +0 -87
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +0 -46
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +0 -31
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +0 -53
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +0 -49
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +0 -147
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +0 -345
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +0 -85
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +0 -27
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -148
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -65
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -86
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -111
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +0 -56
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -79
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -179
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -67
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -87
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -163
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +0 -66
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +0 -269
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +0 -152
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +0 -35
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -96
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -62
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -83
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -100
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +0 -95
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +0 -48
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +0 -70
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +0 -57
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +0 -26
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -77
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -106
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -66
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -104
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -109
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +0 -14
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/WHEEL +0 -0
- {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure/common}/__init__.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure → trainable_infrastructure/common}/constants.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/quant_utils.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/trainable_quantizer_config.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/common → trainable_infrastructure/keras}/__init__.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/quantizer_utils.py +0 -0
- /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras → trainable_infrastructure/pytorch}/__init__.py +0 -0
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from abc import abstractmethod
|
|
16
|
+
from functools import partial
|
|
16
17
|
from typing import Tuple, Any, Dict, List, Union, Callable
|
|
17
18
|
|
|
18
19
|
import torch
|
|
@@ -30,6 +31,7 @@ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAUL
|
|
|
30
31
|
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder, BufferHolder
|
|
31
32
|
from model_compression_toolkit.core.pytorch.utils import get_working_device
|
|
32
33
|
from model_compression_toolkit.core.pytorch.constants import BUFFER
|
|
34
|
+
from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def _build_input_tensors_list(node: BaseNode,
|
|
@@ -66,7 +68,7 @@ def _run_operation(n: BaseNode,
|
|
|
66
68
|
input_tensors: List,
|
|
67
69
|
op_func: Any,
|
|
68
70
|
quantize_node_activation_fn,
|
|
69
|
-
|
|
71
|
+
use_activation_quantization: bool) -> Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]]:
|
|
70
72
|
"""
|
|
71
73
|
Applying the layer (op_func) to the input tensors (input_tensors).
|
|
72
74
|
If quantized is set to True, and the layer's corresponding node (n) has quantization
|
|
@@ -77,7 +79,7 @@ def _run_operation(n: BaseNode,
|
|
|
77
79
|
input_tensors: List of Pytorch tensors that are the layer's inputs.
|
|
78
80
|
op_func: Module/functional to apply to the input tensors.
|
|
79
81
|
quantize_node_activation_fn: quantization function
|
|
80
|
-
|
|
82
|
+
use_activation_quantization: Flag to indicate if we have an activation function.
|
|
81
83
|
Returns:
|
|
82
84
|
A tuple of Pytorch tensors. The Module/functional output tensors after applying the
|
|
83
85
|
Module/functional to the input tensors.
|
|
@@ -92,10 +94,10 @@ def _run_operation(n: BaseNode,
|
|
|
92
94
|
|
|
93
95
|
# Add a fake quant node if the node has an activation threshold.
|
|
94
96
|
out_tensors_of_n = out_tensors_of_n_float
|
|
95
|
-
if
|
|
97
|
+
if use_activation_quantization:
|
|
96
98
|
if isinstance(out_tensors_of_n_float, list):
|
|
97
99
|
out_tensors_of_n_float = torch.cat(out_tensors_of_n_float, dim=0)
|
|
98
|
-
out_tensors_of_n = quantize_node_activation_fn(
|
|
100
|
+
out_tensors_of_n = quantize_node_activation_fn(out_tensors_of_n_float)
|
|
99
101
|
|
|
100
102
|
return out_tensors_of_n, out_tensors_of_n_float
|
|
101
103
|
|
|
@@ -145,7 +147,8 @@ class PytorchModel(torch.nn.Module):
|
|
|
145
147
|
append2output: List[Any] = None,
|
|
146
148
|
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
|
147
149
|
return_float_outputs: bool = False,
|
|
148
|
-
wrapper: Callable =
|
|
150
|
+
wrapper: Callable = None,
|
|
151
|
+
get_activation_quantizer_holder_fn: Callable = None):
|
|
149
152
|
"""
|
|
150
153
|
Construct a Pytorch model.
|
|
151
154
|
|
|
@@ -155,17 +158,31 @@ class PytorchModel(torch.nn.Module):
|
|
|
155
158
|
fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
|
|
156
159
|
return_float_outputs: Whether the model returns float tensors or not.
|
|
157
160
|
wrapper: A function wrapper Pytorch Layers.
|
|
161
|
+
get_activation_quantizer_holder_fn: Function to retrieve a quantization holder for a node.
|
|
162
|
+
|
|
158
163
|
"""
|
|
159
164
|
super(PytorchModel, self).__init__()
|
|
160
165
|
self.graph = graph
|
|
161
166
|
self.node_sort = list(topological_sort(graph))
|
|
162
|
-
self.
|
|
167
|
+
self.node_to_activation_quantization_holder = {}
|
|
163
168
|
self.append2output = append2output
|
|
164
169
|
self.return_float_outputs = return_float_outputs
|
|
165
170
|
self.fw_info = fw_info
|
|
166
171
|
self.wrapper = wrapper
|
|
172
|
+
self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn
|
|
167
173
|
self._add_modules()
|
|
168
174
|
|
|
175
|
+
# todo: Move to parent class BaseModelBuilder
|
|
176
|
+
@property
|
|
177
|
+
def use_activation_holder_during_model_building(self) -> bool:
|
|
178
|
+
"""
|
|
179
|
+
Returns: Whether or not the model builder uses a PytorchActivationQuantizationHolder during
|
|
180
|
+
model building (by adding it as a module when converting the graph to a Pytorch model).
|
|
181
|
+
If so - the model builder expects the activation quantizers not to be wrapped
|
|
182
|
+
in a PytorchQuantizeWrapper.
|
|
183
|
+
"""
|
|
184
|
+
return self.get_activation_quantizer_holder is not None
|
|
185
|
+
|
|
169
186
|
@abstractmethod
|
|
170
187
|
def _quantize_node_activations(self,
|
|
171
188
|
node: BaseNode,
|
|
@@ -184,18 +201,50 @@ class PytorchModel(torch.nn.Module):
|
|
|
184
201
|
raise NotImplemented(f'{self.__class__.__name__} '
|
|
185
202
|
f'have to implement a method for quantization activation nodes.') # pragma: no cover
|
|
186
203
|
|
|
204
|
+
def wrap(self, node):
|
|
205
|
+
"""
|
|
206
|
+
Wraps a node operation with a wrapper, if one is available.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
node: node to wrap its operation.
|
|
210
|
+
|
|
211
|
+
Returns: the node's operation. If a wrapper is available, the operation is wrapped.
|
|
212
|
+
"""
|
|
213
|
+
if isinstance(node, FunctionalNode):
|
|
214
|
+
if self.wrapper is None:
|
|
215
|
+
node_op = node.type
|
|
216
|
+
else:
|
|
217
|
+
node_op = self.wrapper(node, node.type)
|
|
218
|
+
else:
|
|
219
|
+
if self.wrapper is None or node.type == BufferHolder:
|
|
220
|
+
node_op = node_builder(node)
|
|
221
|
+
else:
|
|
222
|
+
node_op = self.wrapper(node, node_builder(node))
|
|
223
|
+
return node_op
|
|
224
|
+
|
|
187
225
|
def _add_modules(self):
|
|
188
|
-
|
|
189
|
-
|
|
226
|
+
"""
|
|
227
|
+
Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel
|
|
228
|
+
"""
|
|
229
|
+
for node in self.node_sort:
|
|
230
|
+
node_op = self.wrap(node)
|
|
231
|
+
if isinstance(node, FunctionalNode):
|
|
190
232
|
# for functional layers
|
|
191
|
-
setattr(self,
|
|
233
|
+
setattr(self, node.name, node_op)
|
|
192
234
|
else:
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
self.get_submodule(
|
|
196
|
-
register_buffer(
|
|
197
|
-
|
|
198
|
-
|
|
235
|
+
self.add_module(node.name, node_op)
|
|
236
|
+
if node.type == BufferHolder:
|
|
237
|
+
self.get_submodule(node.name). \
|
|
238
|
+
register_buffer(node.name,
|
|
239
|
+
torch.Tensor(node.get_weights_by_keys(BUFFER)).to(get_working_device()))
|
|
240
|
+
|
|
241
|
+
# Add activation quantization modules if an activation holder is configured for this node
|
|
242
|
+
if node.is_activation_quantization_enabled() and self.get_activation_quantizer_holder is not None:
|
|
243
|
+
activation_quantizer_holder = self.get_activation_quantizer_holder(node)
|
|
244
|
+
if activation_quantizer_holder is not None:
|
|
245
|
+
self.add_module(node.name + '_' + ACTIVATION_HOLDER_QUANTIZER, activation_quantizer_holder)
|
|
246
|
+
self.node_to_activation_quantization_holder.update(
|
|
247
|
+
{node.name: node.name + '_' + ACTIVATION_HOLDER_QUANTIZER})
|
|
199
248
|
|
|
200
249
|
def forward(self,
|
|
201
250
|
*args: Any) -> Any:
|
|
@@ -208,28 +257,28 @@ class PytorchModel(torch.nn.Module):
|
|
|
208
257
|
node_to_output_tensors_dict = dict()
|
|
209
258
|
node_to_output_tensors_dict_float = dict()
|
|
210
259
|
configurable_nodes = self.graph.get_configurable_sorted_nodes_names()
|
|
211
|
-
for
|
|
212
|
-
input_tensors = _build_input_tensors_list(
|
|
260
|
+
for node in self.node_sort:
|
|
261
|
+
input_tensors = _build_input_tensors_list(node,
|
|
213
262
|
self.graph,
|
|
214
263
|
args,
|
|
215
264
|
node_to_output_tensors_dict)
|
|
216
265
|
|
|
217
|
-
op_func = self._get_op_func(
|
|
266
|
+
op_func = self._get_op_func(node, configurable_nodes)
|
|
267
|
+
use_activation_quantization, activation_quantization_fn = self._get_activation_quantization_fn(node)
|
|
218
268
|
|
|
219
269
|
# Run node operation and fetch outputs
|
|
220
|
-
out_tensors_of_n, out_tensors_of_n_float = _run_operation(
|
|
270
|
+
out_tensors_of_n, out_tensors_of_n_float = _run_operation(node,
|
|
221
271
|
input_tensors,
|
|
222
272
|
op_func=op_func,
|
|
223
|
-
quantize_node_activation_fn=
|
|
224
|
-
|
|
273
|
+
quantize_node_activation_fn=activation_quantization_fn,
|
|
274
|
+
use_activation_quantization=use_activation_quantization)
|
|
225
275
|
|
|
226
276
|
if isinstance(out_tensors_of_n, list):
|
|
227
|
-
node_to_output_tensors_dict.update({
|
|
228
|
-
node_to_output_tensors_dict_float.update({
|
|
277
|
+
node_to_output_tensors_dict.update({node: out_tensors_of_n})
|
|
278
|
+
node_to_output_tensors_dict_float.update({node: out_tensors_of_n_float})
|
|
229
279
|
else:
|
|
230
|
-
node_to_output_tensors_dict.update({
|
|
231
|
-
node_to_output_tensors_dict_float.update({
|
|
232
|
-
|
|
280
|
+
node_to_output_tensors_dict.update({node: [out_tensors_of_n]})
|
|
281
|
+
node_to_output_tensors_dict_float.update({node: [out_tensors_of_n_float]})
|
|
233
282
|
|
|
234
283
|
if self.append2output:
|
|
235
284
|
outputs = _generate_outputs(self.append2output,
|
|
@@ -256,6 +305,28 @@ class PytorchModel(torch.nn.Module):
|
|
|
256
305
|
"""
|
|
257
306
|
return getattr(self, node.name)
|
|
258
307
|
|
|
308
|
+
def _get_activation_quantization_fn(self, node) -> Tuple[bool, bool, Callable]:
|
|
309
|
+
"""
|
|
310
|
+
Get activation quantization parameters for this node.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
node: Node from which to extract the activation quantization parameters.
|
|
314
|
+
|
|
315
|
+
Returns: Flag to indicate if we quantize activations, flag to indicate if we quantize activations
|
|
316
|
+
using a quantization holder and a quantization function to use for the node's activations.
|
|
317
|
+
"""
|
|
318
|
+
activation_quantization_holder = self.node_to_activation_quantization_holder.get(node.name)
|
|
319
|
+
use_activation_quantization = node.is_activation_quantization_enabled()
|
|
320
|
+
if use_activation_quantization:
|
|
321
|
+
if activation_quantization_holder is None:
|
|
322
|
+
activation_quantization_fn = partial(self._quantize_node_activations, node)
|
|
323
|
+
use_activation_quantization = self.wrapper is None
|
|
324
|
+
else:
|
|
325
|
+
activation_quantization_fn = getattr(self, activation_quantization_holder)
|
|
326
|
+
else:
|
|
327
|
+
activation_quantization_fn = None
|
|
328
|
+
return use_activation_quantization, activation_quantization_fn
|
|
329
|
+
|
|
259
330
|
|
|
260
331
|
class PyTorchModelBuilder(BaseModelBuilder):
|
|
261
332
|
"""
|
|
@@ -267,7 +338,8 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
|
267
338
|
append2output=None,
|
|
268
339
|
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
|
269
340
|
return_float_outputs: bool = False,
|
|
270
|
-
wrapper: Callable =
|
|
341
|
+
wrapper: Callable = None,
|
|
342
|
+
get_activation_quantizer_holder_fn: Callable = None):
|
|
271
343
|
"""
|
|
272
344
|
|
|
273
345
|
Args:
|
|
@@ -276,6 +348,7 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
|
276
348
|
fw_info: Information about the specific framework of the model that is built.
|
|
277
349
|
return_float_outputs: Whether the model returns float tensors or not.
|
|
278
350
|
wrapper: A function wrapper Pytorch Layers.
|
|
351
|
+
get_activation_quantizer_holder_fn: Function to retrieve a quantization holder for a node.
|
|
279
352
|
"""
|
|
280
353
|
|
|
281
354
|
super().__init__(graph,
|
|
@@ -284,6 +357,7 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
|
284
357
|
return_float_outputs)
|
|
285
358
|
|
|
286
359
|
self.wrapper = wrapper
|
|
360
|
+
self.get_activation_quantizer_holder_fn = get_activation_quantizer_holder_fn
|
|
287
361
|
|
|
288
362
|
def build_model(self) -> Tuple[PytorchModel, UserInformation]:
|
|
289
363
|
"""
|
|
@@ -294,4 +368,5 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
|
294
368
|
return PytorchModel(self.graph,
|
|
295
369
|
self.append2output,
|
|
296
370
|
return_float_outputs=self.return_float_outputs,
|
|
297
|
-
wrapper=self.wrapper
|
|
371
|
+
wrapper=self.wrapper,
|
|
372
|
+
get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder_fn), self.graph.user_info
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Dict, Callable
|
|
16
16
|
|
|
17
|
+
import keras
|
|
17
18
|
import keras.models
|
|
18
19
|
import keras.models
|
|
19
20
|
import tensorflow as tf
|
|
@@ -22,9 +23,9 @@ from keras.engine.base_layer import Layer
|
|
|
22
23
|
from model_compression_toolkit.logger import Logger
|
|
23
24
|
from model_compression_toolkit.exporter.model_exporter.keras.base_keras_exporter import \
|
|
24
25
|
BaseKerasExporter
|
|
25
|
-
from
|
|
26
|
-
|
|
26
|
+
from mct_quantizers import KerasQuantizationWrapper
|
|
27
27
|
|
|
28
|
+
layers = keras.layers
|
|
28
29
|
|
|
29
30
|
class FakelyQuantKerasExporter(BaseKerasExporter):
|
|
30
31
|
"""
|
|
@@ -69,51 +70,45 @@ class FakelyQuantKerasExporter(BaseKerasExporter):
|
|
|
69
70
|
Layer after unwrapping.
|
|
70
71
|
|
|
71
72
|
"""
|
|
72
|
-
|
|
73
|
+
|
|
74
|
+
# Assert each layer is exportable
|
|
75
|
+
self.is_layer_exportable_fn(layer)
|
|
73
76
|
|
|
74
77
|
# If weights are quantized, use the quantized weight for the new built layer.
|
|
75
|
-
if layer
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
if layer.is_activation_quantization:
|
|
108
|
-
new_layer = KerasQuantizationWrapper(layer=new_layer,
|
|
109
|
-
activation_quantizers=layer.activation_quantizers)
|
|
110
|
-
|
|
111
|
-
return new_layer
|
|
112
|
-
|
|
113
|
-
# If this is a layer with activation quantization only, just return it
|
|
114
|
-
# as activation quantization in the fake-quant case uses the wrapper for quantization.
|
|
115
|
-
return layer
|
|
78
|
+
if isinstance(layer, KerasQuantizationWrapper):
|
|
79
|
+
if layer.is_weights_quantization:
|
|
80
|
+
new_layer = layer.layer.__class__.from_config(layer.layer.get_config())
|
|
81
|
+
|
|
82
|
+
# Build a list of the layer's new weights.
|
|
83
|
+
weights_list = []
|
|
84
|
+
|
|
85
|
+
# Create a list of weights for the new created layer
|
|
86
|
+
if isinstance(layer.layer, layers.DepthwiseConv2D):
|
|
87
|
+
weights_list.append(layer.get_quantized_weights()['depthwise_kernel'])
|
|
88
|
+
elif isinstance(layer.layer, (layers.Conv2D, layers.Dense, layers.Conv2DTranspose)):
|
|
89
|
+
weights_list.append(layer.get_quantized_weights()['kernel'])
|
|
90
|
+
else:
|
|
91
|
+
Logger.error(f'KerasQuantizationWrapper should wrap only DepthwiseConv2D, Conv2D, Dense'
|
|
92
|
+
f' and Conv2DTranspose layers but wrapped layer is {layer.layer}')
|
|
93
|
+
|
|
94
|
+
if layer.layer.bias is not None:
|
|
95
|
+
weights_list.append(layer.layer.bias)
|
|
96
|
+
|
|
97
|
+
# In order to add the weights of the layer, we need to build it. To build it
|
|
98
|
+
# we need to pass its input shape. Not every layer has input_shape since some
|
|
99
|
+
# layers may have multiple inputs with different input shapes (reused layers for
|
|
100
|
+
# example). For this reason, we take input shape at index 0 (any input shape
|
|
101
|
+
# should work since the weights are dependent only at some dimensions which have to
|
|
102
|
+
# be the same for all inputs).
|
|
103
|
+
with tf.name_scope(new_layer.name):
|
|
104
|
+
new_layer.build(layer.get_input_shape_at(0))
|
|
105
|
+
|
|
106
|
+
new_layer.set_weights(weights_list)
|
|
107
|
+
new_layer.trainable = False
|
|
108
|
+
|
|
109
|
+
return new_layer
|
|
116
110
|
|
|
111
|
+
return layer
|
|
117
112
|
|
|
118
113
|
# clone each layer in the model and apply _unwrap_quantize_wrapper to layers wrapped with a QuantizeWrapper.
|
|
119
114
|
self.exported_model = tf.keras.models.clone_model(self.model,
|
|
@@ -19,9 +19,9 @@ from typing import Callable
|
|
|
19
19
|
import keras.models
|
|
20
20
|
import tensorflow as tf
|
|
21
21
|
|
|
22
|
-
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
|
|
23
22
|
from model_compression_toolkit.logger import Logger
|
|
24
23
|
from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
|
|
24
|
+
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class FakelyQuantTFLiteExporter(FakelyQuantKerasExporter):
|
|
@@ -22,11 +22,9 @@ from keras import Sequential
|
|
|
22
22
|
from keras.layers import Dense, Conv2D, Reshape
|
|
23
23
|
from keras.models import clone_model
|
|
24
24
|
|
|
25
|
-
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
26
25
|
from model_compression_toolkit.logger import Logger
|
|
27
26
|
from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
|
|
28
|
-
from
|
|
29
|
-
constants as keras_inferable_constants
|
|
27
|
+
from mct_quantizers import constants as keras_inferable_constants, KerasQuantizationWrapper
|
|
30
28
|
|
|
31
29
|
BIAS_INITIALIZER = 'bias_initializer'
|
|
32
30
|
BIAS_REGULARIZER = 'bias_regularizer'
|
|
@@ -50,6 +48,7 @@ KERNEL = 'kernel'
|
|
|
50
48
|
CONV_KERNEL_CHANNEL_AXIS = 3
|
|
51
49
|
CONV_INPUT_CHANNELS_DIM = 4
|
|
52
50
|
|
|
51
|
+
|
|
53
52
|
class INT8TFLiteExporter(FakelyQuantKerasExporter):
|
|
54
53
|
"""
|
|
55
54
|
Exporter for INT8 TFLite models.
|
|
@@ -75,7 +74,7 @@ class INT8TFLiteExporter(FakelyQuantKerasExporter):
|
|
|
75
74
|
|
|
76
75
|
self.exported_model = None
|
|
77
76
|
|
|
78
|
-
def _get_pointwise_layer_to_replace_dense(self, wrapped_layer:
|
|
77
|
+
def _get_pointwise_layer_to_replace_dense(self, wrapped_layer: KerasQuantizationWrapper) -> keras.layers.Layer:
|
|
79
78
|
# First we create a pointwise configuration based on the Dense layer's configuration
|
|
80
79
|
dense_cfg = wrapped_layer.layer.get_config()
|
|
81
80
|
|
|
@@ -94,7 +93,7 @@ class INT8TFLiteExporter(FakelyQuantKerasExporter):
|
|
|
94
93
|
|
|
95
94
|
# Create the point-wise layer
|
|
96
95
|
pw_layer = Conv2D(**pw_cfg)
|
|
97
|
-
pw_layer.build(wrapped_layer.
|
|
96
|
+
pw_layer.build(wrapped_layer.input_shape)
|
|
98
97
|
|
|
99
98
|
# Create and set the point-wise weights to assign
|
|
100
99
|
dense_kernel = wrapped_layer.layer.kernel
|
|
@@ -110,7 +109,7 @@ class INT8TFLiteExporter(FakelyQuantKerasExporter):
|
|
|
110
109
|
pw_layer.set_weights(pw_weights)
|
|
111
110
|
|
|
112
111
|
# Now that we have the point-wise to replace the dense layer,
|
|
113
|
-
# we need to wrap it using
|
|
112
|
+
# we need to wrap it using KerasQuantizationWrapper with a new
|
|
114
113
|
# relevant quantizers.
|
|
115
114
|
# Create new kernel quantizer
|
|
116
115
|
pw_kernel_quantizer_cfg = wrapped_layer.weights_quantizers[KERNEL].get_config()
|
|
@@ -121,8 +120,10 @@ class INT8TFLiteExporter(FakelyQuantKerasExporter):
|
|
|
121
120
|
# Unquantized weight to conv layer has 4 dimensions (unlike dense which varies)
|
|
122
121
|
pw_kernel_quantizer_cfg[keras_inferable_constants.INPUT_RANK] = CONV_INPUT_CHANNELS_DIM
|
|
123
122
|
|
|
124
|
-
assert isinstance(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD],
|
|
125
|
-
|
|
123
|
+
assert isinstance(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD],
|
|
124
|
+
np.ndarray), f'Expected to find threshold which is a numpy array, but found: {type(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD])}'
|
|
125
|
+
pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD] = list(
|
|
126
|
+
pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD])
|
|
126
127
|
|
|
127
128
|
# Now that we have the point-wise quantizer we can instantiate it
|
|
128
129
|
quantizer_class = type(wrapped_layer.weights_quantizers[KERNEL])
|
|
@@ -131,21 +132,21 @@ class INT8TFLiteExporter(FakelyQuantKerasExporter):
|
|
|
131
132
|
pw_weights_quantizers[KERNEL] = pw_quantizer
|
|
132
133
|
|
|
133
134
|
# Wrap pw with the new quantizers (the activation is not affected thus we take the Dense quantizers)
|
|
134
|
-
wrapped_pw =
|
|
135
|
-
|
|
136
|
-
|
|
135
|
+
wrapped_pw = KerasQuantizationWrapper(pw_layer,
|
|
136
|
+
pw_weights_quantizers,
|
|
137
|
+
wrapped_layer.activation_quantizers)
|
|
137
138
|
|
|
138
139
|
# Compute the shape that the input to the new layer should be reshaped into
|
|
139
140
|
# Example: Dense kernel with the following shape (3, 20) expects to have input with the
|
|
140
141
|
# next dimensions (BATCH_SIZE, x0, x1, ..., xn, 20).
|
|
141
142
|
# Conv layer expects 4-rank input. Thus, the input is reshaped to (BATCH_SIZE, 1, x0*x1*...*xn, 20)
|
|
142
|
-
dim = wrapped_layer.
|
|
143
|
+
dim = wrapped_layer.input_shape[1:-1]
|
|
143
144
|
target_shape = (1, int(np.prod(dim))) + (dense_kernel.get_shape()[0],)
|
|
144
145
|
|
|
145
146
|
return Sequential([
|
|
146
147
|
Reshape(target_shape=target_shape),
|
|
147
148
|
wrapped_pw,
|
|
148
|
-
Reshape(wrapped_layer.
|
|
149
|
+
Reshape(wrapped_layer.output_shape[1:])
|
|
149
150
|
])
|
|
150
151
|
|
|
151
152
|
def export(self) -> None:
|
|
@@ -153,17 +154,18 @@ class INT8TFLiteExporter(FakelyQuantKerasExporter):
|
|
|
153
154
|
Export a fully quantized model to its int8 tflite model.
|
|
154
155
|
"""
|
|
155
156
|
|
|
156
|
-
def _substitute_model(
|
|
157
|
+
def _substitute_model(layer_to_substitue: keras.layers.Layer) -> keras.layers.Layer:
|
|
157
158
|
assert self.is_layer_exportable_fn(
|
|
158
|
-
|
|
159
|
+
layer_to_substitue), f'Layer {layer_to_substitue.get_config()} did not pass validation'
|
|
159
160
|
|
|
160
161
|
# In order to support dense quantization using per-channel quantization (which is
|
|
161
162
|
# unsupported in TFLITE int models) we substitute each dense layer to its equivalent
|
|
162
163
|
# point-wise convolution.
|
|
163
|
-
if isinstance(
|
|
164
|
-
|
|
164
|
+
if isinstance(layer_to_substitue, KerasQuantizationWrapper):
|
|
165
|
+
if isinstance(layer_to_substitue.layer, Dense):
|
|
166
|
+
return self._get_pointwise_layer_to_replace_dense(layer_to_substitue)
|
|
165
167
|
|
|
166
|
-
return
|
|
168
|
+
return layer_to_substitue
|
|
167
169
|
|
|
168
170
|
# Transform the model to a new model that can be converted to int8 models.
|
|
169
171
|
# For example: replace dense layers with point-wise layers (to support per-channel quantization)
|
model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py
CHANGED
|
@@ -21,8 +21,8 @@ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
|
21
21
|
from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
|
|
22
22
|
from packaging import version
|
|
23
23
|
|
|
24
|
-
from
|
|
25
|
-
from
|
|
24
|
+
from mct_quantizers import PytorchQuantizationWrapper
|
|
25
|
+
from mct_quantizers.common.constants import LAYER
|
|
26
26
|
|
|
27
27
|
# ONNX opset version 16 is supported from PyTorch 1.12
|
|
28
28
|
if version.parse(torch.__version__) < version.parse("1.12"):
|
|
@@ -68,7 +68,7 @@ class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
|
|
|
68
68
|
Fake-quant PyTorch model.
|
|
69
69
|
"""
|
|
70
70
|
for layer in self.model.children():
|
|
71
|
-
|
|
71
|
+
self.is_layer_exportable_fn(layer)
|
|
72
72
|
|
|
73
73
|
model_input = to_torch_tensor(next(self.repr_dataset())[0])
|
|
74
74
|
|
|
@@ -57,7 +57,7 @@ class FakelyQuantTorchScriptPyTorchExporter(BasePyTorchExporter):
|
|
|
57
57
|
Fake-quant PyTorch model.
|
|
58
58
|
"""
|
|
59
59
|
for layer in self.model.children():
|
|
60
|
-
|
|
60
|
+
self.is_layer_exportable_fn(layer)
|
|
61
61
|
|
|
62
62
|
torch_traced = torch.jit.trace(self.model,
|
|
63
63
|
to_torch_tensor(next(self.repr_dataset())),
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -12,36 +12,62 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import Tuple
|
|
16
15
|
|
|
17
|
-
|
|
18
|
-
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
16
|
+
from typing import Tuple, Callable
|
|
19
17
|
from model_compression_toolkit.core import common
|
|
20
18
|
from model_compression_toolkit.core.common import Graph
|
|
21
19
|
from model_compression_toolkit.constants import FOUND_TF
|
|
22
20
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
23
21
|
from model_compression_toolkit.logger import Logger
|
|
22
|
+
from mct_quantizers import KerasActivationQuantizationHolder
|
|
24
23
|
|
|
25
24
|
if FOUND_TF:
|
|
26
25
|
import tensorflow as tf
|
|
27
26
|
from tensorflow.keras.layers import Layer
|
|
28
27
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
29
28
|
from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizers import get_quantization_quantizers
|
|
29
|
+
from mct_quantizers import KerasQuantizationWrapper
|
|
30
30
|
|
|
31
31
|
def _get_wrapper(node: common.BaseNode,
|
|
32
|
-
layer: Layer) ->
|
|
32
|
+
layer: Layer) -> Layer:
|
|
33
33
|
"""
|
|
34
34
|
A function which takes a computational graph node and a keras layer and perform the quantization wrapping
|
|
35
35
|
Args:
|
|
36
|
-
|
|
36
|
+
node: A node of mct graph.
|
|
37
37
|
layer: A keras layer
|
|
38
|
-
include_activation_quantizers: Whether to use the wrapper for the activation quantizer or not
|
|
39
38
|
|
|
40
39
|
Returns: Wrapped layer with weights quantizers and activation quantizers
|
|
41
40
|
|
|
42
41
|
"""
|
|
43
|
-
weights_quantizers,
|
|
44
|
-
|
|
42
|
+
weights_quantizers, _ = get_quantization_quantizers(node)
|
|
43
|
+
if len(weights_quantizers) > 0:
|
|
44
|
+
return KerasQuantizationWrapper(layer,
|
|
45
|
+
weights_quantizers)
|
|
46
|
+
return layer
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_activation_quantizer_holder(node: common.BaseNode) -> Callable:
|
|
50
|
+
"""
|
|
51
|
+
Retrieve a ActivationQuantizationHolder layer to use for activation quantization for a node.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
node: Node to get ActivationQuantizationHolder to attach in its output.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
A ActivationQuantizationHolder layer for the node activation quantization.
|
|
58
|
+
"""
|
|
59
|
+
_, activation_quantizers = get_quantization_quantizers(node)
|
|
60
|
+
|
|
61
|
+
# Holder by definition uses a single quantizer for the activation quantization
|
|
62
|
+
# thus we make sure this is the only possible case (unless it's a node with no activation
|
|
63
|
+
# quantization, which in this case has an empty list).
|
|
64
|
+
if len(activation_quantizers) == 1:
|
|
65
|
+
return KerasActivationQuantizationHolder(activation_quantizers[0])
|
|
66
|
+
|
|
67
|
+
Logger.error(
|
|
68
|
+
f'ActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
|
|
69
|
+
f'were found for node {node}')
|
|
70
|
+
|
|
45
71
|
|
|
46
72
|
|
|
47
73
|
def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model, UserInformation]:
|
|
@@ -57,7 +83,8 @@ if FOUND_TF:
|
|
|
57
83
|
Exportable Keras model and user information.
|
|
58
84
|
"""
|
|
59
85
|
exportable_model, user_info = KerasModelBuilder(graph=graph,
|
|
60
|
-
wrapper=_get_wrapper
|
|
86
|
+
wrapper=_get_wrapper,
|
|
87
|
+
get_activation_quantizer_holder_fn=get_activation_quantizer_holder).build_model()
|
|
61
88
|
exportable_model.trainable = False
|
|
62
89
|
return exportable_model, user_info
|
|
63
90
|
else:
|
|
@@ -19,10 +19,10 @@ from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX,
|
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.logger import Logger
|
|
21
21
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
22
|
-
from
|
|
23
|
-
from
|
|
24
|
-
from
|
|
25
|
-
from
|
|
22
|
+
from mct_quantizers import QuantizationTarget
|
|
23
|
+
from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
|
|
24
|
+
from mct_quantizers.keras.quantizers import BaseKerasInferableQuantizer
|
|
25
|
+
from mct_quantizers import constants as qi_keras_consts
|
|
26
26
|
|
|
27
27
|
def get_inferable_quantizer_kwargs(node: BaseNode,
|
|
28
28
|
quantization_target: QuantizationTarget) -> Dict[str, Any]:
|