ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1041 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Tests for params_generator."""
17
+
18
+ from collections.abc import Generator
19
+ import os
20
+ from typing import Any
21
+
22
+ from absl.testing import parameterized
23
+ import numpy as np
24
+
25
+ from tensorflow.python.platform import googletest
26
+ from ai_edge_quantizer import calibrator
27
+ from ai_edge_quantizer import params_generator
28
+ from ai_edge_quantizer import qtyping
29
+ from ai_edge_quantizer import recipe_manager
30
+ from ai_edge_quantizer.utils import test_utils
31
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
32
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
33
+
34
+
35
+ _ComputePrecision = qtyping.ComputePrecision
36
+ _TensorDataType = qtyping.TensorDataType
37
+ _TensorQuantConfig = qtyping.TensorQuantizationConfig
38
+ _QuantTransformation = qtyping.QuantTransformation
39
+ _AlgorithmName = recipe_manager.AlgorithmName
40
+ _QuantGranularity = qtyping.QuantGranularity
41
+
42
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('')
43
+
44
+
45
+ def _single_fc_model_representative_dataset_gen(num_samples=5):
46
+ for _ in range(num_samples):
47
+ yield {'input_1': np.random.rand(1, 8).astype(np.float32)}
48
+
49
+
50
+ def _int_transpose_model_representative_dataset_gen(num_samples=5):
51
+ data = []
52
+ for _ in range(num_samples):
53
+ data.append({'input_2': np.random.rand(1, 2, 3, 4).astype(np.int32)})
54
+ return data
55
+
56
+
57
+ def _get_calibration_data(
58
+ dataset_gen: Generator[dict[str, Any], Any, None],
59
+ ) -> dict[str, Any]:
60
+ calibration_samples = [sample for sample in dataset_gen]
61
+ calibration_data = {
62
+ tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples,
63
+ }
64
+ return calibration_data
65
+
66
+
67
+ class ParamsGeneratorTest(parameterized.TestCase):
68
+
69
+ def setUp(self):
70
+ super().setUp()
71
+ self._test_model_path = os.path.join(
72
+ TEST_DATA_PREFIX_PATH, 'tests/models/conv_fc_mnist.tflite'
73
+ )
74
+ self._recipe_manager = recipe_manager.RecipeManager()
75
+ self._params_generator = params_generator.ParamsGenerator(
76
+ self._test_model_path
77
+ )
78
+
79
+ def test_update_model_quant_results(self):
80
+ params_from_target = qtyping.TensorTransformationParams(
81
+ tensor_name='test_tensor0',
82
+ consumers=[
83
+ qtyping.OpToTensorParams(
84
+ subgraph_op_id=3,
85
+ transformations=[
86
+ _QuantTransformation.ADD_QUANTIZE,
87
+ _QuantTransformation.ADD_DEQUANTIZE,
88
+ ],
89
+ )
90
+ ],
91
+ )
92
+ # Test add new tensor from target
93
+ self._params_generator._update_model_quant_results([params_from_target])
94
+ self.assertIsNotNone(
95
+ 'test_tensor0' in self._params_generator.model_quant_results
96
+ )
97
+ tensor_params = self._params_generator.model_quant_results['test_tensor0']
98
+ self.assertEqual(
99
+ tensor_params,
100
+ params_from_target,
101
+ )
102
+ # Test update new tensor from source
103
+ params_from_source = qtyping.TensorTransformationParams(
104
+ tensor_name='test_tensor0',
105
+ producer=qtyping.OpToTensorParams(
106
+ subgraph_op_id=3,
107
+ transformations=[
108
+ _QuantTransformation.ADD_DEQUANTIZE,
109
+ ],
110
+ ),
111
+ )
112
+ self._params_generator._update_model_quant_results([params_from_source])
113
+ tensor_params = self._params_generator.model_quant_results['test_tensor0']
114
+ self.assertEqual(
115
+ tensor_params.producer,
116
+ params_from_source.producer,
117
+ )
118
+
119
+ # We can have multiple target op params
120
+ params_from_target2 = qtyping.TensorTransformationParams(
121
+ tensor_name='test_tensor0',
122
+ consumers=[
123
+ qtyping.OpToTensorParams(
124
+ subgraph_op_id=3,
125
+ transformations=[
126
+ _QuantTransformation.NO_QUANTIZE,
127
+ ],
128
+ )
129
+ ],
130
+ )
131
+ self._params_generator._update_model_quant_results([params_from_target2])
132
+ tensor_params = self._params_generator.model_quant_results['test_tensor0']
133
+ self.assertSequenceEqual(
134
+ tensor_params.consumers,
135
+ params_from_target.consumers + params_from_target2.consumers,
136
+ )
137
+
138
+ # but only a single source op params
139
+ error_message = (
140
+ 'received multiple quantization parameters from the source op'
141
+ )
142
+ with self.assertRaisesWithPredicateMatch(
143
+ RuntimeError, lambda err: error_message in str(err)
144
+ ):
145
+ self._params_generator._update_model_quant_results([params_from_source])
146
+
147
+ def test_generate_config_global(self):
148
+ # Quantize all fully_connected.
149
+ global_recipe = [
150
+ {
151
+ 'regex': '.*',
152
+ 'operation': 'FULLY_CONNECTED',
153
+ 'algorithm_key': 'min_max_uniform_quantize',
154
+ 'op_config': {
155
+ 'weight_tensor_config': {
156
+ 'dtype': _TensorDataType.INT,
157
+ 'num_bits': 8,
158
+ 'symmetric': False,
159
+ 'granularity': _QuantGranularity.CHANNELWISE,
160
+ },
161
+ # Equivalent to WEIGHT_ONLY.
162
+ 'compute_precision': _ComputePrecision.FLOAT,
163
+ 'explicit_dequantize': True,
164
+ },
165
+ },
166
+ ]
167
+ self._recipe_manager.load_quantization_recipe(global_recipe)
168
+ tensor_quantization_params = (
169
+ self._params_generator.generate_quantization_parameters(
170
+ self._recipe_manager
171
+ )
172
+ )
173
+ # Every tensor in the model will have their params!
174
+ flatbuffer_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
175
+ tensors = flatbuffer_model.subgraphs[0].tensors
176
+ self.assertLen(tensor_quantization_params, len(tensors))
177
+
178
+ # Input tensor
179
+ tensor_name = 'serving_default_conv2d_input:0'
180
+ self._test_tensor_transformation_params(
181
+ 0,
182
+ tensor_quantization_params,
183
+ tensor_name,
184
+ [_QuantTransformation.NO_QUANTIZE],
185
+ is_inbounding_tensor=True,
186
+ )
187
+ # Input tensor is produced from the virtual Input op.
188
+ transformation_config = tensor_quantization_params[tensor_name]
189
+ self.assertIsNotNone(transformation_config.producer)
190
+ self.assertEqual(transformation_config.producer.subgraph_op_id, -1)
191
+
192
+ # Intermediate tensor will have no_quantize at the both end
193
+ tensor_name = 'sequential/average_pooling2d/AvgPool'
194
+ self._test_tensor_transformation_params(
195
+ 1,
196
+ tensor_quantization_params,
197
+ tensor_name,
198
+ [_QuantTransformation.NO_QUANTIZE],
199
+ is_inbounding_tensor=False,
200
+ ) # output from average pool
201
+ self._test_tensor_transformation_params(
202
+ 2,
203
+ tensor_quantization_params,
204
+ tensor_name,
205
+ [_QuantTransformation.NO_QUANTIZE],
206
+ is_inbounding_tensor=True,
207
+ ) # input to Reshape
208
+
209
+ # First FC
210
+ self._test_tensor_transformation_params(
211
+ 3,
212
+ tensor_quantization_params,
213
+ 'sequential/flatten/Reshape',
214
+ [_QuantTransformation.NO_QUANTIZE],
215
+ is_inbounding_tensor=True,
216
+ ) # input tensor
217
+
218
+ self._test_tensor_transformation_params(
219
+ 3,
220
+ tensor_quantization_params,
221
+ 'arith.constant1',
222
+ [
223
+ _QuantTransformation.ADD_DEQUANTIZE,
224
+ ],
225
+ is_inbounding_tensor=True,
226
+ num_bits=8,
227
+ granularity=_QuantGranularity.CHANNELWISE,
228
+ symmetric=False,
229
+ ) # weight tensor
230
+ self._test_tensor_transformation_params(
231
+ 3,
232
+ tensor_quantization_params,
233
+ 'arith.constant2',
234
+ [_QuantTransformation.NO_QUANTIZE],
235
+ is_inbounding_tensor=True,
236
+ ) # bias tensor
237
+ self._test_tensor_transformation_params(
238
+ 3,
239
+ tensor_quantization_params,
240
+ 'sequential/dense/MatMul;sequential/dense/Relu;sequential/dense/BiasAdd',
241
+ [_QuantTransformation.NO_QUANTIZE],
242
+ is_inbounding_tensor=False,
243
+ ) # output tensor
244
+
245
+ # Second FC
246
+ self._test_tensor_transformation_params(
247
+ 4,
248
+ tensor_quantization_params,
249
+ 'sequential/dense/MatMul;sequential/dense/Relu;sequential/dense/BiasAdd',
250
+ [_QuantTransformation.NO_QUANTIZE],
251
+ is_inbounding_tensor=True,
252
+ ) # input tensor
253
+ self._test_tensor_transformation_params(
254
+ 4,
255
+ tensor_quantization_params,
256
+ 'arith.constant',
257
+ [
258
+ _QuantTransformation.ADD_DEQUANTIZE,
259
+ ],
260
+ is_inbounding_tensor=True,
261
+ num_bits=8,
262
+ granularity=_QuantGranularity.CHANNELWISE,
263
+ symmetric=False,
264
+ ) # weight tensor
265
+ self._test_tensor_transformation_params(
266
+ 4,
267
+ tensor_quantization_params,
268
+ 'sequential/dense_1/MatMul',
269
+ [_QuantTransformation.NO_QUANTIZE],
270
+ is_inbounding_tensor=False,
271
+ ) # output tensor
272
+
273
+ # Model output tensor
274
+ tensor_name = 'StatefulPartitionedCall:0'
275
+ self._test_tensor_transformation_params(
276
+ 5,
277
+ tensor_quantization_params,
278
+ tensor_name,
279
+ [_QuantTransformation.NO_QUANTIZE],
280
+ is_inbounding_tensor=False,
281
+ )
282
+ # Output tensor is consumed by the virtual Output op.
283
+ transformation_config = tensor_quantization_params[tensor_name]
284
+ self.assertLen(transformation_config.consumers, 1)
285
+ consumer = transformation_config.consumers[0]
286
+ self.assertEqual(consumer.subgraph_op_id, -1)
287
+
288
+ # TODO: b/330770656 - expand the test to cover mixed activation precision.
289
+ def test_generate_config_selective(self):
290
+ # Choose scope regex using Model Explorer
291
+ selective_quantization_recipe = [
292
+ {
293
+ 'regex': '.*/dense/.*',
294
+ 'operation': 'FULLY_CONNECTED',
295
+ 'algorithm_key': 'min_max_uniform_quantize',
296
+ 'op_config': {
297
+ 'weight_tensor_config': {
298
+ 'dtype': _TensorDataType.INT,
299
+ 'num_bits': 8,
300
+ 'symmetric': True,
301
+ 'granularity': _QuantGranularity.CHANNELWISE,
302
+ },
303
+ # Equivalent to DRQ.
304
+ 'compute_precision': _ComputePrecision.INTEGER,
305
+ 'explicit_dequantize': False,
306
+ },
307
+ },
308
+ {
309
+ 'regex': '.*/dense_1/.*',
310
+ 'operation': 'FULLY_CONNECTED',
311
+ 'algorithm_key': 'min_max_uniform_quantize',
312
+ 'op_config': {
313
+ 'weight_tensor_config': {
314
+ 'dtype': _TensorDataType.INT,
315
+ 'num_bits': 4,
316
+ 'symmetric': False,
317
+ 'granularity': _QuantGranularity.TENSORWISE,
318
+ },
319
+ # Equivalent to WEIGHT_ONLY.
320
+ 'compute_precision': _ComputePrecision.FLOAT,
321
+ 'explicit_dequantize': True,
322
+ },
323
+ },
324
+ ]
325
+ self._recipe_manager.load_quantization_recipe(selective_quantization_recipe)
326
+ tensor_quantization_params = (
327
+ self._params_generator.generate_quantization_parameters(
328
+ self._recipe_manager
329
+ )
330
+ )
331
+ # FC weights for scope "dense"
332
+ self._test_tensor_transformation_params(
333
+ 3,
334
+ tensor_quantization_params,
335
+ 'arith.constant1',
336
+ [_QuantTransformation.QUANTIZE_TENSOR],
337
+ is_inbounding_tensor=True,
338
+ num_bits=8,
339
+ granularity=_QuantGranularity.CHANNELWISE,
340
+ symmetric=True,
341
+ )
342
+
343
+ # FC weights for scope "dense1"
344
+ self._test_tensor_transformation_params(
345
+ 4,
346
+ tensor_quantization_params,
347
+ 'arith.constant',
348
+ [
349
+ _QuantTransformation.ADD_DEQUANTIZE,
350
+ ],
351
+ is_inbounding_tensor=True,
352
+ num_bits=4,
353
+ granularity=_QuantGranularity.TENSORWISE,
354
+ symmetric=False,
355
+ )
356
+
357
+ def test_generate_config_edge_cases(self):
358
+
359
+ selective_quantization_recipe = [
360
+ # Use the tensor name as scope directly.
361
+ {
362
+ 'regex': 'sequential/dense_1/MatMul',
363
+ 'operation': 'FULLY_CONNECTED',
364
+ 'algorithm_key': 'min_max_uniform_quantize',
365
+ 'op_config': {
366
+ 'weight_tensor_config': {
367
+ 'num_bits': 8,
368
+ 'symmetric': True,
369
+ 'granularity': _QuantGranularity.CHANNELWISE,
370
+ },
371
+ # Equivalent to DRQ.
372
+ 'compute_precision': _ComputePrecision.INTEGER,
373
+ },
374
+ },
375
+ # Scope that does not exist in the model.
376
+ {
377
+ 'regex': '.*/dense_3/.*',
378
+ 'operation': 'FULLY_CONNECTED',
379
+ 'algorithm_key': 'min_max_uniform_quantize',
380
+ 'op_config': {
381
+ 'weight_tensor_config': {
382
+ 'num_bits': 4,
383
+ 'symmetric': False,
384
+ 'granularity': _QuantGranularity.TENSORWISE,
385
+ },
386
+ # Equivalent to WEIGHT_ONLY.
387
+ 'compute_precision': _ComputePrecision.FLOAT,
388
+ 'explicit_dequantize': True,
389
+ },
390
+ },
391
+ ]
392
+ self._recipe_manager.load_quantization_recipe(selective_quantization_recipe)
393
+ tensor_quantization_params = (
394
+ self._params_generator.generate_quantization_parameters(
395
+ self._recipe_manager
396
+ )
397
+ )
398
+ # Only the second FC will be quantized
399
+ self._test_tensor_transformation_params(
400
+ 3,
401
+ tensor_quantization_params,
402
+ 'arith.constant1',
403
+ [_QuantTransformation.NO_QUANTIZE],
404
+ is_inbounding_tensor=True,
405
+ )
406
+
407
+ self._test_tensor_transformation_params(
408
+ 4,
409
+ tensor_quantization_params,
410
+ 'arith.constant',
411
+ [_QuantTransformation.QUANTIZE_TENSOR],
412
+ is_inbounding_tensor=True,
413
+ num_bits=8,
414
+ granularity=_QuantGranularity.CHANNELWISE,
415
+ symmetric=True,
416
+ )
417
+
418
+ @parameterized.parameters(
419
+ (True, _QuantGranularity.CHANNELWISE),
420
+ (True, _QuantGranularity.TENSORWISE),
421
+ (False, _QuantGranularity.CHANNELWISE),
422
+ (False, _QuantGranularity.TENSORWISE),
423
+ )
424
+ def test_generate_config_int8xint8_single_fc(
425
+ self, act_symmetric, channelwise_weight
426
+ ):
427
+ single_fc_model_path = os.path.join(
428
+ TEST_DATA_PREFIX_PATH, 'tests/models/single_fc.tflite'
429
+ )
430
+ self._recipe_manager.add_quantization_config(
431
+ regex='.*',
432
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
433
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
434
+ op_config=qtyping.OpQuantizationConfig(
435
+ activation_tensor_config=_TensorQuantConfig(
436
+ num_bits=8, symmetric=act_symmetric
437
+ ),
438
+ weight_tensor_config=_TensorQuantConfig(
439
+ num_bits=8, symmetric=True, granularity=channelwise_weight
440
+ ),
441
+ # Equivalent to SRQ.
442
+ compute_precision=_ComputePrecision.INTEGER,
443
+ ),
444
+ )
445
+
446
+ params_generator_single_fc = params_generator.ParamsGenerator(
447
+ single_fc_model_path
448
+ )
449
+ # Raise error when missing QSVs.
450
+ error_message = 'Model quantization statistics values (QSVs) are required'
451
+ with self.assertRaisesWithPredicateMatch(
452
+ RuntimeError, lambda err: error_message in str(err)
453
+ ):
454
+ params_generator_single_fc.generate_quantization_parameters(
455
+ self._recipe_manager
456
+ )
457
+
458
+ # Calibrate then quantize
459
+ model_calibrator = calibrator.Calibrator(single_fc_model_path)
460
+ calibration_data = _get_calibration_data(
461
+ _single_fc_model_representative_dataset_gen()
462
+ )
463
+ model_calibrator.calibrate(calibration_data, self._recipe_manager)
464
+ model_qsvs = model_calibrator.get_model_qsvs()
465
+ quant_params = params_generator_single_fc.generate_quantization_parameters(
466
+ self._recipe_manager,
467
+ model_qsvs,
468
+ )
469
+ self.assertLen(quant_params, 4)
470
+
471
+ # Input tensor producer (from the virtual input op).
472
+ self._test_tensor_transformation_params(
473
+ -1, # virtual input op.
474
+ quant_params,
475
+ 'serving_default_input_1:0',
476
+ [_QuantTransformation.ADD_DEQUANTIZE],
477
+ num_bits=8,
478
+ granularity=_QuantGranularity.TENSORWISE,
479
+ symmetric=act_symmetric,
480
+ is_inbounding_tensor=False,
481
+ )
482
+ # Input tensor consumer.
483
+ self._test_tensor_transformation_params(
484
+ 0,
485
+ quant_params,
486
+ 'serving_default_input_1:0',
487
+ [_QuantTransformation.ADD_QUANTIZE],
488
+ num_bits=8,
489
+ granularity=_QuantGranularity.TENSORWISE,
490
+ symmetric=act_symmetric,
491
+ is_inbounding_tensor=True,
492
+ )
493
+
494
+ # output tensor producer.
495
+ self._test_tensor_transformation_params(
496
+ 0,
497
+ quant_params,
498
+ 'StatefulPartitionedCall:0',
499
+ [_QuantTransformation.ADD_DEQUANTIZE],
500
+ num_bits=8,
501
+ granularity=_QuantGranularity.TENSORWISE,
502
+ symmetric=act_symmetric,
503
+ is_inbounding_tensor=False,
504
+ )
505
+ # output tensor consumer (into the virtual output op).
506
+ self._test_tensor_transformation_params(
507
+ -1, # virtual output op.
508
+ quant_params,
509
+ 'StatefulPartitionedCall:0',
510
+ [_QuantTransformation.ADD_QUANTIZE],
511
+ num_bits=8,
512
+ granularity=_QuantGranularity.TENSORWISE,
513
+ symmetric=act_symmetric,
514
+ is_inbounding_tensor=True,
515
+ )
516
+
517
+ # weights
518
+ self._test_tensor_transformation_params(
519
+ 0,
520
+ quant_params,
521
+ 'sequential/dense/MatMul',
522
+ [_QuantTransformation.QUANTIZE_TENSOR],
523
+ num_bits=8,
524
+ granularity=channelwise_weight,
525
+ symmetric=True,
526
+ is_inbounding_tensor=True,
527
+ )
528
+
529
+ # bias
530
+ self._test_tensor_transformation_params(
531
+ 0,
532
+ quant_params,
533
+ 'sequential/dense/BiasAdd/ReadVariableOp',
534
+ [_QuantTransformation.QUANTIZE_TENSOR],
535
+ num_bits=32,
536
+ granularity=channelwise_weight,
537
+ symmetric=True,
538
+ is_inbounding_tensor=True,
539
+ )
540
+
541
+ @parameterized.parameters('weight_only', 'DRQ')
542
+ def test_generate_params_buffer_sharing_graphs_succeeds(
543
+ self, the_other_fc_difference
544
+ ):
545
+ model_path = os.path.join(
546
+ TEST_DATA_PREFIX_PATH, 'tests/models/weight_sharing_fcs.tflite'
547
+ )
548
+ self._recipe_manager.add_quantization_config(
549
+ regex='.*',
550
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
551
+ op_config=qtyping.OpQuantizationConfig(
552
+ weight_tensor_config=_TensorQuantConfig(num_bits=8, symmetric=True),
553
+ # Equivalent to WEIGHT_ONLY.
554
+ compute_precision=_ComputePrecision.FLOAT,
555
+ explicit_dequantize=True,
556
+ ),
557
+ )
558
+ if the_other_fc_difference == 'DRQ':
559
+ self._recipe_manager.add_quantization_config(
560
+ regex='PartitionedCall_1:0',
561
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
562
+ op_config=qtyping.OpQuantizationConfig(
563
+ # Equivalent to DRQ.
564
+ compute_precision=_ComputePrecision.INTEGER,
565
+ ),
566
+ )
567
+ pg = params_generator.ParamsGenerator(model_path)
568
+ quant_params = pg.generate_quantization_parameters(
569
+ self._recipe_manager,
570
+ )
571
+ self.assertLen(quant_params, 6)
572
+
573
+ @parameterized.parameters('no_quant', 'execution_mode', 'num_bits')
574
+ def test_generate_params_buffer_sharing_graphs_fails(
575
+ self, the_other_fc_difference
576
+ ):
577
+ model_path = os.path.join(
578
+ TEST_DATA_PREFIX_PATH, 'tests/models/weight_sharing_fcs.tflite'
579
+ )
580
+ # Setup the quantization config for the first FC.
581
+ self._recipe_manager.add_quantization_config(
582
+ regex='PartitionedCall:0',
583
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
584
+ op_config=qtyping.OpQuantizationConfig(
585
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
586
+ compute_precision=_ComputePrecision.INTEGER,
587
+ ),
588
+ )
589
+ # Setup the quantization config for the second FC (weight shared with the
590
+ # first FC).
591
+ if the_other_fc_difference == 'no_quant':
592
+ pass
593
+ elif the_other_fc_difference == 'num_bits':
594
+ self._recipe_manager.add_quantization_config(
595
+ regex='PartitionedCall_1:0',
596
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
597
+ op_config=qtyping.OpQuantizationConfig(
598
+ weight_tensor_config=_TensorQuantConfig(num_bits=4),
599
+ compute_precision=_ComputePrecision.INTEGER,
600
+ ),
601
+ )
602
+ pg = params_generator.ParamsGenerator(model_path)
603
+ error_message = 'do not have the same quantization parameters'
604
+ with self.assertRaisesWithPredicateMatch(
605
+ RuntimeError, lambda err: error_message in str(err)
606
+ ):
607
+ pg.generate_quantization_parameters(
608
+ self._recipe_manager,
609
+ )
610
+
611
+ @parameterized.named_parameters(
612
+ dict(
613
+ testcase_name='producer_incompatible',
614
+ param1=qtyping.TensorTransformationParams(
615
+ tensor_name='tfl.quantize',
616
+ producer=qtyping.OpToTensorParams(
617
+ subgraph_op_id=0,
618
+ transformations=[qtyping.QuantTransformation.ADD_DEQUANTIZE],
619
+ parameters=qtyping.UniformQuantParams(
620
+ 8, None, np.array([1]), np.array([0])
621
+ ),
622
+ ),
623
+ consumers=[
624
+ qtyping.OpToTensorParams(
625
+ subgraph_op_id=1,
626
+ transformations=[
627
+ qtyping.QuantTransformation.ADD_QUANTIZE
628
+ ],
629
+ parameters=qtyping.UniformQuantParams(
630
+ 8, None, np.array([1]), np.array([0])
631
+ ),
632
+ ),
633
+ qtyping.OpToTensorParams(
634
+ subgraph_op_id=2,
635
+ transformations=[
636
+ qtyping.QuantTransformation.ADD_QUANTIZE,
637
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
638
+ ],
639
+ parameters=qtyping.UniformQuantParams(
640
+ 8, None, np.array([1]), np.array([0])
641
+ ),
642
+ ),
643
+ qtyping.OpToTensorParams(
644
+ subgraph_op_id=3,
645
+ transformations=[
646
+ qtyping.QuantTransformation.ADD_QUANTIZE,
647
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
648
+ ],
649
+ parameters=qtyping.UniformQuantParams(
650
+ 8, None, np.array([1]), np.array([0])
651
+ ),
652
+ ),
653
+ qtyping.OpToTensorParams(
654
+ subgraph_op_id=4,
655
+ transformations=[
656
+ qtyping.QuantTransformation.NO_QUANTIZE,
657
+ ],
658
+ parameters=qtyping.UniformQuantParams(
659
+ 8, None, np.array([1]), np.array([0])
660
+ ),
661
+ ),
662
+ ],
663
+ ),
664
+ param2=qtyping.TensorTransformationParams(
665
+ 'tfl.other_quantize',
666
+ qtyping.OpToTensorParams(
667
+ subgraph_op_id=0,
668
+ transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
669
+ parameters=qtyping.UniformQuantParams(
670
+ 8, None, np.array([1]), np.array([0])
671
+ ),
672
+ ),
673
+ [
674
+ qtyping.OpToTensorParams(
675
+ subgraph_op_id=1,
676
+ transformations=[
677
+ qtyping.QuantTransformation.ADD_QUANTIZE
678
+ ],
679
+ parameters=qtyping.UniformQuantParams(
680
+ 8, None, np.array([1]), np.array([0])
681
+ ),
682
+ ),
683
+ qtyping.OpToTensorParams(
684
+ subgraph_op_id=2,
685
+ transformations=[
686
+ qtyping.QuantTransformation.ADD_QUANTIZE,
687
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
688
+ ],
689
+ parameters=qtyping.UniformQuantParams(
690
+ 8, None, np.array([1]), np.array([0])
691
+ ),
692
+ ),
693
+ qtyping.OpToTensorParams(
694
+ subgraph_op_id=3,
695
+ transformations=[
696
+ qtyping.QuantTransformation.ADD_QUANTIZE,
697
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
698
+ ],
699
+ parameters=qtyping.UniformQuantParams(
700
+ 8, None, np.array([1]), np.array([0])
701
+ ),
702
+ ),
703
+ ],
704
+ ),
705
+ expected=False,
706
+ ),
707
+ dict(
708
+ testcase_name='param2_consumer_incompatible',
709
+ param1=qtyping.TensorTransformationParams(
710
+ tensor_name='tfl.quantize',
711
+ producer=qtyping.OpToTensorParams(
712
+ subgraph_op_id=0,
713
+ transformations=[qtyping.QuantTransformation.ADD_QUANTIZE],
714
+ parameters=qtyping.UniformQuantParams(
715
+ 8, None, np.array([1]), np.array([0])
716
+ ),
717
+ ),
718
+ consumers=[
719
+ qtyping.OpToTensorParams(
720
+ subgraph_op_id=1,
721
+ transformations=[
722
+ qtyping.QuantTransformation.ADD_QUANTIZE
723
+ ],
724
+ parameters=qtyping.UniformQuantParams(
725
+ 8, None, np.array([1]), np.array([0])
726
+ ),
727
+ ),
728
+ qtyping.OpToTensorParams(
729
+ subgraph_op_id=2,
730
+ transformations=[
731
+ qtyping.QuantTransformation.ADD_QUANTIZE,
732
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
733
+ ],
734
+ parameters=qtyping.UniformQuantParams(
735
+ 8, None, np.array([1]), np.array([0])
736
+ ),
737
+ ),
738
+ qtyping.OpToTensorParams(
739
+ subgraph_op_id=3,
740
+ transformations=[
741
+ qtyping.QuantTransformation.ADD_QUANTIZE,
742
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
743
+ ],
744
+ parameters=qtyping.UniformQuantParams(
745
+ 8, None, np.array([1]), np.array([0])
746
+ ),
747
+ ),
748
+ ],
749
+ ),
750
+ param2=qtyping.TensorTransformationParams(
751
+ 'tfl.other_quantize',
752
+ qtyping.OpToTensorParams(
753
+ subgraph_op_id=0,
754
+ transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
755
+ parameters=qtyping.UniformQuantParams(
756
+ 8, None, np.array([1]), np.array([0])
757
+ ),
758
+ ),
759
+ [
760
+ qtyping.OpToTensorParams(
761
+ subgraph_op_id=1,
762
+ transformations=[
763
+ qtyping.QuantTransformation.ADD_QUANTIZE
764
+ ],
765
+ parameters=qtyping.UniformQuantParams(
766
+ 8, None, np.array([1]), np.array([0])
767
+ ),
768
+ ),
769
+ qtyping.OpToTensorParams(
770
+ subgraph_op_id=2,
771
+ transformations=[
772
+ qtyping.QuantTransformation.ADD_QUANTIZE,
773
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
774
+ ],
775
+ parameters=qtyping.UniformQuantParams(
776
+ 8, None, np.array([1]), np.array([0])
777
+ ),
778
+ ),
779
+ qtyping.OpToTensorParams(
780
+ subgraph_op_id=3,
781
+ transformations=[
782
+ qtyping.QuantTransformation.ADD_QUANTIZE,
783
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
784
+ ],
785
+ parameters=qtyping.UniformQuantParams(
786
+ 8, None, np.array([1]), np.array([0])
787
+ ),
788
+ ),
789
+ qtyping.OpToTensorParams(
790
+ subgraph_op_id=4,
791
+ transformations=[
792
+ qtyping.QuantTransformation.QUANTIZE_TENSOR,
793
+ ],
794
+ parameters=qtyping.UniformQuantParams(
795
+ 8, None, np.array([1]), np.array([0])
796
+ ),
797
+ ),
798
+ ],
799
+ ),
800
+ expected=False,
801
+ ),
802
+ dict(
803
+ testcase_name='compatible',
804
+ param1=qtyping.TensorTransformationParams(
805
+ tensor_name='tfl.quantize',
806
+ producer=None,
807
+ consumers=[
808
+ qtyping.OpToTensorParams(
809
+ subgraph_op_id=2,
810
+ transformations=[
811
+ qtyping.QuantTransformation.ADD_QUANTIZE,
812
+ ],
813
+ parameters=qtyping.UniformQuantParams(
814
+ 8, None, np.array([1]), np.array([0])
815
+ ),
816
+ ),
817
+ qtyping.OpToTensorParams(
818
+ subgraph_op_id=3,
819
+ transformations=[
820
+ qtyping.QuantTransformation.NO_QUANTIZE,
821
+ qtyping.QuantTransformation.ADD_QUANTIZE,
822
+ ],
823
+ parameters=qtyping.UniformQuantParams(
824
+ 8, None, np.array([1]), np.array([0])
825
+ ),
826
+ ),
827
+ qtyping.OpToTensorParams(
828
+ subgraph_op_id=4,
829
+ transformations=[
830
+ qtyping.QuantTransformation.NO_QUANTIZE,
831
+ ],
832
+ ),
833
+ ],
834
+ ),
835
+ param2=qtyping.TensorTransformationParams(
836
+ 'tfl.other_quantize',
837
+ None,
838
+ [
839
+ qtyping.OpToTensorParams(
840
+ subgraph_op_id=1,
841
+ transformations=[
842
+ qtyping.QuantTransformation.ADD_QUANTIZE
843
+ ],
844
+ parameters=qtyping.UniformQuantParams(
845
+ 8, None, np.array([1]), np.array([0])
846
+ ),
847
+ ),
848
+ qtyping.OpToTensorParams(
849
+ subgraph_op_id=2,
850
+ transformations=[
851
+ qtyping.QuantTransformation.ADD_QUANTIZE,
852
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
853
+ ],
854
+ parameters=qtyping.UniformQuantParams(
855
+ 8, None, np.array([1]), np.array([0])
856
+ ),
857
+ ),
858
+ qtyping.OpToTensorParams(
859
+ subgraph_op_id=3,
860
+ transformations=[
861
+ qtyping.QuantTransformation.ADD_QUANTIZE,
862
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
863
+ ],
864
+ parameters=qtyping.UniformQuantParams(
865
+ 8, None, np.array([1]), np.array([0])
866
+ ),
867
+ ),
868
+ qtyping.OpToTensorParams(
869
+ subgraph_op_id=4,
870
+ transformations=[
871
+ qtyping.QuantTransformation.ADD_QUANTIZE,
872
+ ],
873
+ parameters=qtyping.UniformQuantParams(
874
+ 8, None, np.array([1]), np.array([0])
875
+ ),
876
+ ),
877
+ ],
878
+ ),
879
+ expected=True,
880
+ ),
881
+ )
882
+ def test_params_compatible(self, param1, param2, expected):
883
+ # adding a test to make production coverage happy.
884
+ self.assertEqual(
885
+ params_generator._compatible_tensor_transformation_params(
886
+ param1, param2
887
+ ),
888
+ expected,
889
+ )
890
+
891
+ def test_model_with_duplicated_tensor_names_fails(self):
892
+ model_path = os.path.join(
893
+ TEST_DATA_PREFIX_PATH, 'tests/models/duplicated_tensor_names.tflite'
894
+ )
895
+ error_message = 'Tensor name test_same_name is not unique in the model.'
896
+ with self.assertRaisesWithPredicateMatch(
897
+ ValueError, lambda err: error_message in str(err)
898
+ ):
899
+ params_generator.ParamsGenerator(model_path)
900
+
901
+ def test_quantize_integer_input_output(self):
902
+ model_path = os.path.join(
903
+ TEST_DATA_PREFIX_PATH, 'tests/models/single_transpose_int32.tflite'
904
+ )
905
+ self._recipe_manager.add_quantization_config(
906
+ regex='.*',
907
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
908
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
909
+ op_config=qtyping.OpQuantizationConfig(
910
+ activation_tensor_config=_TensorQuantConfig(
911
+ num_bits=8, symmetric=False
912
+ ),
913
+ weight_tensor_config=_TensorQuantConfig(num_bits=8, symmetric=True),
914
+ # Equivalent to SRQ.
915
+ compute_precision=_ComputePrecision.INTEGER,
916
+ ),
917
+ )
918
+ pg = params_generator.ParamsGenerator(model_path)
919
+
920
+ # Calibrate then quantize.
921
+ model_calibrator = calibrator.Calibrator(model_path)
922
+ calibration_data = _get_calibration_data(
923
+ _int_transpose_model_representative_dataset_gen()
924
+ )
925
+ model_calibrator.calibrate(calibration_data, self._recipe_manager)
926
+ model_qsvs = model_calibrator.get_model_qsvs()
927
+ quant_params = pg.generate_quantization_parameters(
928
+ self._recipe_manager,
929
+ model_qsvs,
930
+ )
931
+ self.assertLen(quant_params, 3)
932
+
933
+ self._test_tensor_transformation_params(
934
+ -1, # virtual input op.
935
+ quant_params,
936
+ 'serving_default_input_2:0',
937
+ [_QuantTransformation.NO_QUANTIZE],
938
+ is_inbounding_tensor=False,
939
+ )
940
+ # Input tensor consumer.
941
+ self._test_tensor_transformation_params(
942
+ 0,
943
+ quant_params,
944
+ 'serving_default_input_2:0',
945
+ [_QuantTransformation.NO_QUANTIZE],
946
+ is_inbounding_tensor=True,
947
+ )
948
+
949
+ # Output tensor producer.
950
+ self._test_tensor_transformation_params(
951
+ 0,
952
+ quant_params,
953
+ 'PartitionedCall:0',
954
+ [_QuantTransformation.NO_QUANTIZE],
955
+ is_inbounding_tensor=False,
956
+ )
957
+ # output tensor consumer (into the virtual output op).
958
+ self._test_tensor_transformation_params(
959
+ -1, # virtual output op.
960
+ quant_params,
961
+ 'PartitionedCall:0',
962
+ [_QuantTransformation.NO_QUANTIZE],
963
+ is_inbounding_tensor=True,
964
+ )
965
+
966
+ # perm
967
+ self._test_tensor_transformation_params(
968
+ 0,
969
+ quant_params,
970
+ 'sequential_1/permute_1/transpose/perm',
971
+ [_QuantTransformation.NO_QUANTIZE],
972
+ is_inbounding_tensor=True,
973
+ )
974
+
975
+ def _test_tensor_transformation_params(
976
+ self,
977
+ subgraph_op_id,
978
+ quant_params,
979
+ tensor_name,
980
+ transformations,
981
+ is_inbounding_tensor,
982
+ num_bits=8,
983
+ granularity=_QuantGranularity.TENSORWISE,
984
+ symmetric=True,
985
+ quantized_dimension=0,
986
+ ):
987
+ """Helper function to test tensor transformation parameters are correct."""
988
+ self.assertIn(tensor_name, quant_params)
989
+ transformation_config = quant_params[tensor_name]
990
+ self.assertEqual(transformation_config.tensor_name, tensor_name)
991
+ if is_inbounding_tensor:
992
+ self.assertLen(transformation_config.consumers, 1)
993
+ op_config = transformation_config.consumers[0]
994
+ else:
995
+ op_config = transformation_config.producer
996
+ self.assertIsNotNone(op_config)
997
+ self.assertEqual(op_config.subgraph_op_id, subgraph_op_id)
998
+ self.assertSequenceEqual(op_config.transformations, transformations)
999
+ if transformations == [_QuantTransformation.NO_QUANTIZE]:
1000
+ self.assertIsNone(op_config.parameters)
1001
+ else:
1002
+ quantization_params = op_config.parameters
1003
+ self.assertIsNotNone(quantization_params)
1004
+ if granularity is _QuantGranularity.CHANNELWISE:
1005
+ self.assertEqual(
1006
+ quantization_params.quantized_dimension, quantized_dimension
1007
+ )
1008
+ else:
1009
+ self.assertIsNone(quantization_params.quantized_dimension)
1010
+ self.assertEqual(quantization_params.num_bits, num_bits)
1011
+ if symmetric:
1012
+ self.assertEqual(np.sum(abs(quantization_params.zero_point)), 0)
1013
+ else:
1014
+ self.assertEqual(
1015
+ len(quantization_params.scale),
1016
+ len(quantization_params.zero_point),
1017
+ )
1018
+
1019
+
1020
+ class ParamsGeneratorAlreadyQuantizedModelTest(googletest.TestCase):
1021
+
1022
+ def test_check_is_float_model_succeeds_when_model_is_float(self):
1023
+ test_model_path = os.path.join(
1024
+ TEST_DATA_PREFIX_PATH, 'tests/models/conv_fc_mnist.tflite'
1025
+ )
1026
+ _ = params_generator.ParamsGenerator(test_model_path)
1027
+
1028
+ def test_check_is_float_model_raises_error_when_model_is_quantized(self):
1029
+ test_model_path = os.path.join(
1030
+ TEST_DATA_PREFIX_PATH, 'tests/models/mnist_quantized.tflite'
1031
+ )
1032
+ with self.assertRaisesRegex(
1033
+ ValueError,
1034
+ 'The input model for quantization parameters generation is not a float'
1035
+ ' model.',
1036
+ ):
1037
+ _ = params_generator.ParamsGenerator(test_model_path)
1038
+
1039
+
1040
+ if __name__ == '__main__':
1041
+ googletest.main()