ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,815 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Tests for recipe_manager.py."""
17
+
18
+ from absl.testing import parameterized
19
+ from tensorflow.python.platform import googletest
20
+ from ai_edge_quantizer import algorithm_manager
21
+ from ai_edge_quantizer import qtyping
22
+ from ai_edge_quantizer import recipe_manager
23
+
24
+ _ComputePrecision = qtyping.ComputePrecision
25
+ _TFLOpName = qtyping.TFLOperationName
26
+ _TensorQuantConfig = qtyping.TensorQuantizationConfig
27
+ _TensorDataType = qtyping.TensorDataType
28
+ _AlgorithmName = recipe_manager.AlgorithmName
29
+ _QuantGranularity = qtyping.QuantGranularity
30
+
31
+
32
+ # Sample functions for test cases.
33
+ def _sample_init_qsvs(*_, **__):
34
+ return 1.0, dict()
35
+
36
+
37
+ def _sample_calibration_func(*_, **__):
38
+ return 2.0, dict()
39
+
40
+
41
+ def _sample_materialize_func(*_, **__):
42
+ return 3.0, dict()
43
+
44
+
45
+ def _sample_check_op_config_func(op_name, op_config, _):
46
+ if (
47
+ op_config.weight_tensor_config is not None
48
+ and op_config.weight_tensor_config.num_bits == 17
49
+ ):
50
+ raise ValueError(f'Unsupported number of bits for op: {op_name}.')
51
+
52
+
53
+ def _add_default_int8xint8_integer_recipe(recipe_manager_object):
54
+ recipe_manager_object.add_quantization_config(
55
+ regex='.*',
56
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
57
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
58
+ op_config=qtyping.OpQuantizationConfig(
59
+ activation_tensor_config=_TensorQuantConfig(
60
+ num_bits=8, symmetric=False
61
+ ),
62
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
63
+ compute_precision=_ComputePrecision.INTEGER, # SRQ.
64
+ ),
65
+ )
66
+
67
+
68
+ # register some currently unsupported ops for testing purposes
69
+ def _register_testing_op(algorithm_key, tfl_op):
70
+ algorithm_manager.register_op_quant_config_validation_func(
71
+ algorithm_key, _sample_check_op_config_func
72
+ )
73
+ algorithm_manager.register_quantized_op(
74
+ algorithm_key,
75
+ tfl_op,
76
+ _sample_init_qsvs,
77
+ _sample_calibration_func,
78
+ _sample_materialize_func,
79
+ )
80
+
81
+
82
+ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
83
+ """Test cases for the flax quantizer Configurator."""
84
+
85
+ def setUp(self):
86
+ super().setUp()
87
+ self._recipe_manager = recipe_manager.RecipeManager()
88
+ self._testing_ops = [
89
+ _TFLOpName.BATCH_MATMUL,
90
+ _TFLOpName.FULLY_CONNECTED,
91
+ _TFLOpName.DEPTHWISE_CONV_2D,
92
+ ]
93
+ for op in self._testing_ops:
94
+ _register_testing_op(_AlgorithmName.MIN_MAX_UNIFORM_QUANT, op)
95
+ _register_testing_op('GPTQ', op)
96
+
97
+ def test_add_get_quantization_config(self):
98
+ # Int8 DRQ all ops under "Dense".
99
+ self._recipe_manager.add_quantization_config(
100
+ regex='.*/Dense/.*',
101
+ operation_name=_TFLOpName.ALL_SUPPORTED,
102
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
103
+ op_config=qtyping.OpQuantizationConfig(
104
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
105
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
106
+ ),
107
+ )
108
+
109
+ # Int8 weight-only FullyConnected configuration under "Dense_3".
110
+ self._recipe_manager.add_quantization_config(
111
+ regex='.*/Dense_3/.*',
112
+ operation_name=_TFLOpName.FULLY_CONNECTED,
113
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
114
+ op_config=qtyping.OpQuantizationConfig(
115
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
116
+ compute_precision=_ComputePrecision.FLOAT, # WEIGHT_ONLY.
117
+ explicit_dequantize=True,
118
+ ),
119
+ )
120
+ # Int4 DRQ BatchMatmul configuration under "Dense_3".
121
+ self._recipe_manager.add_quantization_config(
122
+ regex='.*/Dense_3/.*',
123
+ operation_name=_TFLOpName.BATCH_MATMUL,
124
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
125
+ op_config=qtyping.OpQuantizationConfig(
126
+ weight_tensor_config=_TensorQuantConfig(num_bits=4),
127
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
128
+ ),
129
+ )
130
+
131
+ # Return NO_QUANT if not match.
132
+ alg_key, _ = self._recipe_manager.get_quantization_configs(
133
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense_1/op'
134
+ )
135
+ self.assertEqual(alg_key, _AlgorithmName.NO_QUANTIZE)
136
+ alg_key, _ = self._recipe_manager.get_quantization_configs(
137
+ _TFLOpName.DEPTHWISE_CONV_2D, 'model/Dense_3/op'
138
+ )
139
+ self.assertEqual(alg_key, _AlgorithmName.NO_QUANTIZE)
140
+
141
+ # Check _TFLOperationKey.ALL
142
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
143
+ _TFLOpName.DEPTHWISE_CONV_2D, 'model/Dense/op'
144
+ )
145
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
146
+ # DRQ check.
147
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
148
+
149
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
150
+ _TFLOpName.BATCH_MATMUL, 'model/Dense/op'
151
+ )
152
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
153
+ # DRQ check.
154
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
155
+
156
+ # Check conflicts handling.
157
+ # Int8 Weight-only for FC under "Dense", this should only overwrite FC but
158
+ # leave others unchanged.
159
+ self._recipe_manager.add_quantization_config(
160
+ regex='.*/Dense/.*',
161
+ operation_name=_TFLOpName.FULLY_CONNECTED,
162
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
163
+ op_config=qtyping.OpQuantizationConfig(
164
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
165
+ compute_precision=_ComputePrecision.FLOAT, # WEIGHT_ONLY.
166
+ ),
167
+ )
168
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
169
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
170
+ )
171
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
172
+ # WEIGHT_ONLY check.
173
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.FLOAT)
174
+ alg_key, _ = self._recipe_manager.get_quantization_configs(
175
+ _TFLOpName.BATCH_MATMUL, 'model/Dense/op'
176
+ )
177
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
178
+
179
+ # Reset all ops, this time with 4 bits DRQ.
180
+ self._recipe_manager.add_quantization_config(
181
+ regex='.*/Dense/.*',
182
+ operation_name=_TFLOpName.ALL_SUPPORTED,
183
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
184
+ op_config=qtyping.OpQuantizationConfig(
185
+ weight_tensor_config=_TensorQuantConfig(num_bits=4),
186
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
187
+ ),
188
+ )
189
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
190
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
191
+ )
192
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
193
+ weight_tensor_config = op_config.weight_tensor_config
194
+ self.assertIsNotNone(weight_tensor_config)
195
+ # DRQ check.
196
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
197
+ self.assertEqual(weight_tensor_config.num_bits, 4)
198
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
199
+ _TFLOpName.BATCH_MATMUL, 'model/Dense/op'
200
+ )
201
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
202
+ # DRQ check.
203
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
204
+ self.assertEqual(weight_tensor_config.num_bits, 4)
205
+
206
+ # Overwrite all FC.
207
+ self._recipe_manager.add_quantization_config(
208
+ regex='.*',
209
+ operation_name=_TFLOpName.FULLY_CONNECTED,
210
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
211
+ op_config=qtyping.OpQuantizationConfig(
212
+ weight_tensor_config=_TensorQuantConfig(num_bits=3),
213
+ ),
214
+ )
215
+ # FC config is overridden.
216
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
217
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense_3/op'
218
+ )
219
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
220
+ weight_tensor_config = op_config.weight_tensor_config
221
+ self.assertIsNotNone(weight_tensor_config)
222
+ self.assertEqual(weight_tensor_config.num_bits, 3)
223
+ # No overridden for batch matmul.
224
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
225
+ _TFLOpName.BATCH_MATMUL, 'model/Dense_3/op'
226
+ )
227
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
228
+ weight_tensor_config = op_config.weight_tensor_config
229
+ self.assertIsNotNone(weight_tensor_config)
230
+ self.assertEqual(weight_tensor_config.num_bits, 4)
231
+
232
+ def test_add_unsupported_quantization_config(self):
233
+ error_message = 'Unsupported operation'
234
+ # Add unregistered operations.
235
+ with self.assertRaisesWithPredicateMatch(
236
+ ValueError, lambda err: error_message in str(err)
237
+ ):
238
+ self._recipe_manager.add_quantization_config(
239
+ regex='.*/Dense/.*',
240
+ operation_name=_TFLOpName.CUSTOM_OP,
241
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
242
+ op_config=qtyping.OpQuantizationConfig(
243
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
244
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
245
+ ),
246
+ )
247
+ # Add unregistered algorithm
248
+ with self.assertRaisesWithPredicateMatch(
249
+ ValueError, lambda err: error_message in str(err)
250
+ ):
251
+ self._recipe_manager.add_quantization_config(
252
+ regex='.*/Dense/.*',
253
+ operation_name=_TFLOpName.FULLY_CONNECTED,
254
+ algorithm_key='AWQ',
255
+ op_config=qtyping.OpQuantizationConfig(
256
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
257
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
258
+ ),
259
+ )
260
+
261
+ def test_add_unsupported_num_bits_raise_error(self):
262
+ test_op_name = _TFLOpName.FULLY_CONNECTED
263
+ error_message = f'Unsupported number of bits for op: {test_op_name}.'
264
+ # Add unregistered operation
265
+ with self.assertRaisesWithPredicateMatch(
266
+ ValueError, lambda err: error_message in str(err)
267
+ ):
268
+ self._recipe_manager.add_quantization_config(
269
+ regex='.*/Dense/.*',
270
+ operation_name=test_op_name,
271
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
272
+ op_config=qtyping.OpQuantizationConfig(
273
+ weight_tensor_config=_TensorQuantConfig(num_bits=17),
274
+ ),
275
+ )
276
+
277
+ def test_add_unsupported_skip_successful(self):
278
+ self._recipe_manager.add_quantization_config(
279
+ regex='.*/Dense_3/.*',
280
+ operation_name=_TFLOpName.FULLY_CONNECTED,
281
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
282
+ op_config=qtyping.OpQuantizationConfig(
283
+ weight_tensor_config=_TensorQuantConfig(num_bits=17),
284
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
285
+ skip_checks=True,
286
+ ),
287
+ )
288
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
289
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense_3/op'
290
+ )
291
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
292
+ self.assertIsNone(op_config.activation_tensor_config)
293
+ weight_tensor_config = op_config.weight_tensor_config
294
+ self.assertIsNotNone(weight_tensor_config)
295
+ self.assertEqual(weight_tensor_config.num_bits, 17)
296
+ # DRQ check.
297
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
298
+
299
+ def test_set_full_integer_quantization_config(self):
300
+ _add_default_int8xint8_integer_recipe(self._recipe_manager)
301
+ # Full integer setting is global
302
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
303
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense_3/op'
304
+ )
305
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
306
+ op_act_config = op_config.activation_tensor_config
307
+ self.assertIsNotNone(op_act_config)
308
+ self.assertEqual(op_act_config.num_bits, 8)
309
+ self.assertEqual(op_act_config.symmetric, False)
310
+ self.assertEqual(
311
+ op_act_config.granularity,
312
+ _QuantGranularity.TENSORWISE,
313
+ )
314
+ weight_tensor_config = op_config.weight_tensor_config
315
+ self.assertIsNotNone(weight_tensor_config)
316
+ self.assertEqual(weight_tensor_config.num_bits, 8)
317
+ self.assertEqual(weight_tensor_config.symmetric, True)
318
+ self.assertEqual(
319
+ weight_tensor_config.granularity,
320
+ _QuantGranularity.TENSORWISE,
321
+ )
322
+
323
+ # Change weight settings for Dense_3 FC
324
+ self._recipe_manager.add_quantization_config(
325
+ regex='.*/Dense_3/.*',
326
+ operation_name=_TFLOpName.FULLY_CONNECTED,
327
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
328
+ op_config=qtyping.OpQuantizationConfig(
329
+ weight_tensor_config=_TensorQuantConfig(num_bits=3),
330
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
331
+ ),
332
+ )
333
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
334
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense_3/op'
335
+ )
336
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
337
+ self.assertIsNone(op_config.activation_tensor_config)
338
+ weight_tensor_config = op_config.weight_tensor_config
339
+ self.assertIsNotNone(weight_tensor_config)
340
+ self.assertEqual(weight_tensor_config.num_bits, 3)
341
+ # WEIGHT_ONLY check.
342
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
343
+
344
+ # Change the global setting to int16
345
+ self._recipe_manager.add_quantization_config(
346
+ regex='.*',
347
+ operation_name=_TFLOpName.ALL_SUPPORTED,
348
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
349
+ op_config=qtyping.OpQuantizationConfig(
350
+ activation_tensor_config=_TensorQuantConfig(
351
+ num_bits=16, symmetric=True
352
+ ),
353
+ weight_tensor_config=_TensorQuantConfig(num_bits=8, symmetric=True),
354
+ compute_precision=_ComputePrecision.INTEGER, # SRQ.
355
+ ),
356
+ )
357
+ # This does not impact the special dense_3 case
358
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
359
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense_3/op'
360
+ )
361
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
362
+ self.assertIsNone(op_config.activation_tensor_config)
363
+ self.assertIsNotNone(weight_tensor_config)
364
+ self.assertEqual(weight_tensor_config.num_bits, 3)
365
+ # WEIGHT_ONLY check.
366
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
367
+
368
+ # All the others will be int16
369
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
370
+ _TFLOpName.CONV_2D, 'model/Dense_31/op'
371
+ )
372
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
373
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
374
+ op_act_config = op_config.activation_tensor_config
375
+ self.assertIsNotNone(op_act_config)
376
+ weight_tensor_config = op_config.weight_tensor_config
377
+ self.assertIsNotNone(weight_tensor_config)
378
+ self.assertEqual(op_act_config.num_bits, 16)
379
+ self.assertEqual(op_act_config.symmetric, True)
380
+ self.assertEqual(
381
+ op_act_config.granularity,
382
+ _QuantGranularity.TENSORWISE,
383
+ )
384
+ self.assertEqual(weight_tensor_config.num_bits, 8)
385
+ self.assertEqual(weight_tensor_config.symmetric, True)
386
+ self.assertEqual(
387
+ weight_tensor_config.granularity,
388
+ _QuantGranularity.TENSORWISE,
389
+ )
390
+
391
+ def test_get_full_quantization_config(self):
392
+ # Int8 asymetric full integer model.
393
+ _add_default_int8xint8_integer_recipe(self._recipe_manager)
394
+ # Default all BMM.
395
+ self._recipe_manager.add_quantization_config(
396
+ regex='.*',
397
+ operation_name=_TFLOpName.BATCH_MATMUL,
398
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
399
+ op_config=qtyping.OpQuantizationConfig(
400
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
401
+ compute_precision=_ComputePrecision.FLOAT, # WEIGHT_ONLY.
402
+ explicit_dequantize=True,
403
+ ),
404
+ )
405
+
406
+ # Int8 DRQ FULLY_CONNECTED ops under "Dense".
407
+ self._recipe_manager.add_quantization_config(
408
+ regex='.*/Dense/.*',
409
+ operation_name=_TFLOpName.FULLY_CONNECTED,
410
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
411
+ op_config=qtyping.OpQuantizationConfig(
412
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
413
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
414
+ ),
415
+ )
416
+
417
+ # Overwrite DRQ ALL ops under "Dense".
418
+ self._recipe_manager.add_quantization_config(
419
+ regex='.*/Dense/.*',
420
+ operation_name=_TFLOpName.ALL_SUPPORTED,
421
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
422
+ op_config=qtyping.OpQuantizationConfig(
423
+ weight_tensor_config=_TensorQuantConfig(num_bits=4),
424
+ compute_precision=_ComputePrecision.FLOAT, # WEIGHT_ONLY.
425
+ explicit_dequantize=True,
426
+ ),
427
+ )
428
+
429
+ # Overwrite "Dense_1" to only quantize FullyConnected.
430
+ self._recipe_manager.add_quantization_config(
431
+ regex='.*/Dense_1/.*',
432
+ operation_name=_TFLOpName.FULLY_CONNECTED,
433
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
434
+ op_config=qtyping.OpQuantizationConfig(
435
+ weight_tensor_config=_TensorQuantConfig(num_bits=6),
436
+ compute_precision=_ComputePrecision.FLOAT, # WEIGHT_ONLY.
437
+ explicit_dequantize=True,
438
+ ),
439
+ )
440
+
441
+ # Add BMM to "Dense_1".
442
+ self._recipe_manager.add_quantization_config(
443
+ regex='.*/Dense_1/.*',
444
+ operation_name=_TFLOpName.BATCH_MATMUL,
445
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
446
+ op_config=qtyping.OpQuantizationConfig(
447
+ weight_tensor_config=_TensorQuantConfig(num_bits=3),
448
+ compute_precision=_ComputePrecision.FLOAT, # WEIGHT_ONLY.
449
+ explicit_dequantize=True,
450
+ ),
451
+ )
452
+
453
+ expected_full_quantization_config = [
454
+ {
455
+ 'regex': '.*',
456
+ 'operation': '*',
457
+ 'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
458
+ 'op_config': {
459
+ 'activation_tensor_config': {
460
+ 'num_bits': 8,
461
+ 'symmetric': False,
462
+ 'granularity': _QuantGranularity.TENSORWISE,
463
+ 'dtype': 'INT',
464
+ 'block_size': 0,
465
+ },
466
+ 'weight_tensor_config': {
467
+ 'num_bits': 8,
468
+ 'symmetric': True,
469
+ 'granularity': _QuantGranularity.TENSORWISE,
470
+ 'dtype': 'INT',
471
+ 'block_size': 0,
472
+ },
473
+ # WEIGHT_ONLY.
474
+ 'compute_precision': _ComputePrecision.INTEGER,
475
+ 'explicit_dequantize': False,
476
+ 'skip_checks': False,
477
+ 'min_weight_elements': 0,
478
+ },
479
+ },
480
+ {
481
+ 'regex': '.*',
482
+ 'operation': 'BATCH_MATMUL',
483
+ 'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
484
+ 'op_config': {
485
+ 'weight_tensor_config': {
486
+ 'dtype': 'INT',
487
+ 'num_bits': 8,
488
+ 'symmetric': True,
489
+ 'granularity': _QuantGranularity.TENSORWISE,
490
+ 'block_size': 0,
491
+ },
492
+ # WEIGHT_ONLY.
493
+ 'compute_precision': _ComputePrecision.FLOAT,
494
+ 'explicit_dequantize': True,
495
+ 'skip_checks': False,
496
+ 'min_weight_elements': 0,
497
+ },
498
+ },
499
+ {
500
+ 'regex': '.*/Dense/.*',
501
+ 'operation': '*',
502
+ 'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
503
+ 'op_config': {
504
+ 'weight_tensor_config': {
505
+ 'dtype': 'INT',
506
+ 'num_bits': 4,
507
+ 'symmetric': True,
508
+ 'granularity': _QuantGranularity.TENSORWISE,
509
+ 'block_size': 0,
510
+ },
511
+ # WEIGHT_ONLY.
512
+ 'compute_precision': _ComputePrecision.FLOAT,
513
+ 'explicit_dequantize': True,
514
+ 'skip_checks': False,
515
+ 'min_weight_elements': 0,
516
+ },
517
+ },
518
+ {
519
+ 'regex': '.*/Dense_1/.*',
520
+ 'operation': 'FULLY_CONNECTED',
521
+ 'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
522
+ 'op_config': {
523
+ 'weight_tensor_config': {
524
+ 'dtype': 'INT',
525
+ 'num_bits': 6,
526
+ 'symmetric': True,
527
+ 'granularity': _QuantGranularity.TENSORWISE,
528
+ 'block_size': 0,
529
+ },
530
+ # WEIGHT_ONLY.
531
+ 'compute_precision': _ComputePrecision.FLOAT,
532
+ 'explicit_dequantize': True,
533
+ 'skip_checks': False,
534
+ 'min_weight_elements': 0,
535
+ },
536
+ },
537
+ {
538
+ 'regex': '.*/Dense_1/.*',
539
+ 'operation': 'BATCH_MATMUL',
540
+ 'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
541
+ 'op_config': {
542
+ 'weight_tensor_config': {
543
+ 'dtype': 'INT',
544
+ 'num_bits': 3,
545
+ 'symmetric': True,
546
+ 'granularity': _QuantGranularity.TENSORWISE,
547
+ 'block_size': 0,
548
+ },
549
+ # WEIGHT_ONLY.
550
+ 'compute_precision': _ComputePrecision.FLOAT,
551
+ 'explicit_dequantize': True,
552
+ 'skip_checks': False,
553
+ 'min_weight_elements': 0,
554
+ },
555
+ },
556
+ ]
557
+ self.assertEqual(
558
+ expected_full_quantization_config,
559
+ self._recipe_manager.get_quantization_recipe(),
560
+ )
561
+
562
+ def test_get_quantization_configs_with_no_quantize_overwrite(self):
563
+ self._recipe_manager.add_quantization_config(
564
+ regex='.*/Dense/.*',
565
+ operation_name=_TFLOpName.ALL_SUPPORTED,
566
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
567
+ op_config=qtyping.OpQuantizationConfig(
568
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
569
+ ),
570
+ )
571
+ self._recipe_manager.add_quantization_config(
572
+ regex='.*/Dense/.*',
573
+ operation_name=_TFLOpName.FULLY_CONNECTED,
574
+ algorithm_key=_AlgorithmName.NO_QUANTIZE,
575
+ )
576
+
577
+ # Fully connected will be overwritten to no quantization.
578
+ alg_key, _ = self._recipe_manager.get_quantization_configs(
579
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
580
+ )
581
+ self.assertEqual(alg_key, _AlgorithmName.NO_QUANTIZE)
582
+ # Other ops will be quantized.
583
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
584
+ _TFLOpName.CONV_2D, 'model/Dense/op'
585
+ )
586
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
587
+ weight_tensor_config = op_config.weight_tensor_config
588
+ self.assertIsNotNone(weight_tensor_config)
589
+ # WEIGHT_ONLY check.
590
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.FLOAT)
591
+ self.assertEqual(weight_tensor_config.num_bits, 8)
592
+
593
+ def test_load_from_full_quantization_config(self):
594
+ full_quantization_config = [
595
+ {
596
+ 'regex': '.*',
597
+ 'operation': 'BATCH_MATMUL',
598
+ 'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
599
+ 'op_config': {
600
+ 'weight_tensor_config': {
601
+ 'dtype': 'INT',
602
+ 'num_bits': 8,
603
+ 'symmetric': True,
604
+ 'granularity': _QuantGranularity.CHANNELWISE,
605
+ },
606
+ # WEIGHT_ONLY.
607
+ 'compute_precision': _ComputePrecision.FLOAT,
608
+ 'explicit_dequantize': False,
609
+ },
610
+ },
611
+ {
612
+ 'regex': '.*/Dense/.*',
613
+ 'operation': '*',
614
+ 'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
615
+ 'op_config': {
616
+ 'weight_tensor_config': {
617
+ 'dtype': 'INT',
618
+ 'num_bits': 4,
619
+ 'symmetric': False,
620
+ 'granularity': _QuantGranularity.CHANNELWISE,
621
+ },
622
+ # DRQ.
623
+ 'compute_precision': _ComputePrecision.INTEGER,
624
+ },
625
+ },
626
+ ]
627
+ self._recipe_manager.load_quantization_recipe(full_quantization_config)
628
+
629
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
630
+ _TFLOpName.BATCH_MATMUL, 'model/Dense10/op'
631
+ )
632
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
633
+ weight_tensor_config = op_config.weight_tensor_config
634
+ self.assertIsNotNone(weight_tensor_config)
635
+ # WEIGHT_ONLY check.
636
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.FLOAT)
637
+ self.assertEqual(weight_tensor_config.num_bits, 8)
638
+
639
+ # Dense will be overwritten by the last setting
640
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
641
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
642
+ )
643
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
644
+ weight_tensor_config = op_config.weight_tensor_config
645
+ self.assertIsNotNone(weight_tensor_config)
646
+ # DRQ check.
647
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
648
+ self.assertEqual(weight_tensor_config.num_bits, 4)
649
+
650
+ def test_get_unsupported_op_fall_back_to_default(self):
651
+ self._recipe_manager.add_quantization_config(
652
+ regex='.*/Dense/.*',
653
+ operation_name=_TFLOpName.ALL_SUPPORTED,
654
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
655
+ op_config=qtyping.OpQuantizationConfig(
656
+ weight_tensor_config=_TensorQuantConfig(num_bits=17),
657
+ ),
658
+ )
659
+ alg_key, _ = self._recipe_manager.get_quantization_configs(
660
+ _TFLOpName.BATCH_MATMUL, 'model/Dense10/op'
661
+ )
662
+ # int17 is not supported, fall back to float.
663
+ self.assertEqual(alg_key, _AlgorithmName.NO_QUANTIZE)
664
+
665
+ def test_load_from_full_quantization_config_full_integer(self):
666
+ full_quantization_config = [
667
+ {
668
+ 'regex': '.*',
669
+ 'operation': '*',
670
+ 'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
671
+ 'op_config': {
672
+ 'activation_tensor_config': {
673
+ 'num_bits': 8,
674
+ 'symmetric': False,
675
+ 'granularity': _QuantGranularity.TENSORWISE,
676
+ 'dtype': 'INT',
677
+ },
678
+ 'weight_tensor_config': {
679
+ 'num_bits': 8,
680
+ 'symmetric': True,
681
+ 'granularity': _QuantGranularity.TENSORWISE,
682
+ 'dtype': 'INT',
683
+ },
684
+ # SRQ.
685
+ 'compute_precision': _ComputePrecision.INTEGER,
686
+ },
687
+ },
688
+ {
689
+ 'regex': '.*',
690
+ 'operation': 'BATCH_MATMUL',
691
+ 'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
692
+ 'op_config': {
693
+ 'weight_tensor_config': {
694
+ 'dtype': 'INT',
695
+ 'num_bits': 8,
696
+ 'symmetric': True,
697
+ 'granularity': _QuantGranularity.CHANNELWISE,
698
+ },
699
+ # WEIGHT_ONLY.
700
+ 'compute_precision': _ComputePrecision.FLOAT,
701
+ 'explicit_dequantize': True,
702
+ },
703
+ },
704
+ {
705
+ 'regex': '.*/Dense/.*',
706
+ 'operation': '*',
707
+ 'algorithm_key': _AlgorithmName.MIN_MAX_UNIFORM_QUANT,
708
+ 'op_config': {
709
+ 'weight_tensor_config': {
710
+ 'dtype': 'INT',
711
+ 'num_bits': 4,
712
+ 'symmetric': False,
713
+ 'granularity': _QuantGranularity.CHANNELWISE,
714
+ },
715
+ # DRQ.
716
+ 'compute_precision': _ComputePrecision.INTEGER,
717
+ },
718
+ },
719
+ ]
720
+ self._recipe_manager.load_quantization_recipe(full_quantization_config)
721
+
722
+ # BMMs will be overridden to weight-only
723
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
724
+ _TFLOpName.BATCH_MATMUL, 'model/Dense10/op'
725
+ )
726
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
727
+ self.assertIsNone(op_config.activation_tensor_config)
728
+ weight_tensor_config = op_config.weight_tensor_config
729
+ self.assertIsNotNone(weight_tensor_config)
730
+ # WEIGHT_ONLY check.
731
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.FLOAT)
732
+ self.assertEqual(weight_tensor_config.num_bits, 8)
733
+
734
+ # Dense will be overwritten by the last setting
735
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
736
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
737
+ )
738
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
739
+ self.assertIsNone(op_config.activation_tensor_config)
740
+ weight_tensor_config = op_config.weight_tensor_config
741
+ self.assertIsNotNone(weight_tensor_config)
742
+ # DRQ check.
743
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
744
+ self.assertEqual(weight_tensor_config.num_bits, 4)
745
+
746
+ # Other ops will have default quantization settings
747
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
748
+ _TFLOpName.CONV_2D, 'model/Dense11/op'
749
+ )
750
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
751
+ op_act_config = op_config.activation_tensor_config
752
+ self.assertIsNotNone(op_act_config)
753
+ self.assertEqual(op_act_config.num_bits, 8)
754
+ weight_tensor_config = op_config.weight_tensor_config
755
+ self.assertIsNotNone(weight_tensor_config)
756
+ # SRQ check.
757
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
758
+ self.assertEqual(weight_tensor_config.num_bits, 8)
759
+
760
+ def test_need_calibration_false(self):
761
+ self._recipe_manager.add_quantization_config(
762
+ regex='.*/Dense_1/.*',
763
+ operation_name=_TFLOpName.FULLY_CONNECTED,
764
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
765
+ op_config=qtyping.OpQuantizationConfig(
766
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
767
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
768
+ ),
769
+ )
770
+ self._recipe_manager.add_quantization_config(
771
+ regex='.*/Dense_2/.*',
772
+ operation_name=_TFLOpName.CONV_2D,
773
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
774
+ op_config=qtyping.OpQuantizationConfig(
775
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
776
+ compute_precision=_ComputePrecision.FLOAT, # WEIGHT_ONLY.
777
+ explicit_dequantize=True,
778
+ ),
779
+ )
780
+ self.assertFalse(self._recipe_manager.need_calibration())
781
+
782
+ def test_need_calibration_true(self):
783
+ self._recipe_manager.add_quantization_config(
784
+ regex='.*/Dense_1/.*',
785
+ operation_name=_TFLOpName.FULLY_CONNECTED,
786
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
787
+ op_config=qtyping.OpQuantizationConfig(
788
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
789
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
790
+ ),
791
+ )
792
+ self._recipe_manager.add_quantization_config(
793
+ regex='.*/Dense_2/.*',
794
+ operation_name=_TFLOpName.CONV_2D,
795
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
796
+ op_config=qtyping.OpQuantizationConfig(
797
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
798
+ compute_precision=_ComputePrecision.FLOAT, # WEIGHT_ONLY.
799
+ ),
800
+ )
801
+ self._recipe_manager.add_quantization_config(
802
+ regex='.*/Dense_3/.*',
803
+ operation_name=_TFLOpName.BATCH_MATMUL,
804
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
805
+ op_config=qtyping.OpQuantizationConfig(
806
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
807
+ activation_tensor_config=_TensorQuantConfig(num_bits=8),
808
+ compute_precision=_ComputePrecision.INTEGER, # SRQ.
809
+ ),
810
+ )
811
+ self.assertTrue(self._recipe_manager.need_calibration())
812
+
813
+
814
+ if __name__ == '__main__':
815
+ googletest.main()