ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,354 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import json
|
17
|
+
import os
|
18
|
+
from absl import flags
|
19
|
+
import numpy as np
|
20
|
+
from tensorflow.python.platform import googletest
|
21
|
+
from ai_edge_quantizer import model_validator
|
22
|
+
from ai_edge_quantizer.utils import test_utils
|
23
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
24
|
+
from ai_edge_quantizer.utils import validation_utils
|
25
|
+
|
26
|
+
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('.')
|
27
|
+
|
28
|
+
|
29
|
+
class ComparisonResultTest(googletest.TestCase):
|
30
|
+
|
31
|
+
def setUp(self):
|
32
|
+
# TODO: b/358437395 - Remove this line once the bug is fixed.
|
33
|
+
flags.FLAGS.mark_as_parsed()
|
34
|
+
super().setUp()
|
35
|
+
self.test_model_path = os.path.join(
|
36
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/two_signatures.tflite'
|
37
|
+
)
|
38
|
+
self.test_model = tfl_flatbuffer_utils.get_model_buffer(
|
39
|
+
self.test_model_path
|
40
|
+
)
|
41
|
+
self.test_quantized_model_path = os.path.join(
|
42
|
+
TEST_DATA_PREFIX_PATH,
|
43
|
+
'tests/models/two_signatures_a8w8.tflite',
|
44
|
+
)
|
45
|
+
self.test_quantized_model = tfl_flatbuffer_utils.get_model_buffer(
|
46
|
+
self.test_quantized_model_path
|
47
|
+
)
|
48
|
+
self.test_data = {
|
49
|
+
'add': {'add_x:0': 1e-3, 'Add/y': 0.25, 'PartitionedCall:0': 1e-3},
|
50
|
+
'multiply': {
|
51
|
+
'multiply_x:0': 1e-3,
|
52
|
+
'Mul/y': 0.32,
|
53
|
+
'PartitionedCall_1:0': 1e-2,
|
54
|
+
},
|
55
|
+
}
|
56
|
+
self.test_dir = self.create_tempdir()
|
57
|
+
self.comparison_result = model_validator.ComparisonResult(
|
58
|
+
self.test_model, self.test_quantized_model
|
59
|
+
)
|
60
|
+
|
61
|
+
def test_add_new_signature_results_succeeds(self):
|
62
|
+
for signature_key, test_result in self.test_data.items():
|
63
|
+
self.comparison_result.add_new_signature_results(
|
64
|
+
'mean_squared_difference',
|
65
|
+
test_result,
|
66
|
+
signature_key,
|
67
|
+
)
|
68
|
+
self.assertLen(
|
69
|
+
self.comparison_result.available_signature_keys(), len(self.test_data)
|
70
|
+
)
|
71
|
+
|
72
|
+
for signature_key in self.test_data:
|
73
|
+
signature_result = self.comparison_result.get_signature_comparison_result(
|
74
|
+
signature_key
|
75
|
+
)
|
76
|
+
input_tensors = signature_result.input_tensors
|
77
|
+
output_tensors = signature_result.output_tensors
|
78
|
+
constant_tensors = signature_result.constant_tensors
|
79
|
+
intermediate_tensors = signature_result.intermediate_tensors
|
80
|
+
|
81
|
+
self.assertLen(input_tensors, 1)
|
82
|
+
self.assertLen(output_tensors, 1)
|
83
|
+
self.assertLen(constant_tensors, 1)
|
84
|
+
self.assertEmpty(intermediate_tensors)
|
85
|
+
|
86
|
+
def test_add_new_signature_results_fails_same_signature_key(self):
|
87
|
+
self.comparison_result.add_new_signature_results(
|
88
|
+
'mean_squared_difference',
|
89
|
+
self.test_data['add'],
|
90
|
+
'add',
|
91
|
+
)
|
92
|
+
error_message = 'add is already in the comparison_results.'
|
93
|
+
with self.assertRaisesWithPredicateMatch(
|
94
|
+
ValueError, lambda err: error_message in str(err)
|
95
|
+
):
|
96
|
+
self.comparison_result.add_new_signature_results(
|
97
|
+
'mean_squared_difference',
|
98
|
+
self.test_data['add'],
|
99
|
+
'add',
|
100
|
+
)
|
101
|
+
|
102
|
+
def test_get_signature_comparison_result_fails_with_invalid_signature_key(
|
103
|
+
self,
|
104
|
+
):
|
105
|
+
self.comparison_result.add_new_signature_results(
|
106
|
+
'mean_squared_difference',
|
107
|
+
self.test_data['add'],
|
108
|
+
'add',
|
109
|
+
)
|
110
|
+
error_message = 'multiply is not in the comparison_results.'
|
111
|
+
with self.assertRaisesWithPredicateMatch(
|
112
|
+
ValueError, lambda err: error_message in str(err)
|
113
|
+
):
|
114
|
+
self.comparison_result.get_signature_comparison_result('multiply')
|
115
|
+
|
116
|
+
def test_get_all_tensor_results_succeeds(self):
|
117
|
+
for signature_key, test_result in self.test_data.items():
|
118
|
+
self.comparison_result.add_new_signature_results(
|
119
|
+
'mean_squared_difference',
|
120
|
+
test_result,
|
121
|
+
signature_key,
|
122
|
+
)
|
123
|
+
all_tensor_results = self.comparison_result.get_all_tensor_results()
|
124
|
+
self.assertLen(all_tensor_results, 6)
|
125
|
+
self.assertIn('add_x:0', all_tensor_results)
|
126
|
+
self.assertIn('Add/y', all_tensor_results)
|
127
|
+
self.assertIn('PartitionedCall:0', all_tensor_results)
|
128
|
+
self.assertIn('multiply_x:0', all_tensor_results)
|
129
|
+
self.assertIn('Mul/y', all_tensor_results)
|
130
|
+
self.assertIn('PartitionedCall_1:0', all_tensor_results)
|
131
|
+
|
132
|
+
def test_save_comparison_result_succeeds(self):
|
133
|
+
for signature_key, test_result in self.test_data.items():
|
134
|
+
self.comparison_result.add_new_signature_results(
|
135
|
+
'mean_squared_difference',
|
136
|
+
test_result,
|
137
|
+
signature_key,
|
138
|
+
)
|
139
|
+
model_name = 'test_model'
|
140
|
+
self.comparison_result.save(self.test_dir.full_path, model_name)
|
141
|
+
test_json_path = os.path.join(
|
142
|
+
self.test_dir.full_path, model_name + '_comparison_result.json'
|
143
|
+
)
|
144
|
+
with open(test_json_path) as json_file:
|
145
|
+
json_dict = json.load(json_file)
|
146
|
+
|
147
|
+
# Check model size stats.
|
148
|
+
self.assertIn('reduced_size_bytes', json_dict)
|
149
|
+
self.assertEqual(
|
150
|
+
json_dict['reduced_size_bytes'],
|
151
|
+
len(self.test_model) - len(self.test_quantized_model),
|
152
|
+
)
|
153
|
+
self.assertIn('reduced_size_percentage', json_dict)
|
154
|
+
self.assertEqual(
|
155
|
+
json_dict['reduced_size_percentage'],
|
156
|
+
(len(self.test_model) - len(self.test_quantized_model))
|
157
|
+
/ len(self.test_model)
|
158
|
+
* 100,
|
159
|
+
)
|
160
|
+
|
161
|
+
for signature_key in self.test_data:
|
162
|
+
self.assertIn(signature_key, json_dict)
|
163
|
+
signature_result = json_dict[signature_key]
|
164
|
+
self.assertIn('error_metric', signature_result)
|
165
|
+
self.assertEqual(
|
166
|
+
signature_result['error_metric'], 'mean_squared_difference'
|
167
|
+
)
|
168
|
+
self.assertIn('constant_tensors', signature_result)
|
169
|
+
if signature_key == 'add':
|
170
|
+
self.assertIn('Add/y', signature_result['constant_tensors'])
|
171
|
+
self.assertNotIn('Mul/y', signature_result['constant_tensors'])
|
172
|
+
elif signature_key == 'multiply':
|
173
|
+
self.assertIn('Mul/y', signature_result['constant_tensors'])
|
174
|
+
self.assertNotIn('Add/y', signature_result['constant_tensors'])
|
175
|
+
|
176
|
+
|
177
|
+
class ModelValidatorCompareTest(googletest.TestCase):
|
178
|
+
|
179
|
+
def setUp(self):
|
180
|
+
# TODO: b/358437395 - Remove this line once the bug is fixed.
|
181
|
+
flags.FLAGS.mark_as_parsed()
|
182
|
+
super().setUp()
|
183
|
+
self.reference_model_path = os.path.join(
|
184
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/single_fc_bias.tflite'
|
185
|
+
)
|
186
|
+
self.target_model_path = os.path.join(
|
187
|
+
TEST_DATA_PREFIX_PATH,
|
188
|
+
'tests/models/single_fc_bias_sub_channel_weight_only_sym_weight.tflite',
|
189
|
+
)
|
190
|
+
self.reference_model = tfl_flatbuffer_utils.get_model_buffer(
|
191
|
+
self.reference_model_path
|
192
|
+
)
|
193
|
+
self.target_model = tfl_flatbuffer_utils.get_model_buffer(
|
194
|
+
self.target_model_path
|
195
|
+
)
|
196
|
+
self.signature_key = 'serving_default' # single signature.
|
197
|
+
self.test_data = test_utils.create_random_normal_input_data(
|
198
|
+
self.reference_model_path
|
199
|
+
)
|
200
|
+
self.test_dir = self.create_tempdir()
|
201
|
+
|
202
|
+
def test_model_validator_compare(self):
|
203
|
+
error_metric = 'mean_squared_difference'
|
204
|
+
comparison_result = model_validator.compare_model(
|
205
|
+
self.reference_model,
|
206
|
+
self.target_model,
|
207
|
+
self.test_data,
|
208
|
+
error_metric,
|
209
|
+
validation_utils.mean_squared_difference,
|
210
|
+
)
|
211
|
+
result = comparison_result.get_signature_comparison_result(
|
212
|
+
self.signature_key
|
213
|
+
)
|
214
|
+
self.assertEqual(result.error_metric, 'mean_squared_difference')
|
215
|
+
input_tensors = result.input_tensors
|
216
|
+
output_tensors = result.output_tensors
|
217
|
+
constant_tensors = result.constant_tensors
|
218
|
+
intermediate_tensors = result.intermediate_tensors
|
219
|
+
|
220
|
+
self.assertLen(input_tensors, 1)
|
221
|
+
self.assertLen(output_tensors, 1)
|
222
|
+
self.assertLen(constant_tensors, 2)
|
223
|
+
self.assertEmpty(intermediate_tensors)
|
224
|
+
|
225
|
+
self.assertAlmostEqual(input_tensors['serving_default_input_2:0'], 0)
|
226
|
+
self.assertAlmostEqual(constant_tensors['arith.constant1'], 0)
|
227
|
+
self.assertLess(output_tensors['StatefulPartitionedCall:0'], 1e-5)
|
228
|
+
|
229
|
+
def test_create_json_for_model_explorer(self):
|
230
|
+
error_metric = 'mean_squared_difference'
|
231
|
+
comparison_result = model_validator.compare_model(
|
232
|
+
self.reference_model,
|
233
|
+
self.target_model,
|
234
|
+
self.test_data,
|
235
|
+
error_metric,
|
236
|
+
validation_utils.mean_squared_difference,
|
237
|
+
)
|
238
|
+
mv_json = model_validator.create_json_for_model_explorer(
|
239
|
+
comparison_result, [0, 1, 2, 3]
|
240
|
+
)
|
241
|
+
self.assertContainsSubset(
|
242
|
+
'"thresholds": [{"value": 0, "bgColor": "rgb(200, 0, 0)"}, {"value":'
|
243
|
+
' 1, "bgColor": "rgb(200, 63, 0)"}, {"value": 2, "bgColor": "rgb(200,'
|
244
|
+
' 126, 0)"}, {"value": 3, "bgColor": "rgb(200, 189, 0)"}]',
|
245
|
+
mv_json,
|
246
|
+
)
|
247
|
+
|
248
|
+
def test_create_json_for_model_explorer_no_thresholds(self):
|
249
|
+
error_metric = 'mean_squared_difference'
|
250
|
+
comparison_result = model_validator.compare_model(
|
251
|
+
self.reference_model,
|
252
|
+
self.target_model,
|
253
|
+
self.test_data,
|
254
|
+
error_metric,
|
255
|
+
validation_utils.mean_squared_difference,
|
256
|
+
)
|
257
|
+
mv_json = model_validator.create_json_for_model_explorer(
|
258
|
+
comparison_result, []
|
259
|
+
)
|
260
|
+
self.assertContainsSubset('"thresholds": []', mv_json)
|
261
|
+
|
262
|
+
|
263
|
+
class ModelValidatorMultiSignatureModelTest(googletest.TestCase):
|
264
|
+
|
265
|
+
def setUp(self):
|
266
|
+
# TODO: b/358437395 - Remove this line once the bug is fixed.
|
267
|
+
flags.FLAGS.mark_as_parsed()
|
268
|
+
super().setUp()
|
269
|
+
self.reference_model_path = os.path.join(
|
270
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/two_signatures.tflite'
|
271
|
+
)
|
272
|
+
self.target_model_path = os.path.join(
|
273
|
+
TEST_DATA_PREFIX_PATH,
|
274
|
+
'tests/models/two_signatures_a8w8.tflite',
|
275
|
+
)
|
276
|
+
self.reference_model = tfl_flatbuffer_utils.get_model_buffer(
|
277
|
+
self.reference_model_path
|
278
|
+
)
|
279
|
+
self.target_model = tfl_flatbuffer_utils.get_model_buffer(
|
280
|
+
self.target_model_path
|
281
|
+
)
|
282
|
+
self.test_data = {
|
283
|
+
'add': [{'x': np.array([2.0]).astype(np.float32)}],
|
284
|
+
'multiply': [{'x': np.array([1.0]).astype(np.float32)}],
|
285
|
+
}
|
286
|
+
self.test_dir = self.create_tempdir()
|
287
|
+
|
288
|
+
def test_model_validator_compare_succeeds(self):
|
289
|
+
error_metric = 'mean_squared_difference'
|
290
|
+
result = model_validator.compare_model(
|
291
|
+
self.reference_model,
|
292
|
+
self.target_model,
|
293
|
+
self.test_data,
|
294
|
+
error_metric,
|
295
|
+
validation_utils.mean_squared_difference,
|
296
|
+
)
|
297
|
+
for signature_key in self.test_data:
|
298
|
+
signature_result = result.get_signature_comparison_result(signature_key)
|
299
|
+
input_tensors = signature_result.input_tensors
|
300
|
+
output_tensors = signature_result.output_tensors
|
301
|
+
constant_tensors = signature_result.constant_tensors
|
302
|
+
intermediate_tensors = signature_result.intermediate_tensors
|
303
|
+
|
304
|
+
self.assertLen(input_tensors, 1)
|
305
|
+
self.assertLen(output_tensors, 1)
|
306
|
+
self.assertLen(constant_tensors, 1)
|
307
|
+
self.assertEmpty(intermediate_tensors)
|
308
|
+
|
309
|
+
if signature_key == 'add':
|
310
|
+
self.assertLess(input_tensors['add_x:0'], 1e-3)
|
311
|
+
self.assertAlmostEqual(constant_tensors['Add/y'], 0)
|
312
|
+
self.assertLess(output_tensors['PartitionedCall:0'], 1e-3)
|
313
|
+
elif signature_key == 'multiply':
|
314
|
+
self.assertLess(input_tensors['multiply_x:0'], 1e-3)
|
315
|
+
self.assertAlmostEqual(constant_tensors['Mul/y'], 0)
|
316
|
+
self.assertLess(output_tensors['PartitionedCall_1:0'], 1e-2)
|
317
|
+
|
318
|
+
def test_create_json_for_model_explorer(self):
|
319
|
+
error_metric = 'mean_squared_difference'
|
320
|
+
comparison_result = model_validator.compare_model(
|
321
|
+
self.reference_model,
|
322
|
+
self.target_model,
|
323
|
+
self.test_data,
|
324
|
+
error_metric,
|
325
|
+
validation_utils.mean_squared_difference,
|
326
|
+
)
|
327
|
+
thresholds = [0, 1, 2, 3]
|
328
|
+
mv_json = model_validator.create_json_for_model_explorer(
|
329
|
+
comparison_result, thresholds
|
330
|
+
)
|
331
|
+
self.assertContainsSubset(
|
332
|
+
'"thresholds": [{"value": 0, "bgColor": "rgb(200, 0, 0)"}, {"value":'
|
333
|
+
' 1, "bgColor": "rgb(200, 63, 0)"}, {"value": 2, "bgColor": "rgb(200,'
|
334
|
+
' 126, 0)"}, {"value": 3, "bgColor": "rgb(200, 189, 0)"}]',
|
335
|
+
mv_json,
|
336
|
+
)
|
337
|
+
|
338
|
+
def test_create_json_for_model_explorer_no_thresholds(self):
|
339
|
+
error_metric = 'mean_squared_difference'
|
340
|
+
comparison_result = model_validator.compare_model(
|
341
|
+
self.reference_model,
|
342
|
+
self.target_model,
|
343
|
+
self.test_data,
|
344
|
+
error_metric,
|
345
|
+
validation_utils.mean_squared_difference,
|
346
|
+
)
|
347
|
+
mv_json = model_validator.create_json_for_model_explorer(
|
348
|
+
comparison_result, []
|
349
|
+
)
|
350
|
+
self.assertContainsSubset('"thresholds": []', mv_json)
|
351
|
+
|
352
|
+
|
353
|
+
if __name__ == '__main__':
|
354
|
+
googletest.main()
|