ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,532 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import json
17
+ import os
18
+
19
+ from absl.testing import parameterized
20
+ import numpy as np
21
+
22
+ from tensorflow.python.platform import googletest
23
+ from ai_edge_quantizer import qtyping
24
+ from ai_edge_quantizer import quantizer
25
+ from ai_edge_quantizer.utils import test_utils
26
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
27
+
28
+ _ComputePrecision = qtyping.ComputePrecision
29
+ _TFLOpName = qtyping.TFLOperationName
30
+ _TensorQuantConfig = qtyping.TensorQuantizationConfig
31
+ _TensorDataType = qtyping.TensorDataType
32
+ _AlgorithmName = quantizer.AlgorithmName
33
+
34
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('')
35
+ _MULTI_SIGNATURE_CALIBRATION_DATASET = {
36
+ 'add': [{'x': np.array([2.0], dtype=np.float32)}],
37
+ 'multiply': [{'x': np.array([1.0], dtype=np.float32)}],
38
+ }
39
+ _RNG = np.random.default_rng(66)
40
+
41
+
42
+ def _get_calibration_data(num_samples: int = 16):
43
+ calibration_samples = []
44
+ for _ in range(num_samples):
45
+ calibration_samples.append(
46
+ {'conv2d_input': _RNG.uniform(size=(1, 28, 28, 1)).astype(np.float32)}
47
+ )
48
+ calibration_data = {
49
+ tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples,
50
+ }
51
+ return calibration_data
52
+
53
+
54
+ class QuantizerTest(parameterized.TestCase):
55
+
56
+ def setUp(self):
57
+ super().setUp()
58
+ self._tmp_save_path = self.create_tempdir().full_path
59
+ self._test_model_path = os.path.join(
60
+ TEST_DATA_PREFIX_PATH, 'tests/models/conv_fc_mnist.tflite'
61
+ )
62
+ self._test_recipe_path = os.path.join(
63
+ TEST_DATA_PREFIX_PATH,
64
+ 'recipes/default_af32w8float_recipe.json',
65
+ )
66
+ with open(self._test_recipe_path) as json_file:
67
+ self._test_recipe = json.load(json_file)
68
+ self._quantizer = quantizer.Quantizer(
69
+ self._test_model_path, self._test_recipe_path
70
+ )
71
+
72
+ def test_update_quantization_recipe_succeeds(self):
73
+ self._quantizer.load_quantization_recipe(self._test_recipe_path)
74
+ scope_regex = '.*/Dense/.*'
75
+ new_op_config = qtyping.OpQuantizationConfig(
76
+ weight_tensor_config=_TensorQuantConfig(num_bits=4, symmetric=True),
77
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
78
+ )
79
+ self._quantizer.update_quantization_recipe(
80
+ regex=scope_regex,
81
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
82
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
83
+ op_config=new_op_config,
84
+ )
85
+ updated_recipe = self._quantizer.get_quantization_recipe()
86
+ self.assertLen(updated_recipe, 2)
87
+
88
+ added_config = updated_recipe[-1]
89
+ self.assertEqual(added_config['regex'], scope_regex)
90
+ self.assertEqual(
91
+ added_config['op_config']['compute_precision'],
92
+ new_op_config.compute_precision,
93
+ )
94
+
95
+ def test_load_quantization_recipe_succeeds(self):
96
+ qt = quantizer.Quantizer(self._test_model_path, None)
97
+ qt.load_quantization_recipe(self._test_recipe_path)
98
+ self.assertEqual(qt.get_quantization_recipe(), self._test_recipe)
99
+
100
+ # Load a different recipe.
101
+ new_recipe_path = os.path.join(
102
+ TEST_DATA_PREFIX_PATH,
103
+ 'recipes/dynamic_wi8_afp32_recipe.json',
104
+ )
105
+ with open(new_recipe_path) as json_file:
106
+ new_recipe = json.load(json_file)
107
+ qt.load_quantization_recipe(new_recipe_path)
108
+ self.assertEqual(qt.get_quantization_recipe(), new_recipe)
109
+
110
+ @parameterized.parameters(
111
+ 'recipes/default_a8w8_recipe.json',
112
+ 'recipes/default_a16w8_recipe.json',
113
+ )
114
+ def test_calibrate_required_recipe_succeeds(self, recipe_path):
115
+ recipe_path = os.path.join(TEST_DATA_PREFIX_PATH, recipe_path)
116
+ self._quantizer.load_quantization_recipe(recipe_path)
117
+ self.assertTrue(self._quantizer.need_calibration)
118
+ # Calibrate with empty state.
119
+ calib_data = _get_calibration_data()
120
+ calibration_result = self._quantizer.calibrate(calib_data)
121
+ self.assertLen(calibration_result, 13)
122
+
123
+ @parameterized.parameters(
124
+ 'recipes/default_a8w8_recipe.json',
125
+ 'recipes/default_a16w8_recipe.json',
126
+ )
127
+ def test_reloaded_calibration_succeeds(self, recipe_path):
128
+ recipe_path = os.path.join(TEST_DATA_PREFIX_PATH, recipe_path)
129
+ self._quantizer.load_quantization_recipe(recipe_path)
130
+ calib_data = _get_calibration_data()
131
+ calibration_result = self._quantizer.calibrate(calib_data)
132
+ # Load and calibrate again.
133
+ updated_calibration_result = self._quantizer.calibrate(
134
+ calib_data, previous_calibration_result=calibration_result
135
+ )
136
+ self.assertLen(updated_calibration_result, 13)
137
+ self.assertNotEqual(
138
+ calibration_result['StatefulPartitionedCall:0'],
139
+ updated_calibration_result['StatefulPartitionedCall:0'],
140
+ )
141
+
142
+ @parameterized.parameters(
143
+ 'recipes/dynamic_legacy_wi8_afp32_recipe.json',
144
+ 'recipes/dynamic_wi8_afp32_recipe.json',
145
+ 'recipes/default_af32w8float_recipe.json',
146
+ )
147
+ def test_calibrate_nonrequired_recipe_succeeds(self, recipe_path):
148
+ recipe_path = os.path.join(TEST_DATA_PREFIX_PATH, recipe_path)
149
+ self._quantizer.load_quantization_recipe(recipe_path)
150
+ self.assertFalse(self._quantizer.need_calibration)
151
+ # Empty calibration result if no calibration is required.
152
+ calibration_result = self._quantizer.calibrate(_get_calibration_data())
153
+ self.assertEmpty(calibration_result)
154
+
155
+ def test_quantize_no_calibration_succeeds(self):
156
+ self._quantizer.load_quantization_recipe(self._test_recipe_path)
157
+ self.assertIsNone(self._quantizer._result.quantized_model)
158
+ quant_result = self._quantizer.quantize()
159
+ self.assertEqual(quant_result.recipe, self._test_recipe)
160
+ self.assertIsNotNone(quant_result.quantized_model)
161
+
162
+ @parameterized.parameters(
163
+ 'recipes/default_a8w8_recipe.json',
164
+ 'recipes/default_a16w8_recipe.json',
165
+ )
166
+ def test_quantize_calibration_needed_succeeds(self, recipe_path):
167
+ recipe_path = os.path.join(TEST_DATA_PREFIX_PATH, recipe_path)
168
+ with open(recipe_path) as json_file:
169
+ recipe = json.load(json_file)
170
+
171
+ self._quantizer.load_quantization_recipe(recipe_path)
172
+ self.assertTrue(self._quantizer.need_calibration)
173
+ calibration_result = self._quantizer.calibrate(_get_calibration_data())
174
+
175
+ self.assertIsNone(self._quantizer._result.quantized_model)
176
+ quant_result = self._quantizer.quantize(calibration_result)
177
+ self.assertEqual(quant_result.recipe, recipe)
178
+ self.assertIsNotNone(quant_result.quantized_model)
179
+
180
+ @parameterized.parameters(
181
+ 'recipes/default_a8w8_recipe.json',
182
+ 'recipes/default_a16w8_recipe.json',
183
+ )
184
+ def test_quantize_calibration_needed_raise_error(self, recipe_path):
185
+ recipe_path = os.path.join(TEST_DATA_PREFIX_PATH, recipe_path)
186
+
187
+ self._quantizer.load_quantization_recipe(recipe_path)
188
+ self.assertTrue(self._quantizer.need_calibration)
189
+ error_message = (
190
+ 'Model quantization statistics values (QSVs) are required for the input'
191
+ ' recipe.'
192
+ )
193
+ with self.assertRaisesWithPredicateMatch(
194
+ RuntimeError, lambda err: error_message in str(err)
195
+ ):
196
+ self._quantizer.quantize()
197
+
198
+ def test_quantize_no_recipe_raise_error(self):
199
+ qt = quantizer.Quantizer(self._test_model_path, None)
200
+ error_message = 'Can not quantize without a quantization recipe.'
201
+ with self.assertRaisesWithPredicateMatch(
202
+ RuntimeError, lambda err: error_message in str(err)
203
+ ):
204
+ qt.quantize()
205
+
206
+ def test_save_succeeds(self):
207
+ model_name = 'test_model'
208
+ self._quantizer.load_quantization_recipe(self._test_recipe_path)
209
+ result = self._quantizer.quantize()
210
+ result.save(self._tmp_save_path, model_name)
211
+ saved_recipe_path = os.path.join(
212
+ self._tmp_save_path, model_name + '_recipe.json'
213
+ )
214
+ with open(saved_recipe_path) as json_file:
215
+ saved_recipe = json.load(json_file)
216
+ self.assertEqual(saved_recipe, self._test_recipe)
217
+
218
+ def test_save_no_quantize_raise_error(self):
219
+ error_message = 'No quantized model to save.'
220
+ with self.assertRaisesWithPredicateMatch(
221
+ RuntimeError, lambda err: error_message in str(err)
222
+ ):
223
+ self._quantizer._result.save(self._tmp_save_path, 'test_model')
224
+
225
+ def test_export_model_succeeds(self):
226
+ model_name = 'exported_model'
227
+ self._quantizer.load_quantization_recipe(self._test_recipe_path)
228
+ result = self._quantizer.quantize()
229
+
230
+ exported_model_path = os.path.join(
231
+ self._tmp_save_path, model_name + '.tflite'
232
+ )
233
+ self.assertFalse(os.path.exists(exported_model_path))
234
+ result.export_model(exported_model_path)
235
+ self.assertTrue(os.path.exists(exported_model_path))
236
+
237
+ def test_compare_succeeds(self):
238
+ self._quantizer.quantize()
239
+ validation_result = self._quantizer.validate()
240
+ validation_result = validation_result.get_signature_comparison_result()
241
+ self.assertIsNotNone(validation_result)
242
+ self.assertIn(
243
+ 'sequential/dense_1/MatMul', validation_result.intermediate_tensors
244
+ )
245
+
246
+ def test_load_custom_policies_succeeds(self):
247
+
248
+ test_op_config = qtyping.OpQuantizationConfig(
249
+ weight_tensor_config=_TensorQuantConfig(num_bits=4, symmetric=True),
250
+ compute_precision=_ComputePrecision.INTEGER,
251
+ )
252
+
253
+ # Check if the quant config is supported by default policy.
254
+ self._quantizer.update_quantization_recipe(
255
+ regex='.*/Dense/.*',
256
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
257
+ op_config=test_op_config,
258
+ )
259
+
260
+ # Check if the quant config fails on dummy policy.
261
+ dummy_policy_path = test_utils.get_path_to_datafile(
262
+ 'policies/dummy_config_policy.json'
263
+ )
264
+ self._quantizer.load_config_policy(dummy_policy_path)
265
+ with self.assertRaisesRegex(
266
+ ValueError, 'Unsupported op for .*FULLY_CONNECTED'
267
+ ):
268
+ self._quantizer.update_quantization_recipe(
269
+ regex='.*/Dense/.*',
270
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
271
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
272
+ op_config=test_op_config,
273
+ )
274
+
275
+ # Check if the quant config is supported by example policy in *.json file.
276
+ default_policy_path = test_utils.get_path_to_datafile(
277
+ 'policies/example_config_policy.json'
278
+ )
279
+ self._quantizer.load_config_policy(default_policy_path)
280
+ self._quantizer.update_quantization_recipe(
281
+ regex='.*/Dense/.*',
282
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
283
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
284
+ op_config=test_op_config,
285
+ )
286
+
287
+
288
+ class QuantizerBytearrayInputs(googletest.TestCase):
289
+
290
+ def setUp(self):
291
+ super().setUp()
292
+ self._test_model_path = os.path.join(
293
+ TEST_DATA_PREFIX_PATH, 'tests/models/conv_fc_mnist.tflite'
294
+ )
295
+ self._test_recipe_path = os.path.join(
296
+ TEST_DATA_PREFIX_PATH,
297
+ 'recipes/default_af32w8float_recipe.json',
298
+ )
299
+ with open(self._test_model_path, 'rb') as f:
300
+ model_content = bytearray(f.read())
301
+ with open(self._test_recipe_path, 'r') as f:
302
+ self._test_recipe = json.load(f)
303
+ self._quantizer = quantizer.Quantizer(model_content, self._test_recipe)
304
+
305
+ def test_quantize_compare_succeeds(self):
306
+ quant_result = self._quantizer.quantize()
307
+ self.assertEqual(quant_result.recipe, self._test_recipe)
308
+ self.assertIsNotNone(quant_result.quantized_model)
309
+ validation_result = self._quantizer.validate()
310
+ validation_result = validation_result.get_signature_comparison_result()
311
+ self.assertIsNotNone(validation_result)
312
+ self.assertIn(
313
+ 'sequential/dense_1/MatMul', validation_result.intermediate_tensors
314
+ )
315
+
316
+ def test_compare_succeeds(self):
317
+ self._quantizer.quantize()
318
+ validation_result = self._quantizer.validate()
319
+ validation_result = validation_result.get_signature_comparison_result()
320
+ self.assertIsNotNone(validation_result)
321
+ self.assertIn(
322
+ 'sequential/dense_1/MatMul', validation_result.intermediate_tensors
323
+ )
324
+
325
+
326
+ # TODO: b/364974841 - Add more tests after multiple signatures are supported
327
+ # for calibrate and quantize.
328
+ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
329
+
330
+ def setUp(self):
331
+ super().setUp()
332
+ self._tmp_save_path = self.create_tempdir().full_path
333
+ self._test_model_path = os.path.join(
334
+ TEST_DATA_PREFIX_PATH, 'tests/models/two_signatures.tflite'
335
+ )
336
+ self._test_recipe_path = os.path.join(
337
+ TEST_DATA_PREFIX_PATH,
338
+ 'recipes/default_a8w8_recipe.json',
339
+ )
340
+ with open(self._test_recipe_path) as json_file:
341
+ self._test_recipe = json.load(json_file)
342
+ self._calibration_result = {
343
+ 'add_x:0': {'min': -2.0, 'max': 2.0},
344
+ 'PartitionedCall:0': {'min': -8.0, 'max': 12.0},
345
+ 'multiply_x:0': {'min': -2.0, 'max': 2.0},
346
+ 'PartitionedCall_1:0': {'min': -20.0, 'max': 20.0},
347
+ }
348
+ self._quantizer = quantizer.Quantizer(
349
+ self._test_model_path, self._test_recipe_path
350
+ )
351
+
352
+ @parameterized.named_parameters(
353
+ ('default_random_data', None),
354
+ ('specific_data', _MULTI_SIGNATURE_CALIBRATION_DATASET),
355
+ )
356
+ def test_validate_multiple_signatures_succeeds(self, test_data):
357
+ self._quantizer.quantize(self._calibration_result)
358
+ validation_result = self._quantizer.validate(test_data)
359
+ available_signatures = validation_result.available_signature_keys()
360
+ self.assertLen(available_signatures, 2)
361
+
362
+ add_result = validation_result.get_signature_comparison_result('add')
363
+ self.assertEqual('mse', add_result.error_metric)
364
+ self.assertIn('add_x:0', add_result.input_tensors)
365
+ self.assertIn('PartitionedCall:0', add_result.output_tensors)
366
+ self.assertIn('Add/y', add_result.constant_tensors)
367
+ self.assertEmpty(add_result.intermediate_tensors)
368
+
369
+ mul_result = validation_result.get_signature_comparison_result('multiply')
370
+ self.assertEqual('mse', mul_result.error_metric)
371
+ self.assertIn('multiply_x:0', mul_result.input_tensors)
372
+ self.assertIn('PartitionedCall_1:0', mul_result.output_tensors)
373
+ self.assertIn('Mul/y', mul_result.constant_tensors)
374
+ self.assertEmpty(mul_result.intermediate_tensors)
375
+
376
+ def test_validate_add_signature_succeeds(self):
377
+ test_data = {'add': [{'x': np.array([2.0]).astype(np.float32)}]}
378
+ self._quantizer.quantize(self._calibration_result)
379
+ validation_result = self._quantizer.validate(test_data)
380
+ available_signatures = validation_result.available_signature_keys()
381
+ self.assertLen(available_signatures, 1)
382
+ self.assertIn('add', available_signatures)
383
+ add_result = validation_result.get_signature_comparison_result('add')
384
+ self.assertEqual('mse', add_result.error_metric)
385
+ self.assertIn('add_x:0', add_result.input_tensors)
386
+ self.assertIn('PartitionedCall:0', add_result.output_tensors)
387
+ self.assertIn('Add/y', add_result.constant_tensors)
388
+ self.assertEmpty(add_result.intermediate_tensors)
389
+
390
+ def test_validate_multiply_signature_succeeds(self):
391
+ test_data = {'multiply': [{'x': np.array([1.0]).astype(np.float32)}]}
392
+ self._quantizer.quantize(self._calibration_result)
393
+ validation_result = self._quantizer.validate(test_data)
394
+ available_signatures = validation_result.available_signature_keys()
395
+ self.assertLen(available_signatures, 1)
396
+ self.assertIn('multiply', available_signatures)
397
+ mul_result = validation_result.get_signature_comparison_result('multiply')
398
+ self.assertEqual('mse', mul_result.error_metric)
399
+ self.assertIn('multiply_x:0', mul_result.input_tensors)
400
+ self.assertIn('PartitionedCall_1:0', mul_result.output_tensors)
401
+ self.assertIn('Mul/y', mul_result.constant_tensors)
402
+ self.assertEmpty(mul_result.intermediate_tensors)
403
+
404
+ def test_validate_quantize_after_calibration_succeeds(self):
405
+ calib_result = self._quantizer.calibrate(
406
+ _MULTI_SIGNATURE_CALIBRATION_DATASET
407
+ )
408
+ self._quantizer.quantize(calib_result)
409
+ validation_result = self._quantizer.validate(
410
+ _MULTI_SIGNATURE_CALIBRATION_DATASET
411
+ )
412
+ available_signatures = validation_result.available_signature_keys()
413
+ self.assertLen(available_signatures, 2)
414
+
415
+ def test_recipe_conflict_raises_error(self):
416
+ recipe = [
417
+ dict({
418
+ 'regex': '.*',
419
+ 'operation': 'ADD',
420
+ 'algorithm_key': 'min_max_uniform_quantize',
421
+ 'op_config': {
422
+ 'activation_tensor_config': {
423
+ 'num_bits': 8,
424
+ 'symmetric': False,
425
+ 'granularity': 'TENSORWISE',
426
+ 'dtype': 'INT',
427
+ 'block_size': 0,
428
+ },
429
+ 'weight_tensor_config': {
430
+ 'num_bits': 8,
431
+ 'symmetric': True,
432
+ 'granularity': 'CHANNELWISE',
433
+ 'dtype': 'INT',
434
+ 'block_size': 0,
435
+ },
436
+ 'compute_precision': 'INTEGER',
437
+ 'explicit_dequantize': False,
438
+ 'skip_checks': False,
439
+ },
440
+ })
441
+ ]
442
+
443
+ qt = quantizer.Quantizer(self._test_model_path, recipe)
444
+ calib_result = qt.calibrate(_MULTI_SIGNATURE_CALIBRATION_DATASET)
445
+
446
+ error_message = (
447
+ "The tensors b'Add/y' and b'Mul/y' do not have the same quantization"
448
+ )
449
+ with self.assertRaisesWithPredicateMatch(
450
+ RuntimeError, lambda err: error_message in str(err)
451
+ ):
452
+ qt.quantize(calib_result)
453
+
454
+ def test_quantization_with_insufficient_calibration(self):
455
+ # Run calibration for one signature only.
456
+ scarce_calibration_dataset = {
457
+ 'add': [{'x': np.array([2.0], dtype=np.float32)}],
458
+ }
459
+ calib_result = self._quantizer.calibrate(scarce_calibration_dataset)
460
+
461
+ # Quantize and expect an error about missing signature in calibration data.
462
+ error_message = (
463
+ 'Missing QSVs (min/max) for tensor multiply_x:0 in Signature'
464
+ " 'multiply'."
465
+ )
466
+ with self.assertRaisesWithPredicateMatch(
467
+ ValueError, lambda err: error_message in str(err)
468
+ ):
469
+ self._quantizer.quantize(calib_result)
470
+
471
+
472
+ class QuantizerToyGemma2Test(parameterized.TestCase):
473
+
474
+ def setUp(self):
475
+ super().setUp()
476
+ self._tmp_save_path = self.create_tempdir().full_path
477
+ self._test_model_path = os.path.join(
478
+ TEST_DATA_PREFIX_PATH,
479
+ 'tests/models/toy_model_with_kv_cache_multi_signature.tflite',
480
+ )
481
+
482
+ self._toy_gemma2_calibration_dataset = {
483
+ 'signature_1': [{
484
+ 'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
485
+ 'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
486
+ 'positions': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
487
+ np.int32
488
+ ),
489
+ 'tokens': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
490
+ np.int32
491
+ ),
492
+ }],
493
+ 'signature_2': [{
494
+ 'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
495
+ 'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
496
+ 'positions': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
497
+ np.int32
498
+ ),
499
+ 'tokens': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
500
+ np.int32
501
+ ),
502
+ }],
503
+ }
504
+
505
+ self._test_recipe_path = os.path.join(
506
+ TEST_DATA_PREFIX_PATH,
507
+ 'recipes/default_a8w8_recipe.json',
508
+ )
509
+ with open(self._test_recipe_path) as json_file:
510
+ self._test_recipe = json.load(json_file)
511
+
512
+ self._quantizer = quantizer.Quantizer(
513
+ self._test_model_path, self._test_recipe_path
514
+ )
515
+
516
+ self._quantizer.update_quantization_recipe(
517
+ regex='StatefulPartitionedCall',
518
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
519
+ algorithm_key=_AlgorithmName.NO_QUANTIZE,
520
+ )
521
+
522
+ def test_toy_gemma2_quantization_succeeds(self):
523
+ calib_result = self._quantizer.calibrate(
524
+ self._toy_gemma2_calibration_dataset
525
+ )
526
+ self.assertIsNotNone(calib_result)
527
+ self._quantizer.quantize(calib_result)
528
+ self.assertIsNotNone(self._quantizer._result.quantized_model)
529
+
530
+
531
+ if __name__ == '__main__':
532
+ googletest.main()
@@ -0,0 +1,67 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Quantization recipe module."""
17
+
18
+
19
+ def dynamic_wi8_afp32():
20
+ """Returns a dynamic quantization recipe with int8 weights and float32 activation."""
21
+ return [
22
+ dict({
23
+ 'regex': '.*',
24
+ 'operation': '*',
25
+ 'algorithm_key': 'min_max_uniform_quantize',
26
+ 'op_config': {
27
+ 'weight_tensor_config': {
28
+ 'num_bits': 8,
29
+ 'symmetric': True,
30
+ 'granularity': 'CHANNELWISE',
31
+ 'dtype': 'INT',
32
+ 'block_size': 0,
33
+ },
34
+ 'compute_precision': 'INTEGER',
35
+ 'explicit_dequantize': False,
36
+ 'skip_checks': False,
37
+ },
38
+ })
39
+ ]
40
+
41
+
42
+ def dynamic_legacy_wi8_afp32():
43
+ """Returns a dynamic quantization legacy recipe with int8 weights and float32 activation.
44
+
45
+ The difference between this and dynamic_wi8_afp32 is that this recipe sets
46
+ min_weight_elements to 1024 to match the old quantizer behavior.
47
+ """
48
+ return [
49
+ dict({
50
+ 'regex': '.*',
51
+ 'operation': '*',
52
+ 'algorithm_key': 'min_max_uniform_quantize',
53
+ 'op_config': {
54
+ 'weight_tensor_config': {
55
+ 'num_bits': 8,
56
+ 'symmetric': True,
57
+ 'granularity': 'CHANNELWISE',
58
+ 'dtype': 'INT',
59
+ 'block_size': 0,
60
+ },
61
+ 'compute_precision': 'INTEGER',
62
+ 'explicit_dequantize': False,
63
+ 'skip_checks': False,
64
+ 'min_weight_elements': 1024,
65
+ },
66
+ })
67
+ ]