ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1082 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Tests for instruction_generator."""
17
+
18
+ import os
19
+
20
+ import numpy as np
21
+
22
+ from tensorflow.python.platform import googletest
23
+ from absl.testing import parameterized
24
+ from ai_edge_quantizer import qtyping
25
+ from ai_edge_quantizer import transformation_instruction_generator as instruction_generator
26
+ from ai_edge_quantizer.utils import test_utils
27
+
28
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(".")
29
+
30
+
31
+ class InstructionGeneratorTest(parameterized.TestCase):
32
+
33
+ @parameterized.named_parameters(
34
+ dict(
35
+ testcase_name="second_index_test",
36
+ param1=qtyping.OpToTensorParams(
37
+ 0,
38
+ [
39
+ qtyping.QuantTransformation.ADD_QUANTIZE,
40
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
41
+ ],
42
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
43
+ ),
44
+ param2=qtyping.OpToTensorParams(
45
+ 2,
46
+ [
47
+ qtyping.QuantTransformation.ADD_QUANTIZE,
48
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
49
+ ],
50
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
51
+ ),
52
+ index=1,
53
+ expected=True,
54
+ ),
55
+ dict(
56
+ testcase_name="different_trans_length_test",
57
+ param1=qtyping.OpToTensorParams(
58
+ 0,
59
+ [
60
+ qtyping.QuantTransformation.ADD_QUANTIZE,
61
+ ],
62
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
63
+ ),
64
+ param2=qtyping.OpToTensorParams(
65
+ 2,
66
+ [
67
+ qtyping.QuantTransformation.ADD_QUANTIZE,
68
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
69
+ ],
70
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
71
+ ),
72
+ index=1,
73
+ expected=False,
74
+ ),
75
+ dict(
76
+ testcase_name="different_trans_length_test2",
77
+ param1=qtyping.OpToTensorParams(
78
+ 0,
79
+ [
80
+ qtyping.QuantTransformation.ADD_QUANTIZE,
81
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
82
+ ],
83
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
84
+ ),
85
+ param2=qtyping.OpToTensorParams(
86
+ 2,
87
+ [
88
+ qtyping.QuantTransformation.ADD_QUANTIZE,
89
+ ],
90
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
91
+ ),
92
+ index=1,
93
+ expected=False,
94
+ ),
95
+ dict(
96
+ testcase_name="test_unmatched_transforamtions",
97
+ param1=qtyping.OpToTensorParams(
98
+ 0,
99
+ [
100
+ qtyping.QuantTransformation.ADD_QUANTIZE,
101
+ ],
102
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
103
+ ),
104
+ param2=qtyping.OpToTensorParams(
105
+ 2,
106
+ [
107
+ qtyping.QuantTransformation.ADD_QUANTIZE,
108
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
109
+ ],
110
+ qtyping.UniformQuantParams(
111
+ 16, None, np.array([1]), np.array([0])
112
+ ),
113
+ ),
114
+ index=0,
115
+ expected=False,
116
+ ),
117
+ )
118
+ def test_check_horizontal_optimization(self, param1, param2, index, expected):
119
+ got = instruction_generator.check_horizontal_optimization(
120
+ param1=param1, param2=param2, index=index
121
+ )
122
+ self.assertEqual(expected, got)
123
+
124
+ @parameterized.named_parameters(
125
+ dict(
126
+ testcase_name="test_success",
127
+ producer_inst=qtyping.TransformationInst(
128
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
129
+ tensor_id=1,
130
+ producer=0,
131
+ consumers=[2],
132
+ parameters=qtyping.UniformQuantParams(
133
+ 8, None, np.array([1]), np.array([0])
134
+ ),
135
+ ),
136
+ consumer_inst=qtyping.TransformationInst(
137
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
138
+ tensor_id=1,
139
+ producer=0,
140
+ consumers=[2],
141
+ parameters=qtyping.UniformQuantParams(
142
+ 8, None, np.array([1]), np.array([0])
143
+ ),
144
+ ),
145
+ expected=True,
146
+ ),
147
+ dict(
148
+ testcase_name="test_wrong_transformation",
149
+ producer_inst=qtyping.TransformationInst(
150
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
151
+ tensor_id=1,
152
+ producer=0,
153
+ consumers=[2],
154
+ parameters=qtyping.UniformQuantParams(
155
+ 8, None, np.array([1]), np.array([0])
156
+ ),
157
+ ),
158
+ consumer_inst=qtyping.TransformationInst(
159
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
160
+ tensor_id=1,
161
+ producer=0,
162
+ consumers=[2],
163
+ parameters=qtyping.UniformQuantParams(
164
+ 8, None, np.array([1]), np.array([0])
165
+ ),
166
+ ),
167
+ expected=False,
168
+ ),
169
+ dict(
170
+ testcase_name="test_wrong_parameters",
171
+ producer_inst=qtyping.TransformationInst(
172
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
173
+ tensor_id=1,
174
+ producer=0,
175
+ consumers=[2],
176
+ parameters=qtyping.UniformQuantParams(
177
+ 8, None, np.array([1]), np.array([0])
178
+ ),
179
+ ),
180
+ consumer_inst=qtyping.TransformationInst(
181
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
182
+ tensor_id=1,
183
+ producer=0,
184
+ consumers=[2],
185
+ parameters=qtyping.UniformQuantParams(
186
+ 16, None, np.array([1]), np.array([0])
187
+ ),
188
+ ),
189
+ expected=False,
190
+ ),
191
+ )
192
+ def test_check_dq_q_elimination(self, producer_inst, consumer_inst, expected):
193
+ got = instruction_generator.check_dq_q_elimination(
194
+ producer_inst=producer_inst, consumer_inst=consumer_inst
195
+ )
196
+ self.assertEqual(expected, got)
197
+
198
+ @parameterized.named_parameters(
199
+ dict(
200
+ testcase_name="test_success",
201
+ producer_inst=qtyping.TransformationInst(
202
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
203
+ tensor_id=1,
204
+ producer=0,
205
+ consumers=[2],
206
+ parameters=qtyping.UniformQuantParams(
207
+ 8, None, np.array([1]), np.array([0])
208
+ ),
209
+ ),
210
+ consumer_inst=qtyping.TransformationInst(
211
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
212
+ tensor_id=1,
213
+ producer=0,
214
+ consumers=[2],
215
+ parameters=qtyping.UniformQuantParams(
216
+ 16, None, np.array([1]), np.array([0])
217
+ ),
218
+ ),
219
+ expected=True,
220
+ ),
221
+ dict(
222
+ testcase_name="test_wrong_transformation",
223
+ producer_inst=qtyping.TransformationInst(
224
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
225
+ tensor_id=1,
226
+ producer=0,
227
+ consumers=[2],
228
+ parameters=qtyping.UniformQuantParams(
229
+ 8, None, np.array([1]), np.array([0])
230
+ ),
231
+ ),
232
+ consumer_inst=qtyping.TransformationInst(
233
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
234
+ tensor_id=1,
235
+ producer=0,
236
+ consumers=[2],
237
+ parameters=qtyping.UniformQuantParams(
238
+ 8, None, np.array([1]), np.array([0])
239
+ ),
240
+ ),
241
+ expected=False,
242
+ ),
243
+ dict(
244
+ testcase_name="test_wrong_parameters",
245
+ producer_inst=qtyping.TransformationInst(
246
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
247
+ tensor_id=1,
248
+ producer=0,
249
+ consumers=[2],
250
+ parameters=qtyping.UniformQuantParams(
251
+ 8, None, np.array([1]), np.array([0])
252
+ ),
253
+ ),
254
+ consumer_inst=qtyping.TransformationInst(
255
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
256
+ tensor_id=1,
257
+ producer=0,
258
+ consumers=[2],
259
+ parameters=qtyping.UniformQuantParams(
260
+ 8, None, np.array([1]), np.array([0])
261
+ ),
262
+ ),
263
+ expected=False,
264
+ ),
265
+ )
266
+ def test_check_replace_dq_q_with_rq(
267
+ self, producer_inst, consumer_inst, expected
268
+ ):
269
+ got = instruction_generator.check_replace_dq_q_with_rq(
270
+ producer_inst=producer_inst, consumer_inst=consumer_inst
271
+ )
272
+ self.assertEqual(expected, got)
273
+
274
+ @parameterized.named_parameters(
275
+ dict(
276
+ testcase_name="test_elimination_success",
277
+ producer_inst=qtyping.TransformationInst(
278
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
279
+ tensor_id=1,
280
+ producer=0,
281
+ consumers=[2],
282
+ parameters=qtyping.UniformQuantParams(
283
+ 8, None, np.array([1]), np.array([0])
284
+ ),
285
+ ),
286
+ consumer_inst=qtyping.TransformationInst(
287
+ transformation=qtyping.QuantTransformation.NO_QUANTIZE,
288
+ tensor_id=1,
289
+ producer=0,
290
+ consumers=[2],
291
+ parameters=qtyping.UniformQuantParams(
292
+ 8, None, np.array([1]), np.array([0])
293
+ ),
294
+ ),
295
+ expected=True,
296
+ ),
297
+ dict(
298
+ testcase_name="test_wrong_transformation1",
299
+ producer_inst=qtyping.TransformationInst(
300
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
301
+ tensor_id=1,
302
+ producer=0,
303
+ consumers=[2],
304
+ parameters=qtyping.UniformQuantParams(
305
+ 8, None, np.array([1]), np.array([0])
306
+ ),
307
+ ),
308
+ consumer_inst=qtyping.TransformationInst(
309
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
310
+ tensor_id=1,
311
+ producer=0,
312
+ consumers=[2],
313
+ parameters=qtyping.UniformQuantParams(
314
+ 8, None, np.array([1]), np.array([0])
315
+ ),
316
+ ),
317
+ expected=False,
318
+ ),
319
+ dict(
320
+ testcase_name="test_wrong_transformation2",
321
+ producer_inst=qtyping.TransformationInst(
322
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
323
+ tensor_id=1,
324
+ producer=0,
325
+ consumers=[2],
326
+ parameters=qtyping.UniformQuantParams(
327
+ 8, None, np.array([1]), np.array([0])
328
+ ),
329
+ ),
330
+ consumer_inst=qtyping.TransformationInst(
331
+ transformation=qtyping.QuantTransformation.NO_QUANTIZE,
332
+ tensor_id=1,
333
+ producer=0,
334
+ consumers=[2],
335
+ parameters=qtyping.UniformQuantParams(
336
+ 8, None, np.array([1]), np.array([0])
337
+ ),
338
+ ),
339
+ expected=False,
340
+ ),
341
+ )
342
+ def test_check_dq_no_quant_elimination(
343
+ self, producer_inst, consumer_inst, expected
344
+ ):
345
+ got = instruction_generator.check_dq_no_quant_elimination(
346
+ producer_inst, consumer_inst
347
+ )
348
+ self.assertEqual(expected, got)
349
+
350
+ @parameterized.named_parameters(
351
+ dict(
352
+ testcase_name="test_empty_consumer",
353
+ producer_trans_rule=qtyping.TransformationInst(
354
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
355
+ tensor_id=1,
356
+ producer=0,
357
+ consumers=[2],
358
+ parameters=qtyping.UniformQuantParams(
359
+ 8, None, np.array([1]), np.array([0])
360
+ ),
361
+ ),
362
+ consumer_trans_rule=[],
363
+ expected=[
364
+ qtyping.TransformationInst(
365
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
366
+ tensor_id=1,
367
+ producer=0,
368
+ consumers=[2],
369
+ parameters=qtyping.UniformQuantParams(
370
+ 8, None, np.array([1]), np.array([0])
371
+ ),
372
+ )
373
+ ],
374
+ ),
375
+ dict(
376
+ testcase_name="test_no_vertical_trans",
377
+ producer_trans_rule=qtyping.TransformationInst(
378
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
379
+ tensor_id=1,
380
+ producer=0,
381
+ consumers=[0, np.array([1]), 2],
382
+ parameters=qtyping.UniformQuantParams(
383
+ 8, None, np.array([1]), np.array([0])
384
+ ),
385
+ ),
386
+ consumer_trans_rule=[
387
+ qtyping.TransformationInst(
388
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
389
+ tensor_id=1,
390
+ producer=0,
391
+ consumers=[0],
392
+ parameters=qtyping.UniformQuantParams(
393
+ 8, None, np.array([1]), np.array([0])
394
+ ),
395
+ ),
396
+ ],
397
+ expected=[
398
+ qtyping.TransformationInst(
399
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
400
+ tensor_id=1,
401
+ producer=0,
402
+ consumers=[0, np.array([1]), 2],
403
+ parameters=qtyping.UniformQuantParams(
404
+ 8, None, np.array([1]), np.array([0])
405
+ ),
406
+ ),
407
+ qtyping.TransformationInst(
408
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
409
+ tensor_id=1,
410
+ producer=0,
411
+ consumers=[0],
412
+ parameters=qtyping.UniformQuantParams(
413
+ 8, None, np.array([1]), np.array([0])
414
+ ),
415
+ ),
416
+ ],
417
+ ),
418
+ dict(
419
+ testcase_name="test_vertical_trans_with_mix_output",
420
+ producer_trans_rule=qtyping.TransformationInst(
421
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
422
+ tensor_id=1,
423
+ producer=0,
424
+ consumers=[0, np.array([1]), 2],
425
+ parameters=qtyping.UniformQuantParams(
426
+ 8, None, np.array([1]), np.array([0])
427
+ ),
428
+ ),
429
+ consumer_trans_rule=[
430
+ qtyping.TransformationInst(
431
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
432
+ tensor_id=1,
433
+ producer=0,
434
+ consumers=[0],
435
+ parameters=qtyping.UniformQuantParams(
436
+ 8, None, np.array([1]), np.array([0])
437
+ ),
438
+ ),
439
+ qtyping.TransformationInst(
440
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
441
+ tensor_id=1,
442
+ producer=0,
443
+ consumers=[1],
444
+ parameters=qtyping.UniformQuantParams(
445
+ 16, None, np.array([1]), np.array([0])
446
+ ),
447
+ ),
448
+ ],
449
+ expected=[
450
+ qtyping.TransformationInst(
451
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
452
+ tensor_id=1,
453
+ producer=0,
454
+ consumers=[2],
455
+ parameters=qtyping.UniformQuantParams(
456
+ 8, None, np.array([1]), np.array([0])
457
+ ),
458
+ ),
459
+ qtyping.TransformationInst(
460
+ transformation=qtyping.QuantTransformation.QUANTIZE_TENSOR,
461
+ tensor_id=1,
462
+ producer=0,
463
+ consumers=[0],
464
+ parameters=qtyping.UniformQuantParams(
465
+ 8, None, np.array([1]), np.array([0])
466
+ ),
467
+ ),
468
+ qtyping.TransformationInst(
469
+ transformation=qtyping.QuantTransformation.QUANTIZE_TENSOR,
470
+ tensor_id=1,
471
+ producer=0,
472
+ consumers=[1],
473
+ parameters=qtyping.UniformQuantParams(
474
+ 8, None, np.array([1]), np.array([0])
475
+ ),
476
+ ),
477
+ qtyping.TransformationInst(
478
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
479
+ tensor_id=1,
480
+ producer=0,
481
+ consumers=[1],
482
+ parameters=qtyping.UniformQuantParams(
483
+ 16, None, np.array([1]), np.array([0])
484
+ ),
485
+ ),
486
+ ],
487
+ ),
488
+ dict(
489
+ testcase_name="test_multi_match",
490
+ producer_trans_rule=qtyping.TransformationInst(
491
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
492
+ tensor_id=1,
493
+ producer=0,
494
+ consumers=[0, 1, 2],
495
+ parameters=qtyping.UniformQuantParams(
496
+ 8, None, np.array([1]), np.array([0])
497
+ ),
498
+ ),
499
+ consumer_trans_rule=[
500
+ qtyping.TransformationInst(
501
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
502
+ tensor_id=1,
503
+ producer=0,
504
+ consumers=[0, 1],
505
+ parameters=qtyping.UniformQuantParams(
506
+ 8, None, np.array([1]), np.array([0])
507
+ ),
508
+ ),
509
+ ],
510
+ expected=[
511
+ qtyping.TransformationInst(
512
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
513
+ tensor_id=1,
514
+ producer=0,
515
+ consumers=[2],
516
+ parameters=qtyping.UniformQuantParams(
517
+ 8, None, np.array([1]), np.array([0])
518
+ ),
519
+ ),
520
+ qtyping.TransformationInst(
521
+ transformation=qtyping.QuantTransformation.QUANTIZE_TENSOR,
522
+ tensor_id=1,
523
+ producer=0,
524
+ consumers=[0, 1],
525
+ parameters=qtyping.UniformQuantParams(
526
+ 8, None, np.array([1]), np.array([0])
527
+ ),
528
+ ),
529
+ ],
530
+ ),
531
+ dict(
532
+ testcase_name="test_dequant_no_quant_elimination_succeeds",
533
+ producer_trans_rule=qtyping.TransformationInst(
534
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
535
+ tensor_id=1,
536
+ producer=0,
537
+ consumers=[0, 1, 2],
538
+ parameters=qtyping.UniformQuantParams(
539
+ 8, None, np.array([1]), np.array([0])
540
+ ),
541
+ ),
542
+ consumer_trans_rule=[
543
+ qtyping.TransformationInst(
544
+ transformation=qtyping.QuantTransformation.NO_QUANTIZE,
545
+ tensor_id=1,
546
+ producer=0,
547
+ consumers=[0, 1, 2],
548
+ parameters=qtyping.UniformQuantParams(
549
+ 8, None, np.array([1]), np.array([0])
550
+ ),
551
+ ),
552
+ ],
553
+ expected=[
554
+ qtyping.TransformationInst(
555
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
556
+ tensor_id=1,
557
+ producer=0,
558
+ consumers=[0, 1, 2],
559
+ parameters=qtyping.UniformQuantParams(
560
+ 8, None, np.array([1]), np.array([0])
561
+ ),
562
+ ),
563
+ ],
564
+ ),
565
+ )
566
+ def test_apply_vertical_optimization(
567
+ self, producer_trans_rule, consumer_trans_rule, expected
568
+ ):
569
+ ins_gen = instruction_generator.TransformationInstructionsGenerator(
570
+ os.path.join(
571
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
572
+ )
573
+ )
574
+ got = ins_gen._apply_vertical_optimization(
575
+ producer_trans_rule, consumer_trans_rule
576
+ )
577
+ self.assertEqual(expected, got)
578
+
579
+ @parameterized.named_parameters(
580
+ dict(testcase_name="test_empty_consumer", param={}, expected=[]),
581
+ dict(
582
+ testcase_name="test_multi_level_grouping",
583
+ param=qtyping.TensorTransformationParams(
584
+ "tfl.quantize",
585
+ qtyping.OpToTensorParams(
586
+ subgraph_op_id=0,
587
+ transformations=[qtyping.QuantTransformation.ADD_DEQUANTIZE],
588
+ parameters=qtyping.UniformQuantParams(
589
+ 8, None, np.array([1]), np.array([0])
590
+ ),
591
+ ),
592
+ [
593
+ qtyping.OpToTensorParams(
594
+ subgraph_op_id=1,
595
+ transformations=[
596
+ qtyping.QuantTransformation.ADD_QUANTIZE
597
+ ],
598
+ parameters=qtyping.UniformQuantParams(
599
+ 8, None, np.array([1]), np.array([0])
600
+ ),
601
+ ),
602
+ qtyping.OpToTensorParams(
603
+ subgraph_op_id=2,
604
+ transformations=[
605
+ qtyping.QuantTransformation.ADD_QUANTIZE,
606
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
607
+ ],
608
+ parameters=qtyping.UniformQuantParams(
609
+ 8, None, np.array([1]), np.array([0])
610
+ ),
611
+ ),
612
+ qtyping.OpToTensorParams(
613
+ subgraph_op_id=3,
614
+ transformations=[
615
+ qtyping.QuantTransformation.ADD_QUANTIZE,
616
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
617
+ ],
618
+ parameters=qtyping.UniformQuantParams(
619
+ 8, None, np.array([1]), np.array([0])
620
+ ),
621
+ ),
622
+ qtyping.OpToTensorParams(
623
+ subgraph_op_id=4,
624
+ transformations=[
625
+ qtyping.QuantTransformation.NO_QUANTIZE,
626
+ ],
627
+ parameters=qtyping.UniformQuantParams(
628
+ 8, None, np.array([1]), np.array([0])
629
+ ),
630
+ ),
631
+ ],
632
+ ),
633
+ expected=[
634
+ [{0, 1, 2, 3}],
635
+ [{0, 1, 2}, {3}],
636
+ [{1, 2}],
637
+ ],
638
+ ),
639
+ )
640
+ def test_group_consumer_transformations(self, param, expected):
641
+ ins_gen = instruction_generator.TransformationInstructionsGenerator(
642
+ os.path.join(
643
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
644
+ )
645
+ )
646
+ got = ins_gen._group_consumer_transformations(param)
647
+ self.assertEqual(expected, got)
648
+
649
+ @parameterized.named_parameters(
650
+ dict(
651
+ testcase_name="test_empty_input",
652
+ consumer_group=[],
653
+ param=qtyping.TensorTransformationParams(
654
+ "arg0",
655
+ None,
656
+ [],
657
+ ),
658
+ expected=[],
659
+ ),
660
+ dict(
661
+ testcase_name="test_multi_level_grouping",
662
+ consumer_group=[
663
+ [{0, 1, 2, 3}],
664
+ [{0, 1, 2}, {3}],
665
+ [{1, 2}],
666
+ ],
667
+ param=qtyping.TensorTransformationParams(
668
+ "tfl.quantize",
669
+ qtyping.OpToTensorParams(
670
+ subgraph_op_id=0,
671
+ transformations=[qtyping.QuantTransformation.ADD_DEQUANTIZE],
672
+ parameters=qtyping.UniformQuantParams(
673
+ 8, None, np.array([1]), np.array([0])
674
+ ),
675
+ ),
676
+ [
677
+ qtyping.OpToTensorParams(
678
+ subgraph_op_id=1,
679
+ transformations=[
680
+ qtyping.QuantTransformation.ADD_QUANTIZE
681
+ ],
682
+ parameters=qtyping.UniformQuantParams(
683
+ 8, None, np.array([1]), np.array([0])
684
+ ),
685
+ ),
686
+ qtyping.OpToTensorParams(
687
+ subgraph_op_id=2,
688
+ transformations=[
689
+ qtyping.QuantTransformation.ADD_QUANTIZE,
690
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
691
+ ],
692
+ parameters=qtyping.UniformQuantParams(
693
+ 8, None, np.array([1]), np.array([0])
694
+ ),
695
+ ),
696
+ qtyping.OpToTensorParams(
697
+ subgraph_op_id=3,
698
+ transformations=[
699
+ qtyping.QuantTransformation.ADD_QUANTIZE,
700
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
701
+ ],
702
+ parameters=qtyping.UniformQuantParams(
703
+ 8, None, np.array([1]), np.array([0])
704
+ ),
705
+ ),
706
+ qtyping.OpToTensorParams(
707
+ subgraph_op_id=4,
708
+ transformations=[
709
+ qtyping.QuantTransformation.NO_QUANTIZE,
710
+ ],
711
+ parameters=qtyping.UniformQuantParams(
712
+ 8, None, np.array([1]), np.array([0])
713
+ ),
714
+ ),
715
+ ],
716
+ ),
717
+ expected=[
718
+ qtyping.TransformationInst(
719
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
720
+ tensor_id=1,
721
+ producer=0,
722
+ consumers=[1, 2, 3],
723
+ parameters=qtyping.UniformQuantParams(
724
+ 8, None, np.array([1]), np.array([0])
725
+ ),
726
+ ),
727
+ qtyping.TransformationInst(
728
+ transformation=qtyping.QuantTransformation.NO_QUANTIZE,
729
+ tensor_id=1,
730
+ producer=0,
731
+ consumers=[4],
732
+ parameters=qtyping.UniformQuantParams(
733
+ 8, None, np.array([1]), np.array([0])
734
+ ),
735
+ ),
736
+ ],
737
+ ),
738
+ )
739
+ def test_produce_transformation_for_vertical_opt(
740
+ self, consumer_group, param, expected
741
+ ):
742
+ ins_gen = instruction_generator.TransformationInstructionsGenerator(
743
+ os.path.join(
744
+ TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
745
+ )
746
+ )
747
+ got = ins_gen._produce_transformation_for_vertical_opt(
748
+ consumer_group, param
749
+ )
750
+ self.assertEqual(expected, got)
751
+
752
+ @parameterized.named_parameters(
753
+ dict(
754
+ testcase_name="test_empty_input",
755
+ consumer_group=[],
756
+ param=qtyping.TensorTransformationParams(
757
+ "arg0",
758
+ None,
759
+ [],
760
+ ),
761
+ expected=[],
762
+ ),
763
+ dict(
764
+ testcase_name="test_multi_level_grouping",
765
+ consumer_group=[
766
+ [{0, 1, 2, 3}],
767
+ [{0, 1, 2}, {3}],
768
+ [{1, 2}],
769
+ ],
770
+ param=qtyping.TensorTransformationParams(
771
+ "tfl.quantize",
772
+ qtyping.OpToTensorParams(
773
+ subgraph_op_id=0,
774
+ transformations=[qtyping.QuantTransformation.ADD_DEQUANTIZE],
775
+ parameters=qtyping.UniformQuantParams(
776
+ 8, None, np.array([1]), np.array([0])
777
+ ),
778
+ ),
779
+ [
780
+ qtyping.OpToTensorParams(
781
+ subgraph_op_id=1,
782
+ transformations=[
783
+ qtyping.QuantTransformation.ADD_QUANTIZE
784
+ ],
785
+ parameters=qtyping.UniformQuantParams(
786
+ 8, None, np.array([1]), np.array([0])
787
+ ),
788
+ ),
789
+ qtyping.OpToTensorParams(
790
+ subgraph_op_id=2,
791
+ transformations=[
792
+ qtyping.QuantTransformation.ADD_QUANTIZE,
793
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
794
+ ],
795
+ parameters=qtyping.UniformQuantParams(
796
+ 8, None, np.array([1]), np.array([0])
797
+ ),
798
+ ),
799
+ qtyping.OpToTensorParams(
800
+ subgraph_op_id=3,
801
+ transformations=[
802
+ qtyping.QuantTransformation.ADD_QUANTIZE,
803
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
804
+ ],
805
+ parameters=qtyping.UniformQuantParams(
806
+ 8, None, np.array([1]), np.array([0])
807
+ ),
808
+ ),
809
+ qtyping.OpToTensorParams(
810
+ subgraph_op_id=4,
811
+ transformations=[
812
+ qtyping.QuantTransformation.NO_QUANTIZE,
813
+ ],
814
+ parameters=qtyping.UniformQuantParams(
815
+ 8, None, np.array([1]), np.array([0])
816
+ ),
817
+ ),
818
+ ],
819
+ ),
820
+ expected=[
821
+ qtyping.TransformationInst(
822
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
823
+ tensor_id=1,
824
+ producer=0,
825
+ consumers=[2, 3],
826
+ parameters=qtyping.UniformQuantParams(
827
+ 8, None, np.array([1]), np.array([0])
828
+ ),
829
+ ),
830
+ ],
831
+ ),
832
+ )
833
+ def test_produce_customer_transformations_unavailable_for_vertical_opt(
834
+ self, consumer_group, param, expected
835
+ ):
836
+ ins_gen = instruction_generator.TransformationInstructionsGenerator(
837
+ os.path.join(
838
+ TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
839
+ )
840
+ )
841
+ got = (
842
+ ins_gen._produce_consumer_transformations_unavailable_for_vertical_opt(
843
+ consumer_group, param
844
+ )
845
+ )
846
+ self.assertEqual(expected, got)
847
+
848
+ def test_empty_param(self):
849
+ """test the capability to handle empty params."""
850
+ test_model_path = os.path.join(
851
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
852
+ )
853
+ quant_parameters = {}
854
+ ins_gen = instruction_generator.TransformationInstructionsGenerator(
855
+ test_model_path
856
+ )
857
+ instructions = ins_gen.quant_params_to_transformation_insts(
858
+ quant_parameters
859
+ )
860
+ self.assertEmpty(instructions)
861
+
862
+ def test_generate_instruction_for_single_fc_bias(self):
863
+ """test the capability to run multiple tensor infos."""
864
+ test_model_path = os.path.join(
865
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
866
+ )
867
+ quant_parameters = {}
868
+ quant_parameters["serving_default_input_2:0"] = (
869
+ qtyping.TensorTransformationParams(
870
+ "serving_default_input_2:0",
871
+ None,
872
+ [
873
+ qtyping.OpToTensorParams(
874
+ 0,
875
+ [qtyping.QuantTransformation.ADD_QUANTIZE],
876
+ qtyping.UniformQuantParams(
877
+ 8, None, np.array([1]), np.array([0])
878
+ ),
879
+ )
880
+ ],
881
+ )
882
+ )
883
+
884
+ quant_parameters["StatefulPartitionedCall:0"] = (
885
+ qtyping.TensorTransformationParams(
886
+ "StatefulPartitionedCall:0",
887
+ qtyping.OpToTensorParams(
888
+ 0,
889
+ [
890
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
891
+ qtyping.QuantTransformation.ADD_QUANTIZE,
892
+ ],
893
+ qtyping.UniformQuantParams(
894
+ 8, None, np.array([1]), np.array([0])
895
+ ),
896
+ ),
897
+ [],
898
+ )
899
+ )
900
+
901
+ ins_gen = instruction_generator.TransformationInstructionsGenerator(
902
+ test_model_path
903
+ )
904
+ instructions = ins_gen.quant_params_to_transformation_insts(
905
+ quant_parameters
906
+ )
907
+ input_transformation = qtyping.TensorTransformationInsts(
908
+ tensor_name="serving_default_input_2:0",
909
+ subgraph_id=0,
910
+ instructions=[
911
+ qtyping.TransformationInst(
912
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
913
+ tensor_id=0,
914
+ producer=-1, # input tensor is the subgraph input
915
+ consumers=[0], # consumed by node 0
916
+ parameters=qtyping.UniformQuantParams(
917
+ 8, None, np.array([1]), np.array([0])
918
+ ),
919
+ )
920
+ ],
921
+ )
922
+ output_transformation = qtyping.TensorTransformationInsts(
923
+ tensor_name="StatefulPartitionedCall:0",
924
+ subgraph_id=0,
925
+ instructions=[
926
+ qtyping.TransformationInst(
927
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
928
+ tensor_id=3,
929
+ producer=0,
930
+ consumers=[-1],
931
+ parameters=qtyping.UniformQuantParams(
932
+ 8, None, np.array([1]), np.array([0])
933
+ ),
934
+ ),
935
+ qtyping.TransformationInst(
936
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
937
+ tensor_id=3,
938
+ producer=0,
939
+ consumers=[-1],
940
+ parameters=qtyping.UniformQuantParams(
941
+ 8, None, np.array([1]), np.array([0])
942
+ ),
943
+ ),
944
+ ],
945
+ )
946
+ self.assertLen(instructions, 2)
947
+ self.assertEqual(
948
+ instructions["serving_default_input_2:0"], input_transformation
949
+ )
950
+ self.assertEqual(
951
+ instructions["StatefulPartitionedCall:0"], output_transformation
952
+ )
953
+
954
+ def test_raise_error_on_op_replacement_transformation_is_not_unique(self):
955
+ test_model_path = os.path.join(
956
+ TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
957
+ )
958
+ quant_parameters = {}
959
+ quant_parameters["tfl.quantize"] = qtyping.TensorTransformationParams(
960
+ "tfl.quantize",
961
+ qtyping.OpToTensorParams(
962
+ subgraph_op_id=0,
963
+ transformations=[
964
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
965
+ qtyping.QuantTransformation.EMULATED_SUBCHANNEL,
966
+ ],
967
+ parameters=qtyping.UniformQuantParams(
968
+ 8, None, np.array([1]), np.array([0])
969
+ ),
970
+ ),
971
+ [],
972
+ )
973
+ ins_gen = instruction_generator.TransformationInstructionsGenerator(
974
+ test_model_path
975
+ )
976
+ with self.assertRaisesRegex(
977
+ ValueError, "op replacement transformation can not be combined"
978
+ ):
979
+ ins_gen.quant_params_to_transformation_insts(quant_parameters)
980
+
981
+ def test_raise_error_on_no_quant_conflict(self):
982
+ test_model_path = os.path.join(
983
+ TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
984
+ )
985
+ quant_parameters = {}
986
+ quant_parameters["tfl.quantize"] = qtyping.TensorTransformationParams(
987
+ "tfl.quantize",
988
+ None,
989
+ [
990
+ qtyping.OpToTensorParams(
991
+ subgraph_op_id=1,
992
+ transformations=[qtyping.QuantTransformation.QUANTIZE_TENSOR],
993
+ parameters=qtyping.UniformQuantParams(
994
+ 8, None, np.array([1]), np.array([0])
995
+ ),
996
+ ),
997
+ qtyping.OpToTensorParams(
998
+ subgraph_op_id=2,
999
+ transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
1000
+ parameters=None,
1001
+ ),
1002
+ ],
1003
+ )
1004
+ ins_gen = instruction_generator.TransformationInstructionsGenerator(
1005
+ test_model_path
1006
+ )
1007
+ with self.assertRaisesRegex(
1008
+ ValueError, "can not be both quantized and unquantized"
1009
+ ):
1010
+ ins_gen.quant_params_to_transformation_insts(quant_parameters)
1011
+
1012
+ def test_generate_instruction_for_branching(self):
1013
+ """test horizontal and vertial optimization on a graph with multi branch."""
1014
+ test_model_path = os.path.join(
1015
+ TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
1016
+ )
1017
+ quant_parameters = {}
1018
+ quant_parameters["tfl.quantize"] = qtyping.TensorTransformationParams(
1019
+ "tfl.quantize",
1020
+ qtyping.OpToTensorParams(
1021
+ subgraph_op_id=0,
1022
+ transformations=[qtyping.QuantTransformation.ADD_DEQUANTIZE],
1023
+ parameters=qtyping.UniformQuantParams(
1024
+ 8, None, np.array([1]), np.array([0])
1025
+ ),
1026
+ ),
1027
+ [
1028
+ qtyping.OpToTensorParams(
1029
+ subgraph_op_id=1,
1030
+ transformations=[qtyping.QuantTransformation.ADD_QUANTIZE],
1031
+ parameters=qtyping.UniformQuantParams(
1032
+ 8, None, np.array([1]), np.array([0])
1033
+ ),
1034
+ ),
1035
+ qtyping.OpToTensorParams(
1036
+ subgraph_op_id=2,
1037
+ transformations=[
1038
+ qtyping.QuantTransformation.ADD_QUANTIZE,
1039
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
1040
+ ],
1041
+ parameters=qtyping.UniformQuantParams(
1042
+ 8, None, np.array([1]), np.array([0])
1043
+ ),
1044
+ ),
1045
+ ],
1046
+ )
1047
+ ins_gen = instruction_generator.TransformationInstructionsGenerator(
1048
+ test_model_path
1049
+ )
1050
+ instructions = ins_gen.quant_params_to_transformation_insts(
1051
+ quant_parameters
1052
+ )
1053
+ expected_instructions = qtyping.TensorTransformationInsts(
1054
+ tensor_name="tfl.quantize",
1055
+ subgraph_id=0,
1056
+ instructions=[
1057
+ qtyping.TransformationInst(
1058
+ transformation=qtyping.QuantTransformation.QUANTIZE_TENSOR,
1059
+ tensor_id=1,
1060
+ producer=0,
1061
+ consumers=[1, 2],
1062
+ parameters=qtyping.UniformQuantParams(
1063
+ 8, None, np.array([1]), np.array([0])
1064
+ ),
1065
+ ),
1066
+ qtyping.TransformationInst(
1067
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
1068
+ tensor_id=1,
1069
+ producer=0,
1070
+ consumers=[2],
1071
+ parameters=qtyping.UniformQuantParams(
1072
+ 8, None, np.array([1]), np.array([0])
1073
+ ),
1074
+ ),
1075
+ ],
1076
+ )
1077
+ self.assertLen(instructions, 1)
1078
+ self.assertEqual(instructions["tfl.quantize"], expected_instructions)
1079
+
1080
+
1081
+ if __name__ == "__main__":
1082
+ googletest.main()