ai-edge-quantizer-nightly 0.1.0.dev20250424__py3-none-any.whl → 0.1.0.dev20250426__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/transformation_performer.py +44 -23
- ai_edge_quantizer/transformation_performer_test.py +127 -5
- {ai_edge_quantizer_nightly-0.1.0.dev20250424.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.1.0.dev20250424.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info}/RECORD +7 -7
- {ai_edge_quantizer_nightly-0.1.0.dev20250424.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250424.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250424.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info}/top_level.txt +0 -0
@@ -180,6 +180,38 @@ class TransformationPerformer:
|
|
180
180
|
)
|
181
181
|
transformation.tensor_id = trans_info.output_tensor_id
|
182
182
|
|
183
|
+
def _get_updated_producer_id(
|
184
|
+
self, original_producer_id: int, subgraph_id: int
|
185
|
+
) -> int:
|
186
|
+
"""Update the producer of a transformation instruction."""
|
187
|
+
if original_producer_id is None or original_producer_id < 0:
|
188
|
+
producer = -1
|
189
|
+
elif original_producer_id < len(self._original_op_id_map[subgraph_id]):
|
190
|
+
producer = self._original_op_id_map[subgraph_id][original_producer_id]
|
191
|
+
else:
|
192
|
+
# If the producer id is not in the original op map, it's an added op,
|
193
|
+
# go the added op map to find the producer.
|
194
|
+
producer = self._added_op_id_map[subgraph_id][
|
195
|
+
original_producer_id - len(self._original_op_id_map[subgraph_id])
|
196
|
+
]
|
197
|
+
return producer
|
198
|
+
|
199
|
+
def _get_updated_consumer_ids(
|
200
|
+
self,
|
201
|
+
original_consumer_ids: list[int],
|
202
|
+
subgraph_id: int,
|
203
|
+
) -> list[int]:
|
204
|
+
"""Update the consumers of a transformation instruction."""
|
205
|
+
consumers = []
|
206
|
+
for original_op_id in original_consumer_ids:
|
207
|
+
new_consumer_id = (
|
208
|
+
-1
|
209
|
+
if original_op_id == -1
|
210
|
+
else self._original_op_id_map[subgraph_id][original_op_id]
|
211
|
+
)
|
212
|
+
consumers.append(new_consumer_id)
|
213
|
+
return consumers
|
214
|
+
|
183
215
|
def _apply_single_transformation(
|
184
216
|
self,
|
185
217
|
transformation_inst: qtyping.TensorTransformationInsts,
|
@@ -198,28 +230,12 @@ class TransformationPerformer:
|
|
198
230
|
None, update the transformation_inst & tflite_model in place
|
199
231
|
"""
|
200
232
|
instruction = transformation_inst.instructions[transformation_index]
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
instruction.producer
|
208
|
-
]
|
209
|
-
else:
|
210
|
-
# if the producer id is not in the original op map, it's an added op,
|
211
|
-
# go the corresponding new maps
|
212
|
-
producer = self._added_op_id_map[transformation_inst.subgraph_id][
|
213
|
-
instruction.producer
|
214
|
-
- len(self._original_op_id_map[transformation_inst.subgraph_id])
|
215
|
-
]
|
216
|
-
consumers = []
|
217
|
-
for original_op_id in instruction.consumers:
|
218
|
-
consumers.append(
|
219
|
-
self._original_op_id_map[transformation_inst.subgraph_id][
|
220
|
-
original_op_id
|
221
|
-
]
|
222
|
-
)
|
233
|
+
producer = self._get_updated_producer_id(
|
234
|
+
instruction.producer, transformation_inst.subgraph_id
|
235
|
+
)
|
236
|
+
consumers = self._get_updated_consumer_ids(
|
237
|
+
instruction.consumers, transformation_inst.subgraph_id
|
238
|
+
)
|
223
239
|
trans_info = self._transformation_registration[instruction.transformation](
|
224
240
|
transformation_utils.TransformationInput(
|
225
241
|
instruction.tensor_id,
|
@@ -239,7 +255,12 @@ class TransformationPerformer:
|
|
239
255
|
)
|
240
256
|
self._update_op_id_map(
|
241
257
|
transformation_inst.subgraph_id,
|
242
|
-
|
258
|
+
# The added op must be right before the most immediate consumer, unless
|
259
|
+
# the consumer is the graph output (id=-1), then use the producer's
|
260
|
+
# index instead.
|
261
|
+
min(instruction.consumers)
|
262
|
+
if min(instruction.consumers) >= 0
|
263
|
+
else instruction.producer + 1,
|
243
264
|
trans_info.num_ops_added,
|
244
265
|
)
|
245
266
|
|
@@ -112,6 +112,32 @@ class TransformationPerformerTest(parameterized.TestCase):
|
|
112
112
|
for index, op_id in enumerate(op_id_map[0]):
|
113
113
|
self.assertEqual(op_id, index)
|
114
114
|
|
115
|
+
def test_update_op_id_map_not_changing_value_single_op_model(self):
|
116
|
+
"""test for _update_op_id_map."""
|
117
|
+
model = tfl_flatbuffer_utils.read_model(
|
118
|
+
os.path.join(
|
119
|
+
TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
|
120
|
+
)
|
121
|
+
)
|
122
|
+
self._transformation_performer._create_op_id_map(model)
|
123
|
+
instruction = qtyping.TransformationInst(
|
124
|
+
transformation=qtyping.QuantTransformation.QUANTIZE_TENSOR,
|
125
|
+
tensor_id=0,
|
126
|
+
producer=0,
|
127
|
+
consumers=[-1],
|
128
|
+
parameters=qtyping.UniformQuantParams(
|
129
|
+
8, None, np.array([1]), np.array([0])
|
130
|
+
),
|
131
|
+
)
|
132
|
+
producer = self._transformation_performer._get_updated_producer_id(
|
133
|
+
instruction.producer, 0
|
134
|
+
)
|
135
|
+
consumers = self._transformation_performer._get_updated_consumer_ids(
|
136
|
+
instruction.consumers, 0
|
137
|
+
)
|
138
|
+
self.assertEqual(producer, 0)
|
139
|
+
self.assertEqual(consumers, [-1])
|
140
|
+
|
115
141
|
@parameterized.named_parameters(
|
116
142
|
dict(
|
117
143
|
testcase_name="test_no_update",
|
@@ -271,7 +297,7 @@ class TransformationPerformerTest(parameterized.TestCase):
|
|
271
297
|
expected_added_op_id_map,
|
272
298
|
)
|
273
299
|
|
274
|
-
def
|
300
|
+
def test_update_instructions_updates_tensor_id_after_duplicate_tensor(self):
|
275
301
|
def get_test_instruction(transformation, consumers):
|
276
302
|
return qtyping.TransformationInst(
|
277
303
|
transformation=transformation,
|
@@ -325,6 +351,8 @@ class TransformationPerformerTest(parameterized.TestCase):
|
|
325
351
|
tensor_name="sequential/conv2d/Relu;sequential/conv2d/BiasAdd;"
|
326
352
|
+ "sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1",
|
327
353
|
subgraph_id=0,
|
354
|
+
# Conv2d: op_id=0, output_tensor_id=7.
|
355
|
+
# This should add two sequential dequants after the conv2d.
|
328
356
|
instructions=[
|
329
357
|
qtyping.TransformationInst(
|
330
358
|
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
@@ -349,6 +377,8 @@ class TransformationPerformerTest(parameterized.TestCase):
|
|
349
377
|
"sequential/average_pooling2d/AvgPool": qtyping.TensorTransformationInsts(
|
350
378
|
tensor_name="sequential/average_pooling2d/AvgPool",
|
351
379
|
subgraph_id=0,
|
380
|
+
# Avg_pool: op_id=1, output_tensor_id=8.
|
381
|
+
# This should add two sequential dequants after the avg_pool.
|
352
382
|
instructions=[
|
353
383
|
qtyping.TransformationInst(
|
354
384
|
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
@@ -376,19 +406,111 @@ class TransformationPerformerTest(parameterized.TestCase):
|
|
376
406
|
)
|
377
407
|
self.assertLen(self._test_model.subgraphs, 1)
|
378
408
|
self.assertLen(self._test_model.subgraphs[0].operators, 10)
|
409
|
+
# The original model has 13 tensors, each dequant adds 1 tensor.
|
379
410
|
self.assertLen(self._test_model.subgraphs[0].tensors, 17)
|
411
|
+
# Check that the dequant opcode is added to the model.
|
380
412
|
self.assertEqual(
|
381
413
|
self._test_model.subgraphs[0].operators[1].opcodeIndex,
|
382
414
|
len(self._test_model.operatorCodes) - 1,
|
383
415
|
)
|
416
|
+
# Conv2d, dequant, dequant, avgpool, dequant, dequant, etc.
|
417
|
+
expected_builtin_op_order = [3, 6, 6, 1, 6, 6, 22, 9, 9, 25]
|
418
|
+
for i, op in enumerate(self._test_model.subgraphs[0].operators):
|
419
|
+
op_code = self._test_model.operatorCodes[op.opcodeIndex].builtinCode
|
420
|
+
self.assertEqual(op_code, expected_builtin_op_order[i])
|
421
|
+
# Check that the first dequant input is connected to the conv2d output.
|
422
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[1].inputs[0], 7)
|
423
|
+
# Output is a new tensor just added.
|
424
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[1].outputs[0], 13)
|
425
|
+
# Second dequant has new tensors.
|
384
426
|
self.assertEqual(self._test_model.subgraphs[0].operators[2].inputs[0], 13)
|
385
427
|
self.assertEqual(self._test_model.subgraphs[0].operators[2].outputs[0], 14)
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
)
|
428
|
+
# Avgpool's input is second dequant's output.
|
429
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[3].inputs[0], 14)
|
430
|
+
# Avgpool's output remains the same.
|
390
431
|
self.assertEqual(self._test_model.subgraphs[0].operators[3].outputs[0], 8)
|
432
|
+
# Third dequant's output is a new tensor.
|
391
433
|
self.assertEqual(self._test_model.subgraphs[0].operators[4].outputs[0], 15)
|
434
|
+
# Fourth dequant.
|
435
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[5].inputs[0], 15)
|
436
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[5].outputs[0], 16)
|
437
|
+
|
438
|
+
# Avgpool (op_id=1) and reshape (op_id=2) are bumped by 2 due to the two
|
439
|
+
# dequants added after it.
|
440
|
+
expected_op_id_map = [0, 3, 6, 7, 8, 9]
|
441
|
+
self.assertEqual(
|
442
|
+
self._transformation_performer._original_op_id_map[0],
|
443
|
+
expected_op_id_map,
|
444
|
+
)
|
445
|
+
# New dequants are added at these indices.
|
446
|
+
expected_added_op_id_map = [1, 2, 4, 5]
|
447
|
+
self.assertEqual(
|
448
|
+
self._transformation_performer._added_op_id_map[0],
|
449
|
+
expected_added_op_id_map,
|
450
|
+
)
|
451
|
+
|
452
|
+
def test_op_insertion_at_input_and_output(self):
|
453
|
+
"""test for _update_op_id_map."""
|
454
|
+
model = tfl_flatbuffer_utils.read_model(
|
455
|
+
os.path.join(
|
456
|
+
TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
|
457
|
+
)
|
458
|
+
)
|
459
|
+
self._transformation_performer._create_op_id_map(model)
|
460
|
+
instructions = {
|
461
|
+
# Fully_connected: op_id=0, input_tensor_id=0, output_tensor_id=3.
|
462
|
+
# Add a new quantize op to the input of the fully_connected.
|
463
|
+
"serving_default_input_2:0": qtyping.TensorTransformationInsts(
|
464
|
+
tensor_name="serving_default_input_2:0",
|
465
|
+
subgraph_id=0,
|
466
|
+
instructions=[
|
467
|
+
qtyping.TransformationInst(
|
468
|
+
transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
|
469
|
+
tensor_id=0,
|
470
|
+
producer=-1,
|
471
|
+
consumers=[0],
|
472
|
+
parameters=qtyping.UniformQuantParams(
|
473
|
+
8, None, np.array([1]), np.array([0])
|
474
|
+
),
|
475
|
+
),
|
476
|
+
],
|
477
|
+
),
|
478
|
+
# Add a new dequantize op to the output of the fully_connected.
|
479
|
+
"StatefulPartitionedCall:0": qtyping.TensorTransformationInsts(
|
480
|
+
tensor_name="StatefulPartitionedCall:0",
|
481
|
+
subgraph_id=0,
|
482
|
+
instructions=[
|
483
|
+
qtyping.TransformationInst(
|
484
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
485
|
+
tensor_id=3,
|
486
|
+
producer=0,
|
487
|
+
consumers=[-1],
|
488
|
+
parameters=qtyping.UniformQuantParams(
|
489
|
+
8, None, np.array([1]), np.array([0])
|
490
|
+
),
|
491
|
+
),
|
492
|
+
],
|
493
|
+
),
|
494
|
+
}
|
495
|
+
self._transformation_performer.transform_graph(instructions, model)
|
496
|
+
|
497
|
+
# Original fc (op_id=0) should be bumped to op_id=1.
|
498
|
+
self.assertEqual(
|
499
|
+
self._transformation_performer._original_op_id_map[0],
|
500
|
+
[1],
|
501
|
+
)
|
502
|
+
# New quantize added at op_id=0, dequantize added at op_id=1.
|
503
|
+
expected_added_op_id_map = [0, 2]
|
504
|
+
self.assertEqual(
|
505
|
+
self._transformation_performer._added_op_id_map[0],
|
506
|
+
expected_added_op_id_map,
|
507
|
+
)
|
508
|
+
# Quantize, fully_connected, dequantize.
|
509
|
+
expected_builtin_op_order = [114, 9, 6]
|
510
|
+
for i, op in enumerate(model.subgraphs[0].operators):
|
511
|
+
op_code = model.operatorCodes[op.opcodeIndex].builtinCode
|
512
|
+
self.assertEqual(op_code, expected_builtin_op_order[i])
|
513
|
+
|
392
514
|
|
393
515
|
if __name__ == "__main__":
|
394
516
|
googletest.main()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-quantizer-nightly
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.dev20250426
|
4
4
|
Summary: A quantizer for advanced developers to quantize converted AI Edge models.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
|
@@ -21,8 +21,8 @@ ai_edge_quantizer/recipe_manager_test.py,sha256=LulVxsYp6TBGFI2PLCUCd4VsFq8ELpC7
|
|
21
21
|
ai_edge_quantizer/recipe_test.py,sha256=Fg_sfxovI2fRjk5qdu18ghOvXdUvhDR1TxbE0GHDczc,3381
|
22
22
|
ai_edge_quantizer/transformation_instruction_generator.py,sha256=R7A90Qj6iQQROrznXmXLJd-5yXq0PRHbLOdNY51dEu4,27913
|
23
23
|
ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=E0QSDCav6N6izlJ-a1ZJOsb2VEUxuxBmTbt0-EgDdxY,49890
|
24
|
-
ai_edge_quantizer/transformation_performer.py,sha256=
|
25
|
-
ai_edge_quantizer/transformation_performer_test.py,sha256=
|
24
|
+
ai_edge_quantizer/transformation_performer.py,sha256=zAzrQOb2n2IpB3qopmKV59e5E99HmTOL60QTCn9-7kA,12821
|
25
|
+
ai_edge_quantizer/transformation_performer_test.py,sha256=xk6A3LStCyPclN51--9uO7XjSxNfZmpdfvrzOL0maNM,20349
|
26
26
|
ai_edge_quantizer/algorithms/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
|
27
27
|
ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
|
28
28
|
ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py,sha256=Bs9CK7wZAw6jNaZ8xEtbwO2vM34VYXNZSMVWvxJo9nw,9297
|
@@ -66,8 +66,8 @@ ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=x2xA2CFPpe_2trcV8v5xGaBE
|
|
66
66
|
ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=Op3JxtOqlrjzmYF18jnnstL1k9xiY9kKJ8S2vklKGkc,11327
|
67
67
|
ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
|
68
68
|
ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
|
69
|
-
ai_edge_quantizer_nightly-0.1.0.
|
70
|
-
ai_edge_quantizer_nightly-0.1.0.
|
71
|
-
ai_edge_quantizer_nightly-0.1.0.
|
72
|
-
ai_edge_quantizer_nightly-0.1.0.
|
73
|
-
ai_edge_quantizer_nightly-0.1.0.
|
69
|
+
ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
70
|
+
ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info/METADATA,sha256=CIV-3K_joKQSZc9qpwHgYbFHYwaAtPCZWE6yEYWgDkc,1527
|
71
|
+
ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
72
|
+
ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
|
73
|
+
ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info/RECORD,,
|
File without changes
|
File without changes
|