ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1041 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Tests for params_generator."""
|
17
|
+
|
18
|
+
from collections.abc import Generator
|
19
|
+
import os
|
20
|
+
from typing import Any
|
21
|
+
|
22
|
+
from absl.testing import parameterized
|
23
|
+
import numpy as np
|
24
|
+
|
25
|
+
from tensorflow.python.platform import googletest
|
26
|
+
from ai_edge_quantizer import calibrator
|
27
|
+
from ai_edge_quantizer import params_generator
|
28
|
+
from ai_edge_quantizer import qtyping
|
29
|
+
from ai_edge_quantizer import recipe_manager
|
30
|
+
from ai_edge_quantizer.utils import test_utils
|
31
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
32
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
33
|
+
|
34
|
+
|
35
|
+
_ComputePrecision = qtyping.ComputePrecision
|
36
|
+
_TensorDataType = qtyping.TensorDataType
|
37
|
+
_TensorQuantConfig = qtyping.TensorQuantizationConfig
|
38
|
+
_QuantTransformation = qtyping.QuantTransformation
|
39
|
+
_AlgorithmName = recipe_manager.AlgorithmName
|
40
|
+
_QuantGranularity = qtyping.QuantGranularity
|
41
|
+
|
42
|
+
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('')
|
43
|
+
|
44
|
+
|
45
|
+
def _single_fc_model_representative_dataset_gen(num_samples=5):
|
46
|
+
for _ in range(num_samples):
|
47
|
+
yield {'input_1': np.random.rand(1, 8).astype(np.float32)}
|
48
|
+
|
49
|
+
|
50
|
+
def _int_transpose_model_representative_dataset_gen(num_samples=5):
|
51
|
+
data = []
|
52
|
+
for _ in range(num_samples):
|
53
|
+
data.append({'input_2': np.random.rand(1, 2, 3, 4).astype(np.int32)})
|
54
|
+
return data
|
55
|
+
|
56
|
+
|
57
|
+
def _get_calibration_data(
|
58
|
+
dataset_gen: Generator[dict[str, Any], Any, None],
|
59
|
+
) -> dict[str, Any]:
|
60
|
+
calibration_samples = [sample for sample in dataset_gen]
|
61
|
+
calibration_data = {
|
62
|
+
tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples,
|
63
|
+
}
|
64
|
+
return calibration_data
|
65
|
+
|
66
|
+
|
67
|
+
class ParamsGeneratorTest(parameterized.TestCase):
|
68
|
+
|
69
|
+
def setUp(self):
|
70
|
+
super().setUp()
|
71
|
+
self._test_model_path = os.path.join(
|
72
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/conv_fc_mnist.tflite'
|
73
|
+
)
|
74
|
+
self._recipe_manager = recipe_manager.RecipeManager()
|
75
|
+
self._params_generator = params_generator.ParamsGenerator(
|
76
|
+
self._test_model_path
|
77
|
+
)
|
78
|
+
|
79
|
+
def test_update_model_quant_results(self):
|
80
|
+
params_from_target = qtyping.TensorTransformationParams(
|
81
|
+
tensor_name='test_tensor0',
|
82
|
+
consumers=[
|
83
|
+
qtyping.OpToTensorParams(
|
84
|
+
subgraph_op_id=3,
|
85
|
+
transformations=[
|
86
|
+
_QuantTransformation.ADD_QUANTIZE,
|
87
|
+
_QuantTransformation.ADD_DEQUANTIZE,
|
88
|
+
],
|
89
|
+
)
|
90
|
+
],
|
91
|
+
)
|
92
|
+
# Test add new tensor from target
|
93
|
+
self._params_generator._update_model_quant_results([params_from_target])
|
94
|
+
self.assertIsNotNone(
|
95
|
+
'test_tensor0' in self._params_generator.model_quant_results
|
96
|
+
)
|
97
|
+
tensor_params = self._params_generator.model_quant_results['test_tensor0']
|
98
|
+
self.assertEqual(
|
99
|
+
tensor_params,
|
100
|
+
params_from_target,
|
101
|
+
)
|
102
|
+
# Test update new tensor from source
|
103
|
+
params_from_source = qtyping.TensorTransformationParams(
|
104
|
+
tensor_name='test_tensor0',
|
105
|
+
producer=qtyping.OpToTensorParams(
|
106
|
+
subgraph_op_id=3,
|
107
|
+
transformations=[
|
108
|
+
_QuantTransformation.ADD_DEQUANTIZE,
|
109
|
+
],
|
110
|
+
),
|
111
|
+
)
|
112
|
+
self._params_generator._update_model_quant_results([params_from_source])
|
113
|
+
tensor_params = self._params_generator.model_quant_results['test_tensor0']
|
114
|
+
self.assertEqual(
|
115
|
+
tensor_params.producer,
|
116
|
+
params_from_source.producer,
|
117
|
+
)
|
118
|
+
|
119
|
+
# We can have multiple target op params
|
120
|
+
params_from_target2 = qtyping.TensorTransformationParams(
|
121
|
+
tensor_name='test_tensor0',
|
122
|
+
consumers=[
|
123
|
+
qtyping.OpToTensorParams(
|
124
|
+
subgraph_op_id=3,
|
125
|
+
transformations=[
|
126
|
+
_QuantTransformation.NO_QUANTIZE,
|
127
|
+
],
|
128
|
+
)
|
129
|
+
],
|
130
|
+
)
|
131
|
+
self._params_generator._update_model_quant_results([params_from_target2])
|
132
|
+
tensor_params = self._params_generator.model_quant_results['test_tensor0']
|
133
|
+
self.assertSequenceEqual(
|
134
|
+
tensor_params.consumers,
|
135
|
+
params_from_target.consumers + params_from_target2.consumers,
|
136
|
+
)
|
137
|
+
|
138
|
+
# but only a single source op params
|
139
|
+
error_message = (
|
140
|
+
'received multiple quantization parameters from the source op'
|
141
|
+
)
|
142
|
+
with self.assertRaisesWithPredicateMatch(
|
143
|
+
RuntimeError, lambda err: error_message in str(err)
|
144
|
+
):
|
145
|
+
self._params_generator._update_model_quant_results([params_from_source])
|
146
|
+
|
147
|
+
def test_generate_config_global(self):
|
148
|
+
# Quantize all fully_connected.
|
149
|
+
global_recipe = [
|
150
|
+
{
|
151
|
+
'regex': '.*',
|
152
|
+
'operation': 'FULLY_CONNECTED',
|
153
|
+
'algorithm_key': 'min_max_uniform_quantize',
|
154
|
+
'op_config': {
|
155
|
+
'weight_tensor_config': {
|
156
|
+
'dtype': _TensorDataType.INT,
|
157
|
+
'num_bits': 8,
|
158
|
+
'symmetric': False,
|
159
|
+
'granularity': _QuantGranularity.CHANNELWISE,
|
160
|
+
},
|
161
|
+
# Equivalent to WEIGHT_ONLY.
|
162
|
+
'compute_precision': _ComputePrecision.FLOAT,
|
163
|
+
'explicit_dequantize': True,
|
164
|
+
},
|
165
|
+
},
|
166
|
+
]
|
167
|
+
self._recipe_manager.load_quantization_recipe(global_recipe)
|
168
|
+
tensor_quantization_params = (
|
169
|
+
self._params_generator.generate_quantization_parameters(
|
170
|
+
self._recipe_manager
|
171
|
+
)
|
172
|
+
)
|
173
|
+
# Every tensor in the model will have their params!
|
174
|
+
flatbuffer_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
|
175
|
+
tensors = flatbuffer_model.subgraphs[0].tensors
|
176
|
+
self.assertLen(tensor_quantization_params, len(tensors))
|
177
|
+
|
178
|
+
# Input tensor
|
179
|
+
tensor_name = 'serving_default_conv2d_input:0'
|
180
|
+
self._test_tensor_transformation_params(
|
181
|
+
0,
|
182
|
+
tensor_quantization_params,
|
183
|
+
tensor_name,
|
184
|
+
[_QuantTransformation.NO_QUANTIZE],
|
185
|
+
is_inbounding_tensor=True,
|
186
|
+
)
|
187
|
+
# Input tensor is produced from the virtual Input op.
|
188
|
+
transformation_config = tensor_quantization_params[tensor_name]
|
189
|
+
self.assertIsNotNone(transformation_config.producer)
|
190
|
+
self.assertEqual(transformation_config.producer.subgraph_op_id, -1)
|
191
|
+
|
192
|
+
# Intermediate tensor will have no_quantize at the both end
|
193
|
+
tensor_name = 'sequential/average_pooling2d/AvgPool'
|
194
|
+
self._test_tensor_transformation_params(
|
195
|
+
1,
|
196
|
+
tensor_quantization_params,
|
197
|
+
tensor_name,
|
198
|
+
[_QuantTransformation.NO_QUANTIZE],
|
199
|
+
is_inbounding_tensor=False,
|
200
|
+
) # output from average pool
|
201
|
+
self._test_tensor_transformation_params(
|
202
|
+
2,
|
203
|
+
tensor_quantization_params,
|
204
|
+
tensor_name,
|
205
|
+
[_QuantTransformation.NO_QUANTIZE],
|
206
|
+
is_inbounding_tensor=True,
|
207
|
+
) # input to Reshape
|
208
|
+
|
209
|
+
# First FC
|
210
|
+
self._test_tensor_transformation_params(
|
211
|
+
3,
|
212
|
+
tensor_quantization_params,
|
213
|
+
'sequential/flatten/Reshape',
|
214
|
+
[_QuantTransformation.NO_QUANTIZE],
|
215
|
+
is_inbounding_tensor=True,
|
216
|
+
) # input tensor
|
217
|
+
|
218
|
+
self._test_tensor_transformation_params(
|
219
|
+
3,
|
220
|
+
tensor_quantization_params,
|
221
|
+
'arith.constant1',
|
222
|
+
[
|
223
|
+
_QuantTransformation.ADD_DEQUANTIZE,
|
224
|
+
],
|
225
|
+
is_inbounding_tensor=True,
|
226
|
+
num_bits=8,
|
227
|
+
granularity=_QuantGranularity.CHANNELWISE,
|
228
|
+
symmetric=False,
|
229
|
+
) # weight tensor
|
230
|
+
self._test_tensor_transformation_params(
|
231
|
+
3,
|
232
|
+
tensor_quantization_params,
|
233
|
+
'arith.constant2',
|
234
|
+
[_QuantTransformation.NO_QUANTIZE],
|
235
|
+
is_inbounding_tensor=True,
|
236
|
+
) # bias tensor
|
237
|
+
self._test_tensor_transformation_params(
|
238
|
+
3,
|
239
|
+
tensor_quantization_params,
|
240
|
+
'sequential/dense/MatMul;sequential/dense/Relu;sequential/dense/BiasAdd',
|
241
|
+
[_QuantTransformation.NO_QUANTIZE],
|
242
|
+
is_inbounding_tensor=False,
|
243
|
+
) # output tensor
|
244
|
+
|
245
|
+
# Second FC
|
246
|
+
self._test_tensor_transformation_params(
|
247
|
+
4,
|
248
|
+
tensor_quantization_params,
|
249
|
+
'sequential/dense/MatMul;sequential/dense/Relu;sequential/dense/BiasAdd',
|
250
|
+
[_QuantTransformation.NO_QUANTIZE],
|
251
|
+
is_inbounding_tensor=True,
|
252
|
+
) # input tensor
|
253
|
+
self._test_tensor_transformation_params(
|
254
|
+
4,
|
255
|
+
tensor_quantization_params,
|
256
|
+
'arith.constant',
|
257
|
+
[
|
258
|
+
_QuantTransformation.ADD_DEQUANTIZE,
|
259
|
+
],
|
260
|
+
is_inbounding_tensor=True,
|
261
|
+
num_bits=8,
|
262
|
+
granularity=_QuantGranularity.CHANNELWISE,
|
263
|
+
symmetric=False,
|
264
|
+
) # weight tensor
|
265
|
+
self._test_tensor_transformation_params(
|
266
|
+
4,
|
267
|
+
tensor_quantization_params,
|
268
|
+
'sequential/dense_1/MatMul',
|
269
|
+
[_QuantTransformation.NO_QUANTIZE],
|
270
|
+
is_inbounding_tensor=False,
|
271
|
+
) # output tensor
|
272
|
+
|
273
|
+
# Model output tensor
|
274
|
+
tensor_name = 'StatefulPartitionedCall:0'
|
275
|
+
self._test_tensor_transformation_params(
|
276
|
+
5,
|
277
|
+
tensor_quantization_params,
|
278
|
+
tensor_name,
|
279
|
+
[_QuantTransformation.NO_QUANTIZE],
|
280
|
+
is_inbounding_tensor=False,
|
281
|
+
)
|
282
|
+
# Output tensor is consumed by the virtual Output op.
|
283
|
+
transformation_config = tensor_quantization_params[tensor_name]
|
284
|
+
self.assertLen(transformation_config.consumers, 1)
|
285
|
+
consumer = transformation_config.consumers[0]
|
286
|
+
self.assertEqual(consumer.subgraph_op_id, -1)
|
287
|
+
|
288
|
+
# TODO: b/330770656 - expand the test to cover mixed activation precision.
|
289
|
+
def test_generate_config_selective(self):
|
290
|
+
# Choose scope regex using Model Explorer
|
291
|
+
selective_quantization_recipe = [
|
292
|
+
{
|
293
|
+
'regex': '.*/dense/.*',
|
294
|
+
'operation': 'FULLY_CONNECTED',
|
295
|
+
'algorithm_key': 'min_max_uniform_quantize',
|
296
|
+
'op_config': {
|
297
|
+
'weight_tensor_config': {
|
298
|
+
'dtype': _TensorDataType.INT,
|
299
|
+
'num_bits': 8,
|
300
|
+
'symmetric': True,
|
301
|
+
'granularity': _QuantGranularity.CHANNELWISE,
|
302
|
+
},
|
303
|
+
# Equivalent to DRQ.
|
304
|
+
'compute_precision': _ComputePrecision.INTEGER,
|
305
|
+
'explicit_dequantize': False,
|
306
|
+
},
|
307
|
+
},
|
308
|
+
{
|
309
|
+
'regex': '.*/dense_1/.*',
|
310
|
+
'operation': 'FULLY_CONNECTED',
|
311
|
+
'algorithm_key': 'min_max_uniform_quantize',
|
312
|
+
'op_config': {
|
313
|
+
'weight_tensor_config': {
|
314
|
+
'dtype': _TensorDataType.INT,
|
315
|
+
'num_bits': 4,
|
316
|
+
'symmetric': False,
|
317
|
+
'granularity': _QuantGranularity.TENSORWISE,
|
318
|
+
},
|
319
|
+
# Equivalent to WEIGHT_ONLY.
|
320
|
+
'compute_precision': _ComputePrecision.FLOAT,
|
321
|
+
'explicit_dequantize': True,
|
322
|
+
},
|
323
|
+
},
|
324
|
+
]
|
325
|
+
self._recipe_manager.load_quantization_recipe(selective_quantization_recipe)
|
326
|
+
tensor_quantization_params = (
|
327
|
+
self._params_generator.generate_quantization_parameters(
|
328
|
+
self._recipe_manager
|
329
|
+
)
|
330
|
+
)
|
331
|
+
# FC weights for scope "dense"
|
332
|
+
self._test_tensor_transformation_params(
|
333
|
+
3,
|
334
|
+
tensor_quantization_params,
|
335
|
+
'arith.constant1',
|
336
|
+
[_QuantTransformation.QUANTIZE_TENSOR],
|
337
|
+
is_inbounding_tensor=True,
|
338
|
+
num_bits=8,
|
339
|
+
granularity=_QuantGranularity.CHANNELWISE,
|
340
|
+
symmetric=True,
|
341
|
+
)
|
342
|
+
|
343
|
+
# FC weights for scope "dense1"
|
344
|
+
self._test_tensor_transformation_params(
|
345
|
+
4,
|
346
|
+
tensor_quantization_params,
|
347
|
+
'arith.constant',
|
348
|
+
[
|
349
|
+
_QuantTransformation.ADD_DEQUANTIZE,
|
350
|
+
],
|
351
|
+
is_inbounding_tensor=True,
|
352
|
+
num_bits=4,
|
353
|
+
granularity=_QuantGranularity.TENSORWISE,
|
354
|
+
symmetric=False,
|
355
|
+
)
|
356
|
+
|
357
|
+
def test_generate_config_edge_cases(self):
|
358
|
+
|
359
|
+
selective_quantization_recipe = [
|
360
|
+
# Use the tensor name as scope directly.
|
361
|
+
{
|
362
|
+
'regex': 'sequential/dense_1/MatMul',
|
363
|
+
'operation': 'FULLY_CONNECTED',
|
364
|
+
'algorithm_key': 'min_max_uniform_quantize',
|
365
|
+
'op_config': {
|
366
|
+
'weight_tensor_config': {
|
367
|
+
'num_bits': 8,
|
368
|
+
'symmetric': True,
|
369
|
+
'granularity': _QuantGranularity.CHANNELWISE,
|
370
|
+
},
|
371
|
+
# Equivalent to DRQ.
|
372
|
+
'compute_precision': _ComputePrecision.INTEGER,
|
373
|
+
},
|
374
|
+
},
|
375
|
+
# Scope that does not exist in the model.
|
376
|
+
{
|
377
|
+
'regex': '.*/dense_3/.*',
|
378
|
+
'operation': 'FULLY_CONNECTED',
|
379
|
+
'algorithm_key': 'min_max_uniform_quantize',
|
380
|
+
'op_config': {
|
381
|
+
'weight_tensor_config': {
|
382
|
+
'num_bits': 4,
|
383
|
+
'symmetric': False,
|
384
|
+
'granularity': _QuantGranularity.TENSORWISE,
|
385
|
+
},
|
386
|
+
# Equivalent to WEIGHT_ONLY.
|
387
|
+
'compute_precision': _ComputePrecision.FLOAT,
|
388
|
+
'explicit_dequantize': True,
|
389
|
+
},
|
390
|
+
},
|
391
|
+
]
|
392
|
+
self._recipe_manager.load_quantization_recipe(selective_quantization_recipe)
|
393
|
+
tensor_quantization_params = (
|
394
|
+
self._params_generator.generate_quantization_parameters(
|
395
|
+
self._recipe_manager
|
396
|
+
)
|
397
|
+
)
|
398
|
+
# Only the second FC will be quantized
|
399
|
+
self._test_tensor_transformation_params(
|
400
|
+
3,
|
401
|
+
tensor_quantization_params,
|
402
|
+
'arith.constant1',
|
403
|
+
[_QuantTransformation.NO_QUANTIZE],
|
404
|
+
is_inbounding_tensor=True,
|
405
|
+
)
|
406
|
+
|
407
|
+
self._test_tensor_transformation_params(
|
408
|
+
4,
|
409
|
+
tensor_quantization_params,
|
410
|
+
'arith.constant',
|
411
|
+
[_QuantTransformation.QUANTIZE_TENSOR],
|
412
|
+
is_inbounding_tensor=True,
|
413
|
+
num_bits=8,
|
414
|
+
granularity=_QuantGranularity.CHANNELWISE,
|
415
|
+
symmetric=True,
|
416
|
+
)
|
417
|
+
|
418
|
+
@parameterized.parameters(
|
419
|
+
(True, _QuantGranularity.CHANNELWISE),
|
420
|
+
(True, _QuantGranularity.TENSORWISE),
|
421
|
+
(False, _QuantGranularity.CHANNELWISE),
|
422
|
+
(False, _QuantGranularity.TENSORWISE),
|
423
|
+
)
|
424
|
+
def test_generate_config_int8xint8_single_fc(
|
425
|
+
self, act_symmetric, channelwise_weight
|
426
|
+
):
|
427
|
+
single_fc_model_path = os.path.join(
|
428
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/single_fc.tflite'
|
429
|
+
)
|
430
|
+
self._recipe_manager.add_quantization_config(
|
431
|
+
regex='.*',
|
432
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
433
|
+
algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
434
|
+
op_config=qtyping.OpQuantizationConfig(
|
435
|
+
activation_tensor_config=_TensorQuantConfig(
|
436
|
+
num_bits=8, symmetric=act_symmetric
|
437
|
+
),
|
438
|
+
weight_tensor_config=_TensorQuantConfig(
|
439
|
+
num_bits=8, symmetric=True, granularity=channelwise_weight
|
440
|
+
),
|
441
|
+
# Equivalent to SRQ.
|
442
|
+
compute_precision=_ComputePrecision.INTEGER,
|
443
|
+
),
|
444
|
+
)
|
445
|
+
|
446
|
+
params_generator_single_fc = params_generator.ParamsGenerator(
|
447
|
+
single_fc_model_path
|
448
|
+
)
|
449
|
+
# Raise error when missing QSVs.
|
450
|
+
error_message = 'Model quantization statistics values (QSVs) are required'
|
451
|
+
with self.assertRaisesWithPredicateMatch(
|
452
|
+
RuntimeError, lambda err: error_message in str(err)
|
453
|
+
):
|
454
|
+
params_generator_single_fc.generate_quantization_parameters(
|
455
|
+
self._recipe_manager
|
456
|
+
)
|
457
|
+
|
458
|
+
# Calibrate then quantize
|
459
|
+
model_calibrator = calibrator.Calibrator(single_fc_model_path)
|
460
|
+
calibration_data = _get_calibration_data(
|
461
|
+
_single_fc_model_representative_dataset_gen()
|
462
|
+
)
|
463
|
+
model_calibrator.calibrate(calibration_data, self._recipe_manager)
|
464
|
+
model_qsvs = model_calibrator.get_model_qsvs()
|
465
|
+
quant_params = params_generator_single_fc.generate_quantization_parameters(
|
466
|
+
self._recipe_manager,
|
467
|
+
model_qsvs,
|
468
|
+
)
|
469
|
+
self.assertLen(quant_params, 4)
|
470
|
+
|
471
|
+
# Input tensor producer (from the virtual input op).
|
472
|
+
self._test_tensor_transformation_params(
|
473
|
+
-1, # virtual input op.
|
474
|
+
quant_params,
|
475
|
+
'serving_default_input_1:0',
|
476
|
+
[_QuantTransformation.ADD_DEQUANTIZE],
|
477
|
+
num_bits=8,
|
478
|
+
granularity=_QuantGranularity.TENSORWISE,
|
479
|
+
symmetric=act_symmetric,
|
480
|
+
is_inbounding_tensor=False,
|
481
|
+
)
|
482
|
+
# Input tensor consumer.
|
483
|
+
self._test_tensor_transformation_params(
|
484
|
+
0,
|
485
|
+
quant_params,
|
486
|
+
'serving_default_input_1:0',
|
487
|
+
[_QuantTransformation.ADD_QUANTIZE],
|
488
|
+
num_bits=8,
|
489
|
+
granularity=_QuantGranularity.TENSORWISE,
|
490
|
+
symmetric=act_symmetric,
|
491
|
+
is_inbounding_tensor=True,
|
492
|
+
)
|
493
|
+
|
494
|
+
# output tensor producer.
|
495
|
+
self._test_tensor_transformation_params(
|
496
|
+
0,
|
497
|
+
quant_params,
|
498
|
+
'StatefulPartitionedCall:0',
|
499
|
+
[_QuantTransformation.ADD_DEQUANTIZE],
|
500
|
+
num_bits=8,
|
501
|
+
granularity=_QuantGranularity.TENSORWISE,
|
502
|
+
symmetric=act_symmetric,
|
503
|
+
is_inbounding_tensor=False,
|
504
|
+
)
|
505
|
+
# output tensor consumer (into the virtual output op).
|
506
|
+
self._test_tensor_transformation_params(
|
507
|
+
-1, # virtual output op.
|
508
|
+
quant_params,
|
509
|
+
'StatefulPartitionedCall:0',
|
510
|
+
[_QuantTransformation.ADD_QUANTIZE],
|
511
|
+
num_bits=8,
|
512
|
+
granularity=_QuantGranularity.TENSORWISE,
|
513
|
+
symmetric=act_symmetric,
|
514
|
+
is_inbounding_tensor=True,
|
515
|
+
)
|
516
|
+
|
517
|
+
# weights
|
518
|
+
self._test_tensor_transformation_params(
|
519
|
+
0,
|
520
|
+
quant_params,
|
521
|
+
'sequential/dense/MatMul',
|
522
|
+
[_QuantTransformation.QUANTIZE_TENSOR],
|
523
|
+
num_bits=8,
|
524
|
+
granularity=channelwise_weight,
|
525
|
+
symmetric=True,
|
526
|
+
is_inbounding_tensor=True,
|
527
|
+
)
|
528
|
+
|
529
|
+
# bias
|
530
|
+
self._test_tensor_transformation_params(
|
531
|
+
0,
|
532
|
+
quant_params,
|
533
|
+
'sequential/dense/BiasAdd/ReadVariableOp',
|
534
|
+
[_QuantTransformation.QUANTIZE_TENSOR],
|
535
|
+
num_bits=32,
|
536
|
+
granularity=channelwise_weight,
|
537
|
+
symmetric=True,
|
538
|
+
is_inbounding_tensor=True,
|
539
|
+
)
|
540
|
+
|
541
|
+
@parameterized.parameters('weight_only', 'DRQ')
|
542
|
+
def test_generate_params_buffer_sharing_graphs_succeeds(
|
543
|
+
self, the_other_fc_difference
|
544
|
+
):
|
545
|
+
model_path = os.path.join(
|
546
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/weight_sharing_fcs.tflite'
|
547
|
+
)
|
548
|
+
self._recipe_manager.add_quantization_config(
|
549
|
+
regex='.*',
|
550
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
551
|
+
op_config=qtyping.OpQuantizationConfig(
|
552
|
+
weight_tensor_config=_TensorQuantConfig(num_bits=8, symmetric=True),
|
553
|
+
# Equivalent to WEIGHT_ONLY.
|
554
|
+
compute_precision=_ComputePrecision.FLOAT,
|
555
|
+
explicit_dequantize=True,
|
556
|
+
),
|
557
|
+
)
|
558
|
+
if the_other_fc_difference == 'DRQ':
|
559
|
+
self._recipe_manager.add_quantization_config(
|
560
|
+
regex='PartitionedCall_1:0',
|
561
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
562
|
+
op_config=qtyping.OpQuantizationConfig(
|
563
|
+
# Equivalent to DRQ.
|
564
|
+
compute_precision=_ComputePrecision.INTEGER,
|
565
|
+
),
|
566
|
+
)
|
567
|
+
pg = params_generator.ParamsGenerator(model_path)
|
568
|
+
quant_params = pg.generate_quantization_parameters(
|
569
|
+
self._recipe_manager,
|
570
|
+
)
|
571
|
+
self.assertLen(quant_params, 6)
|
572
|
+
|
573
|
+
@parameterized.parameters('no_quant', 'execution_mode', 'num_bits')
|
574
|
+
def test_generate_params_buffer_sharing_graphs_fails(
|
575
|
+
self, the_other_fc_difference
|
576
|
+
):
|
577
|
+
model_path = os.path.join(
|
578
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/weight_sharing_fcs.tflite'
|
579
|
+
)
|
580
|
+
# Setup the quantization config for the first FC.
|
581
|
+
self._recipe_manager.add_quantization_config(
|
582
|
+
regex='PartitionedCall:0',
|
583
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
584
|
+
op_config=qtyping.OpQuantizationConfig(
|
585
|
+
weight_tensor_config=_TensorQuantConfig(num_bits=8),
|
586
|
+
compute_precision=_ComputePrecision.INTEGER,
|
587
|
+
),
|
588
|
+
)
|
589
|
+
# Setup the quantization config for the second FC (weight shared with the
|
590
|
+
# first FC).
|
591
|
+
if the_other_fc_difference == 'no_quant':
|
592
|
+
pass
|
593
|
+
elif the_other_fc_difference == 'num_bits':
|
594
|
+
self._recipe_manager.add_quantization_config(
|
595
|
+
regex='PartitionedCall_1:0',
|
596
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
597
|
+
op_config=qtyping.OpQuantizationConfig(
|
598
|
+
weight_tensor_config=_TensorQuantConfig(num_bits=4),
|
599
|
+
compute_precision=_ComputePrecision.INTEGER,
|
600
|
+
),
|
601
|
+
)
|
602
|
+
pg = params_generator.ParamsGenerator(model_path)
|
603
|
+
error_message = 'do not have the same quantization parameters'
|
604
|
+
with self.assertRaisesWithPredicateMatch(
|
605
|
+
RuntimeError, lambda err: error_message in str(err)
|
606
|
+
):
|
607
|
+
pg.generate_quantization_parameters(
|
608
|
+
self._recipe_manager,
|
609
|
+
)
|
610
|
+
|
611
|
+
@parameterized.named_parameters(
|
612
|
+
dict(
|
613
|
+
testcase_name='producer_incompatible',
|
614
|
+
param1=qtyping.TensorTransformationParams(
|
615
|
+
tensor_name='tfl.quantize',
|
616
|
+
producer=qtyping.OpToTensorParams(
|
617
|
+
subgraph_op_id=0,
|
618
|
+
transformations=[qtyping.QuantTransformation.ADD_DEQUANTIZE],
|
619
|
+
parameters=qtyping.UniformQuantParams(
|
620
|
+
8, None, np.array([1]), np.array([0])
|
621
|
+
),
|
622
|
+
),
|
623
|
+
consumers=[
|
624
|
+
qtyping.OpToTensorParams(
|
625
|
+
subgraph_op_id=1,
|
626
|
+
transformations=[
|
627
|
+
qtyping.QuantTransformation.ADD_QUANTIZE
|
628
|
+
],
|
629
|
+
parameters=qtyping.UniformQuantParams(
|
630
|
+
8, None, np.array([1]), np.array([0])
|
631
|
+
),
|
632
|
+
),
|
633
|
+
qtyping.OpToTensorParams(
|
634
|
+
subgraph_op_id=2,
|
635
|
+
transformations=[
|
636
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
637
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
638
|
+
],
|
639
|
+
parameters=qtyping.UniformQuantParams(
|
640
|
+
8, None, np.array([1]), np.array([0])
|
641
|
+
),
|
642
|
+
),
|
643
|
+
qtyping.OpToTensorParams(
|
644
|
+
subgraph_op_id=3,
|
645
|
+
transformations=[
|
646
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
647
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
648
|
+
],
|
649
|
+
parameters=qtyping.UniformQuantParams(
|
650
|
+
8, None, np.array([1]), np.array([0])
|
651
|
+
),
|
652
|
+
),
|
653
|
+
qtyping.OpToTensorParams(
|
654
|
+
subgraph_op_id=4,
|
655
|
+
transformations=[
|
656
|
+
qtyping.QuantTransformation.NO_QUANTIZE,
|
657
|
+
],
|
658
|
+
parameters=qtyping.UniformQuantParams(
|
659
|
+
8, None, np.array([1]), np.array([0])
|
660
|
+
),
|
661
|
+
),
|
662
|
+
],
|
663
|
+
),
|
664
|
+
param2=qtyping.TensorTransformationParams(
|
665
|
+
'tfl.other_quantize',
|
666
|
+
qtyping.OpToTensorParams(
|
667
|
+
subgraph_op_id=0,
|
668
|
+
transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
|
669
|
+
parameters=qtyping.UniformQuantParams(
|
670
|
+
8, None, np.array([1]), np.array([0])
|
671
|
+
),
|
672
|
+
),
|
673
|
+
[
|
674
|
+
qtyping.OpToTensorParams(
|
675
|
+
subgraph_op_id=1,
|
676
|
+
transformations=[
|
677
|
+
qtyping.QuantTransformation.ADD_QUANTIZE
|
678
|
+
],
|
679
|
+
parameters=qtyping.UniformQuantParams(
|
680
|
+
8, None, np.array([1]), np.array([0])
|
681
|
+
),
|
682
|
+
),
|
683
|
+
qtyping.OpToTensorParams(
|
684
|
+
subgraph_op_id=2,
|
685
|
+
transformations=[
|
686
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
687
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
688
|
+
],
|
689
|
+
parameters=qtyping.UniformQuantParams(
|
690
|
+
8, None, np.array([1]), np.array([0])
|
691
|
+
),
|
692
|
+
),
|
693
|
+
qtyping.OpToTensorParams(
|
694
|
+
subgraph_op_id=3,
|
695
|
+
transformations=[
|
696
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
697
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
698
|
+
],
|
699
|
+
parameters=qtyping.UniformQuantParams(
|
700
|
+
8, None, np.array([1]), np.array([0])
|
701
|
+
),
|
702
|
+
),
|
703
|
+
],
|
704
|
+
),
|
705
|
+
expected=False,
|
706
|
+
),
|
707
|
+
dict(
|
708
|
+
testcase_name='param2_consumer_incompatible',
|
709
|
+
param1=qtyping.TensorTransformationParams(
|
710
|
+
tensor_name='tfl.quantize',
|
711
|
+
producer=qtyping.OpToTensorParams(
|
712
|
+
subgraph_op_id=0,
|
713
|
+
transformations=[qtyping.QuantTransformation.ADD_QUANTIZE],
|
714
|
+
parameters=qtyping.UniformQuantParams(
|
715
|
+
8, None, np.array([1]), np.array([0])
|
716
|
+
),
|
717
|
+
),
|
718
|
+
consumers=[
|
719
|
+
qtyping.OpToTensorParams(
|
720
|
+
subgraph_op_id=1,
|
721
|
+
transformations=[
|
722
|
+
qtyping.QuantTransformation.ADD_QUANTIZE
|
723
|
+
],
|
724
|
+
parameters=qtyping.UniformQuantParams(
|
725
|
+
8, None, np.array([1]), np.array([0])
|
726
|
+
),
|
727
|
+
),
|
728
|
+
qtyping.OpToTensorParams(
|
729
|
+
subgraph_op_id=2,
|
730
|
+
transformations=[
|
731
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
732
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
733
|
+
],
|
734
|
+
parameters=qtyping.UniformQuantParams(
|
735
|
+
8, None, np.array([1]), np.array([0])
|
736
|
+
),
|
737
|
+
),
|
738
|
+
qtyping.OpToTensorParams(
|
739
|
+
subgraph_op_id=3,
|
740
|
+
transformations=[
|
741
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
742
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
743
|
+
],
|
744
|
+
parameters=qtyping.UniformQuantParams(
|
745
|
+
8, None, np.array([1]), np.array([0])
|
746
|
+
),
|
747
|
+
),
|
748
|
+
],
|
749
|
+
),
|
750
|
+
param2=qtyping.TensorTransformationParams(
|
751
|
+
'tfl.other_quantize',
|
752
|
+
qtyping.OpToTensorParams(
|
753
|
+
subgraph_op_id=0,
|
754
|
+
transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
|
755
|
+
parameters=qtyping.UniformQuantParams(
|
756
|
+
8, None, np.array([1]), np.array([0])
|
757
|
+
),
|
758
|
+
),
|
759
|
+
[
|
760
|
+
qtyping.OpToTensorParams(
|
761
|
+
subgraph_op_id=1,
|
762
|
+
transformations=[
|
763
|
+
qtyping.QuantTransformation.ADD_QUANTIZE
|
764
|
+
],
|
765
|
+
parameters=qtyping.UniformQuantParams(
|
766
|
+
8, None, np.array([1]), np.array([0])
|
767
|
+
),
|
768
|
+
),
|
769
|
+
qtyping.OpToTensorParams(
|
770
|
+
subgraph_op_id=2,
|
771
|
+
transformations=[
|
772
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
773
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
774
|
+
],
|
775
|
+
parameters=qtyping.UniformQuantParams(
|
776
|
+
8, None, np.array([1]), np.array([0])
|
777
|
+
),
|
778
|
+
),
|
779
|
+
qtyping.OpToTensorParams(
|
780
|
+
subgraph_op_id=3,
|
781
|
+
transformations=[
|
782
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
783
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
784
|
+
],
|
785
|
+
parameters=qtyping.UniformQuantParams(
|
786
|
+
8, None, np.array([1]), np.array([0])
|
787
|
+
),
|
788
|
+
),
|
789
|
+
qtyping.OpToTensorParams(
|
790
|
+
subgraph_op_id=4,
|
791
|
+
transformations=[
|
792
|
+
qtyping.QuantTransformation.QUANTIZE_TENSOR,
|
793
|
+
],
|
794
|
+
parameters=qtyping.UniformQuantParams(
|
795
|
+
8, None, np.array([1]), np.array([0])
|
796
|
+
),
|
797
|
+
),
|
798
|
+
],
|
799
|
+
),
|
800
|
+
expected=False,
|
801
|
+
),
|
802
|
+
dict(
|
803
|
+
testcase_name='compatible',
|
804
|
+
param1=qtyping.TensorTransformationParams(
|
805
|
+
tensor_name='tfl.quantize',
|
806
|
+
producer=None,
|
807
|
+
consumers=[
|
808
|
+
qtyping.OpToTensorParams(
|
809
|
+
subgraph_op_id=2,
|
810
|
+
transformations=[
|
811
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
812
|
+
],
|
813
|
+
parameters=qtyping.UniformQuantParams(
|
814
|
+
8, None, np.array([1]), np.array([0])
|
815
|
+
),
|
816
|
+
),
|
817
|
+
qtyping.OpToTensorParams(
|
818
|
+
subgraph_op_id=3,
|
819
|
+
transformations=[
|
820
|
+
qtyping.QuantTransformation.NO_QUANTIZE,
|
821
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
822
|
+
],
|
823
|
+
parameters=qtyping.UniformQuantParams(
|
824
|
+
8, None, np.array([1]), np.array([0])
|
825
|
+
),
|
826
|
+
),
|
827
|
+
qtyping.OpToTensorParams(
|
828
|
+
subgraph_op_id=4,
|
829
|
+
transformations=[
|
830
|
+
qtyping.QuantTransformation.NO_QUANTIZE,
|
831
|
+
],
|
832
|
+
),
|
833
|
+
],
|
834
|
+
),
|
835
|
+
param2=qtyping.TensorTransformationParams(
|
836
|
+
'tfl.other_quantize',
|
837
|
+
None,
|
838
|
+
[
|
839
|
+
qtyping.OpToTensorParams(
|
840
|
+
subgraph_op_id=1,
|
841
|
+
transformations=[
|
842
|
+
qtyping.QuantTransformation.ADD_QUANTIZE
|
843
|
+
],
|
844
|
+
parameters=qtyping.UniformQuantParams(
|
845
|
+
8, None, np.array([1]), np.array([0])
|
846
|
+
),
|
847
|
+
),
|
848
|
+
qtyping.OpToTensorParams(
|
849
|
+
subgraph_op_id=2,
|
850
|
+
transformations=[
|
851
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
852
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
853
|
+
],
|
854
|
+
parameters=qtyping.UniformQuantParams(
|
855
|
+
8, None, np.array([1]), np.array([0])
|
856
|
+
),
|
857
|
+
),
|
858
|
+
qtyping.OpToTensorParams(
|
859
|
+
subgraph_op_id=3,
|
860
|
+
transformations=[
|
861
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
862
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
863
|
+
],
|
864
|
+
parameters=qtyping.UniformQuantParams(
|
865
|
+
8, None, np.array([1]), np.array([0])
|
866
|
+
),
|
867
|
+
),
|
868
|
+
qtyping.OpToTensorParams(
|
869
|
+
subgraph_op_id=4,
|
870
|
+
transformations=[
|
871
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
872
|
+
],
|
873
|
+
parameters=qtyping.UniformQuantParams(
|
874
|
+
8, None, np.array([1]), np.array([0])
|
875
|
+
),
|
876
|
+
),
|
877
|
+
],
|
878
|
+
),
|
879
|
+
expected=True,
|
880
|
+
),
|
881
|
+
)
|
882
|
+
def test_params_compatible(self, param1, param2, expected):
|
883
|
+
# adding a test to make production coverage happy.
|
884
|
+
self.assertEqual(
|
885
|
+
params_generator._compatible_tensor_transformation_params(
|
886
|
+
param1, param2
|
887
|
+
),
|
888
|
+
expected,
|
889
|
+
)
|
890
|
+
|
891
|
+
def test_model_with_duplicated_tensor_names_fails(self):
|
892
|
+
model_path = os.path.join(
|
893
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/duplicated_tensor_names.tflite'
|
894
|
+
)
|
895
|
+
error_message = 'Tensor name test_same_name is not unique in the model.'
|
896
|
+
with self.assertRaisesWithPredicateMatch(
|
897
|
+
ValueError, lambda err: error_message in str(err)
|
898
|
+
):
|
899
|
+
params_generator.ParamsGenerator(model_path)
|
900
|
+
|
901
|
+
def test_quantize_integer_input_output(self):
|
902
|
+
model_path = os.path.join(
|
903
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/single_transpose_int32.tflite'
|
904
|
+
)
|
905
|
+
self._recipe_manager.add_quantization_config(
|
906
|
+
regex='.*',
|
907
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
908
|
+
algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
909
|
+
op_config=qtyping.OpQuantizationConfig(
|
910
|
+
activation_tensor_config=_TensorQuantConfig(
|
911
|
+
num_bits=8, symmetric=False
|
912
|
+
),
|
913
|
+
weight_tensor_config=_TensorQuantConfig(num_bits=8, symmetric=True),
|
914
|
+
# Equivalent to SRQ.
|
915
|
+
compute_precision=_ComputePrecision.INTEGER,
|
916
|
+
),
|
917
|
+
)
|
918
|
+
pg = params_generator.ParamsGenerator(model_path)
|
919
|
+
|
920
|
+
# Calibrate then quantize.
|
921
|
+
model_calibrator = calibrator.Calibrator(model_path)
|
922
|
+
calibration_data = _get_calibration_data(
|
923
|
+
_int_transpose_model_representative_dataset_gen()
|
924
|
+
)
|
925
|
+
model_calibrator.calibrate(calibration_data, self._recipe_manager)
|
926
|
+
model_qsvs = model_calibrator.get_model_qsvs()
|
927
|
+
quant_params = pg.generate_quantization_parameters(
|
928
|
+
self._recipe_manager,
|
929
|
+
model_qsvs,
|
930
|
+
)
|
931
|
+
self.assertLen(quant_params, 3)
|
932
|
+
|
933
|
+
self._test_tensor_transformation_params(
|
934
|
+
-1, # virtual input op.
|
935
|
+
quant_params,
|
936
|
+
'serving_default_input_2:0',
|
937
|
+
[_QuantTransformation.NO_QUANTIZE],
|
938
|
+
is_inbounding_tensor=False,
|
939
|
+
)
|
940
|
+
# Input tensor consumer.
|
941
|
+
self._test_tensor_transformation_params(
|
942
|
+
0,
|
943
|
+
quant_params,
|
944
|
+
'serving_default_input_2:0',
|
945
|
+
[_QuantTransformation.NO_QUANTIZE],
|
946
|
+
is_inbounding_tensor=True,
|
947
|
+
)
|
948
|
+
|
949
|
+
# Output tensor producer.
|
950
|
+
self._test_tensor_transformation_params(
|
951
|
+
0,
|
952
|
+
quant_params,
|
953
|
+
'PartitionedCall:0',
|
954
|
+
[_QuantTransformation.NO_QUANTIZE],
|
955
|
+
is_inbounding_tensor=False,
|
956
|
+
)
|
957
|
+
# output tensor consumer (into the virtual output op).
|
958
|
+
self._test_tensor_transformation_params(
|
959
|
+
-1, # virtual output op.
|
960
|
+
quant_params,
|
961
|
+
'PartitionedCall:0',
|
962
|
+
[_QuantTransformation.NO_QUANTIZE],
|
963
|
+
is_inbounding_tensor=True,
|
964
|
+
)
|
965
|
+
|
966
|
+
# perm
|
967
|
+
self._test_tensor_transformation_params(
|
968
|
+
0,
|
969
|
+
quant_params,
|
970
|
+
'sequential_1/permute_1/transpose/perm',
|
971
|
+
[_QuantTransformation.NO_QUANTIZE],
|
972
|
+
is_inbounding_tensor=True,
|
973
|
+
)
|
974
|
+
|
975
|
+
def _test_tensor_transformation_params(
|
976
|
+
self,
|
977
|
+
subgraph_op_id,
|
978
|
+
quant_params,
|
979
|
+
tensor_name,
|
980
|
+
transformations,
|
981
|
+
is_inbounding_tensor,
|
982
|
+
num_bits=8,
|
983
|
+
granularity=_QuantGranularity.TENSORWISE,
|
984
|
+
symmetric=True,
|
985
|
+
quantized_dimension=0,
|
986
|
+
):
|
987
|
+
"""Helper function to test tensor transformation parameters are correct."""
|
988
|
+
self.assertIn(tensor_name, quant_params)
|
989
|
+
transformation_config = quant_params[tensor_name]
|
990
|
+
self.assertEqual(transformation_config.tensor_name, tensor_name)
|
991
|
+
if is_inbounding_tensor:
|
992
|
+
self.assertLen(transformation_config.consumers, 1)
|
993
|
+
op_config = transformation_config.consumers[0]
|
994
|
+
else:
|
995
|
+
op_config = transformation_config.producer
|
996
|
+
self.assertIsNotNone(op_config)
|
997
|
+
self.assertEqual(op_config.subgraph_op_id, subgraph_op_id)
|
998
|
+
self.assertSequenceEqual(op_config.transformations, transformations)
|
999
|
+
if transformations == [_QuantTransformation.NO_QUANTIZE]:
|
1000
|
+
self.assertIsNone(op_config.parameters)
|
1001
|
+
else:
|
1002
|
+
quantization_params = op_config.parameters
|
1003
|
+
self.assertIsNotNone(quantization_params)
|
1004
|
+
if granularity is _QuantGranularity.CHANNELWISE:
|
1005
|
+
self.assertEqual(
|
1006
|
+
quantization_params.quantized_dimension, quantized_dimension
|
1007
|
+
)
|
1008
|
+
else:
|
1009
|
+
self.assertIsNone(quantization_params.quantized_dimension)
|
1010
|
+
self.assertEqual(quantization_params.num_bits, num_bits)
|
1011
|
+
if symmetric:
|
1012
|
+
self.assertEqual(np.sum(abs(quantization_params.zero_point)), 0)
|
1013
|
+
else:
|
1014
|
+
self.assertEqual(
|
1015
|
+
len(quantization_params.scale),
|
1016
|
+
len(quantization_params.zero_point),
|
1017
|
+
)
|
1018
|
+
|
1019
|
+
|
1020
|
+
class ParamsGeneratorAlreadyQuantizedModelTest(googletest.TestCase):
|
1021
|
+
|
1022
|
+
def test_check_is_float_model_succeeds_when_model_is_float(self):
|
1023
|
+
test_model_path = os.path.join(
|
1024
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/conv_fc_mnist.tflite'
|
1025
|
+
)
|
1026
|
+
_ = params_generator.ParamsGenerator(test_model_path)
|
1027
|
+
|
1028
|
+
def test_check_is_float_model_raises_error_when_model_is_quantized(self):
|
1029
|
+
test_model_path = os.path.join(
|
1030
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/mnist_quantized.tflite'
|
1031
|
+
)
|
1032
|
+
with self.assertRaisesRegex(
|
1033
|
+
ValueError,
|
1034
|
+
'The input model for quantization parameters generation is not a float'
|
1035
|
+
' model.',
|
1036
|
+
):
|
1037
|
+
_ = params_generator.ParamsGenerator(test_model_path)
|
1038
|
+
|
1039
|
+
|
1040
|
+
if __name__ == '__main__':
|
1041
|
+
googletest.main()
|