ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,354 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import json
17
+ import os
18
+ from absl import flags
19
+ import numpy as np
20
+ from tensorflow.python.platform import googletest
21
+ from ai_edge_quantizer import model_validator
22
+ from ai_edge_quantizer.utils import test_utils
23
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
24
+ from ai_edge_quantizer.utils import validation_utils
25
+
26
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('.')
27
+
28
+
29
+ class ComparisonResultTest(googletest.TestCase):
30
+
31
+ def setUp(self):
32
+ # TODO: b/358437395 - Remove this line once the bug is fixed.
33
+ flags.FLAGS.mark_as_parsed()
34
+ super().setUp()
35
+ self.test_model_path = os.path.join(
36
+ TEST_DATA_PREFIX_PATH, 'tests/models/two_signatures.tflite'
37
+ )
38
+ self.test_model = tfl_flatbuffer_utils.get_model_buffer(
39
+ self.test_model_path
40
+ )
41
+ self.test_quantized_model_path = os.path.join(
42
+ TEST_DATA_PREFIX_PATH,
43
+ 'tests/models/two_signatures_a8w8.tflite',
44
+ )
45
+ self.test_quantized_model = tfl_flatbuffer_utils.get_model_buffer(
46
+ self.test_quantized_model_path
47
+ )
48
+ self.test_data = {
49
+ 'add': {'add_x:0': 1e-3, 'Add/y': 0.25, 'PartitionedCall:0': 1e-3},
50
+ 'multiply': {
51
+ 'multiply_x:0': 1e-3,
52
+ 'Mul/y': 0.32,
53
+ 'PartitionedCall_1:0': 1e-2,
54
+ },
55
+ }
56
+ self.test_dir = self.create_tempdir()
57
+ self.comparison_result = model_validator.ComparisonResult(
58
+ self.test_model, self.test_quantized_model
59
+ )
60
+
61
+ def test_add_new_signature_results_succeeds(self):
62
+ for signature_key, test_result in self.test_data.items():
63
+ self.comparison_result.add_new_signature_results(
64
+ 'mean_squared_difference',
65
+ test_result,
66
+ signature_key,
67
+ )
68
+ self.assertLen(
69
+ self.comparison_result.available_signature_keys(), len(self.test_data)
70
+ )
71
+
72
+ for signature_key in self.test_data:
73
+ signature_result = self.comparison_result.get_signature_comparison_result(
74
+ signature_key
75
+ )
76
+ input_tensors = signature_result.input_tensors
77
+ output_tensors = signature_result.output_tensors
78
+ constant_tensors = signature_result.constant_tensors
79
+ intermediate_tensors = signature_result.intermediate_tensors
80
+
81
+ self.assertLen(input_tensors, 1)
82
+ self.assertLen(output_tensors, 1)
83
+ self.assertLen(constant_tensors, 1)
84
+ self.assertEmpty(intermediate_tensors)
85
+
86
+ def test_add_new_signature_results_fails_same_signature_key(self):
87
+ self.comparison_result.add_new_signature_results(
88
+ 'mean_squared_difference',
89
+ self.test_data['add'],
90
+ 'add',
91
+ )
92
+ error_message = 'add is already in the comparison_results.'
93
+ with self.assertRaisesWithPredicateMatch(
94
+ ValueError, lambda err: error_message in str(err)
95
+ ):
96
+ self.comparison_result.add_new_signature_results(
97
+ 'mean_squared_difference',
98
+ self.test_data['add'],
99
+ 'add',
100
+ )
101
+
102
+ def test_get_signature_comparison_result_fails_with_invalid_signature_key(
103
+ self,
104
+ ):
105
+ self.comparison_result.add_new_signature_results(
106
+ 'mean_squared_difference',
107
+ self.test_data['add'],
108
+ 'add',
109
+ )
110
+ error_message = 'multiply is not in the comparison_results.'
111
+ with self.assertRaisesWithPredicateMatch(
112
+ ValueError, lambda err: error_message in str(err)
113
+ ):
114
+ self.comparison_result.get_signature_comparison_result('multiply')
115
+
116
+ def test_get_all_tensor_results_succeeds(self):
117
+ for signature_key, test_result in self.test_data.items():
118
+ self.comparison_result.add_new_signature_results(
119
+ 'mean_squared_difference',
120
+ test_result,
121
+ signature_key,
122
+ )
123
+ all_tensor_results = self.comparison_result.get_all_tensor_results()
124
+ self.assertLen(all_tensor_results, 6)
125
+ self.assertIn('add_x:0', all_tensor_results)
126
+ self.assertIn('Add/y', all_tensor_results)
127
+ self.assertIn('PartitionedCall:0', all_tensor_results)
128
+ self.assertIn('multiply_x:0', all_tensor_results)
129
+ self.assertIn('Mul/y', all_tensor_results)
130
+ self.assertIn('PartitionedCall_1:0', all_tensor_results)
131
+
132
+ def test_save_comparison_result_succeeds(self):
133
+ for signature_key, test_result in self.test_data.items():
134
+ self.comparison_result.add_new_signature_results(
135
+ 'mean_squared_difference',
136
+ test_result,
137
+ signature_key,
138
+ )
139
+ model_name = 'test_model'
140
+ self.comparison_result.save(self.test_dir.full_path, model_name)
141
+ test_json_path = os.path.join(
142
+ self.test_dir.full_path, model_name + '_comparison_result.json'
143
+ )
144
+ with open(test_json_path) as json_file:
145
+ json_dict = json.load(json_file)
146
+
147
+ # Check model size stats.
148
+ self.assertIn('reduced_size_bytes', json_dict)
149
+ self.assertEqual(
150
+ json_dict['reduced_size_bytes'],
151
+ len(self.test_model) - len(self.test_quantized_model),
152
+ )
153
+ self.assertIn('reduced_size_percentage', json_dict)
154
+ self.assertEqual(
155
+ json_dict['reduced_size_percentage'],
156
+ (len(self.test_model) - len(self.test_quantized_model))
157
+ / len(self.test_model)
158
+ * 100,
159
+ )
160
+
161
+ for signature_key in self.test_data:
162
+ self.assertIn(signature_key, json_dict)
163
+ signature_result = json_dict[signature_key]
164
+ self.assertIn('error_metric', signature_result)
165
+ self.assertEqual(
166
+ signature_result['error_metric'], 'mean_squared_difference'
167
+ )
168
+ self.assertIn('constant_tensors', signature_result)
169
+ if signature_key == 'add':
170
+ self.assertIn('Add/y', signature_result['constant_tensors'])
171
+ self.assertNotIn('Mul/y', signature_result['constant_tensors'])
172
+ elif signature_key == 'multiply':
173
+ self.assertIn('Mul/y', signature_result['constant_tensors'])
174
+ self.assertNotIn('Add/y', signature_result['constant_tensors'])
175
+
176
+
177
+ class ModelValidatorCompareTest(googletest.TestCase):
178
+
179
+ def setUp(self):
180
+ # TODO: b/358437395 - Remove this line once the bug is fixed.
181
+ flags.FLAGS.mark_as_parsed()
182
+ super().setUp()
183
+ self.reference_model_path = os.path.join(
184
+ TEST_DATA_PREFIX_PATH, 'tests/models/single_fc_bias.tflite'
185
+ )
186
+ self.target_model_path = os.path.join(
187
+ TEST_DATA_PREFIX_PATH,
188
+ 'tests/models/single_fc_bias_sub_channel_weight_only_sym_weight.tflite',
189
+ )
190
+ self.reference_model = tfl_flatbuffer_utils.get_model_buffer(
191
+ self.reference_model_path
192
+ )
193
+ self.target_model = tfl_flatbuffer_utils.get_model_buffer(
194
+ self.target_model_path
195
+ )
196
+ self.signature_key = 'serving_default' # single signature.
197
+ self.test_data = test_utils.create_random_normal_input_data(
198
+ self.reference_model_path
199
+ )
200
+ self.test_dir = self.create_tempdir()
201
+
202
+ def test_model_validator_compare(self):
203
+ error_metric = 'mean_squared_difference'
204
+ comparison_result = model_validator.compare_model(
205
+ self.reference_model,
206
+ self.target_model,
207
+ self.test_data,
208
+ error_metric,
209
+ validation_utils.mean_squared_difference,
210
+ )
211
+ result = comparison_result.get_signature_comparison_result(
212
+ self.signature_key
213
+ )
214
+ self.assertEqual(result.error_metric, 'mean_squared_difference')
215
+ input_tensors = result.input_tensors
216
+ output_tensors = result.output_tensors
217
+ constant_tensors = result.constant_tensors
218
+ intermediate_tensors = result.intermediate_tensors
219
+
220
+ self.assertLen(input_tensors, 1)
221
+ self.assertLen(output_tensors, 1)
222
+ self.assertLen(constant_tensors, 2)
223
+ self.assertEmpty(intermediate_tensors)
224
+
225
+ self.assertAlmostEqual(input_tensors['serving_default_input_2:0'], 0)
226
+ self.assertAlmostEqual(constant_tensors['arith.constant1'], 0)
227
+ self.assertLess(output_tensors['StatefulPartitionedCall:0'], 1e-5)
228
+
229
+ def test_create_json_for_model_explorer(self):
230
+ error_metric = 'mean_squared_difference'
231
+ comparison_result = model_validator.compare_model(
232
+ self.reference_model,
233
+ self.target_model,
234
+ self.test_data,
235
+ error_metric,
236
+ validation_utils.mean_squared_difference,
237
+ )
238
+ mv_json = model_validator.create_json_for_model_explorer(
239
+ comparison_result, [0, 1, 2, 3]
240
+ )
241
+ self.assertContainsSubset(
242
+ '"thresholds": [{"value": 0, "bgColor": "rgb(200, 0, 0)"}, {"value":'
243
+ ' 1, "bgColor": "rgb(200, 63, 0)"}, {"value": 2, "bgColor": "rgb(200,'
244
+ ' 126, 0)"}, {"value": 3, "bgColor": "rgb(200, 189, 0)"}]',
245
+ mv_json,
246
+ )
247
+
248
+ def test_create_json_for_model_explorer_no_thresholds(self):
249
+ error_metric = 'mean_squared_difference'
250
+ comparison_result = model_validator.compare_model(
251
+ self.reference_model,
252
+ self.target_model,
253
+ self.test_data,
254
+ error_metric,
255
+ validation_utils.mean_squared_difference,
256
+ )
257
+ mv_json = model_validator.create_json_for_model_explorer(
258
+ comparison_result, []
259
+ )
260
+ self.assertContainsSubset('"thresholds": []', mv_json)
261
+
262
+
263
+ class ModelValidatorMultiSignatureModelTest(googletest.TestCase):
264
+
265
+ def setUp(self):
266
+ # TODO: b/358437395 - Remove this line once the bug is fixed.
267
+ flags.FLAGS.mark_as_parsed()
268
+ super().setUp()
269
+ self.reference_model_path = os.path.join(
270
+ TEST_DATA_PREFIX_PATH, 'tests/models/two_signatures.tflite'
271
+ )
272
+ self.target_model_path = os.path.join(
273
+ TEST_DATA_PREFIX_PATH,
274
+ 'tests/models/two_signatures_a8w8.tflite',
275
+ )
276
+ self.reference_model = tfl_flatbuffer_utils.get_model_buffer(
277
+ self.reference_model_path
278
+ )
279
+ self.target_model = tfl_flatbuffer_utils.get_model_buffer(
280
+ self.target_model_path
281
+ )
282
+ self.test_data = {
283
+ 'add': [{'x': np.array([2.0]).astype(np.float32)}],
284
+ 'multiply': [{'x': np.array([1.0]).astype(np.float32)}],
285
+ }
286
+ self.test_dir = self.create_tempdir()
287
+
288
+ def test_model_validator_compare_succeeds(self):
289
+ error_metric = 'mean_squared_difference'
290
+ result = model_validator.compare_model(
291
+ self.reference_model,
292
+ self.target_model,
293
+ self.test_data,
294
+ error_metric,
295
+ validation_utils.mean_squared_difference,
296
+ )
297
+ for signature_key in self.test_data:
298
+ signature_result = result.get_signature_comparison_result(signature_key)
299
+ input_tensors = signature_result.input_tensors
300
+ output_tensors = signature_result.output_tensors
301
+ constant_tensors = signature_result.constant_tensors
302
+ intermediate_tensors = signature_result.intermediate_tensors
303
+
304
+ self.assertLen(input_tensors, 1)
305
+ self.assertLen(output_tensors, 1)
306
+ self.assertLen(constant_tensors, 1)
307
+ self.assertEmpty(intermediate_tensors)
308
+
309
+ if signature_key == 'add':
310
+ self.assertLess(input_tensors['add_x:0'], 1e-3)
311
+ self.assertAlmostEqual(constant_tensors['Add/y'], 0)
312
+ self.assertLess(output_tensors['PartitionedCall:0'], 1e-3)
313
+ elif signature_key == 'multiply':
314
+ self.assertLess(input_tensors['multiply_x:0'], 1e-3)
315
+ self.assertAlmostEqual(constant_tensors['Mul/y'], 0)
316
+ self.assertLess(output_tensors['PartitionedCall_1:0'], 1e-2)
317
+
318
+ def test_create_json_for_model_explorer(self):
319
+ error_metric = 'mean_squared_difference'
320
+ comparison_result = model_validator.compare_model(
321
+ self.reference_model,
322
+ self.target_model,
323
+ self.test_data,
324
+ error_metric,
325
+ validation_utils.mean_squared_difference,
326
+ )
327
+ thresholds = [0, 1, 2, 3]
328
+ mv_json = model_validator.create_json_for_model_explorer(
329
+ comparison_result, thresholds
330
+ )
331
+ self.assertContainsSubset(
332
+ '"thresholds": [{"value": 0, "bgColor": "rgb(200, 0, 0)"}, {"value":'
333
+ ' 1, "bgColor": "rgb(200, 63, 0)"}, {"value": 2, "bgColor": "rgb(200,'
334
+ ' 126, 0)"}, {"value": 3, "bgColor": "rgb(200, 189, 0)"}]',
335
+ mv_json,
336
+ )
337
+
338
+ def test_create_json_for_model_explorer_no_thresholds(self):
339
+ error_metric = 'mean_squared_difference'
340
+ comparison_result = model_validator.compare_model(
341
+ self.reference_model,
342
+ self.target_model,
343
+ self.test_data,
344
+ error_metric,
345
+ validation_utils.mean_squared_difference,
346
+ )
347
+ mv_json = model_validator.create_json_for_model_explorer(
348
+ comparison_result, []
349
+ )
350
+ self.assertContainsSubset('"thresholds": []', mv_json)
351
+
352
+
353
+ if __name__ == '__main__':
354
+ googletest.main()