ai-edge-quantizer-nightly 0.1.0.dev20250424__py3-none-any.whl → 0.1.0.dev20250426__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -180,6 +180,38 @@ class TransformationPerformer:
180
180
  )
181
181
  transformation.tensor_id = trans_info.output_tensor_id
182
182
 
183
+ def _get_updated_producer_id(
184
+ self, original_producer_id: int, subgraph_id: int
185
+ ) -> int:
186
+ """Update the producer of a transformation instruction."""
187
+ if original_producer_id is None or original_producer_id < 0:
188
+ producer = -1
189
+ elif original_producer_id < len(self._original_op_id_map[subgraph_id]):
190
+ producer = self._original_op_id_map[subgraph_id][original_producer_id]
191
+ else:
192
+ # If the producer id is not in the original op map, it's an added op,
193
+ # go the added op map to find the producer.
194
+ producer = self._added_op_id_map[subgraph_id][
195
+ original_producer_id - len(self._original_op_id_map[subgraph_id])
196
+ ]
197
+ return producer
198
+
199
+ def _get_updated_consumer_ids(
200
+ self,
201
+ original_consumer_ids: list[int],
202
+ subgraph_id: int,
203
+ ) -> list[int]:
204
+ """Update the consumers of a transformation instruction."""
205
+ consumers = []
206
+ for original_op_id in original_consumer_ids:
207
+ new_consumer_id = (
208
+ -1
209
+ if original_op_id == -1
210
+ else self._original_op_id_map[subgraph_id][original_op_id]
211
+ )
212
+ consumers.append(new_consumer_id)
213
+ return consumers
214
+
183
215
  def _apply_single_transformation(
184
216
  self,
185
217
  transformation_inst: qtyping.TensorTransformationInsts,
@@ -198,28 +230,12 @@ class TransformationPerformer:
198
230
  None, update the transformation_inst & tflite_model in place
199
231
  """
200
232
  instruction = transformation_inst.instructions[transformation_index]
201
- if not instruction.producer or instruction.producer < 0:
202
- producer = -1
203
- elif instruction.producer < len(
204
- self._original_op_id_map[transformation_inst.subgraph_id]
205
- ):
206
- producer = self._original_op_id_map[transformation_inst.subgraph_id][
207
- instruction.producer
208
- ]
209
- else:
210
- # if the producer id is not in the original op map, it's an added op,
211
- # go the corresponding new maps
212
- producer = self._added_op_id_map[transformation_inst.subgraph_id][
213
- instruction.producer
214
- - len(self._original_op_id_map[transformation_inst.subgraph_id])
215
- ]
216
- consumers = []
217
- for original_op_id in instruction.consumers:
218
- consumers.append(
219
- self._original_op_id_map[transformation_inst.subgraph_id][
220
- original_op_id
221
- ]
222
- )
233
+ producer = self._get_updated_producer_id(
234
+ instruction.producer, transformation_inst.subgraph_id
235
+ )
236
+ consumers = self._get_updated_consumer_ids(
237
+ instruction.consumers, transformation_inst.subgraph_id
238
+ )
223
239
  trans_info = self._transformation_registration[instruction.transformation](
224
240
  transformation_utils.TransformationInput(
225
241
  instruction.tensor_id,
@@ -239,7 +255,12 @@ class TransformationPerformer:
239
255
  )
240
256
  self._update_op_id_map(
241
257
  transformation_inst.subgraph_id,
242
- min(instruction.consumers),
258
+ # The added op must be right before the most immediate consumer, unless
259
+ # the consumer is the graph output (id=-1), then use the producer's
260
+ # index instead.
261
+ min(instruction.consumers)
262
+ if min(instruction.consumers) >= 0
263
+ else instruction.producer + 1,
243
264
  trans_info.num_ops_added,
244
265
  )
245
266
 
@@ -112,6 +112,32 @@ class TransformationPerformerTest(parameterized.TestCase):
112
112
  for index, op_id in enumerate(op_id_map[0]):
113
113
  self.assertEqual(op_id, index)
114
114
 
115
+ def test_update_op_id_map_not_changing_value_single_op_model(self):
116
+ """test for _update_op_id_map."""
117
+ model = tfl_flatbuffer_utils.read_model(
118
+ os.path.join(
119
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
120
+ )
121
+ )
122
+ self._transformation_performer._create_op_id_map(model)
123
+ instruction = qtyping.TransformationInst(
124
+ transformation=qtyping.QuantTransformation.QUANTIZE_TENSOR,
125
+ tensor_id=0,
126
+ producer=0,
127
+ consumers=[-1],
128
+ parameters=qtyping.UniformQuantParams(
129
+ 8, None, np.array([1]), np.array([0])
130
+ ),
131
+ )
132
+ producer = self._transformation_performer._get_updated_producer_id(
133
+ instruction.producer, 0
134
+ )
135
+ consumers = self._transformation_performer._get_updated_consumer_ids(
136
+ instruction.consumers, 0
137
+ )
138
+ self.assertEqual(producer, 0)
139
+ self.assertEqual(consumers, [-1])
140
+
115
141
  @parameterized.named_parameters(
116
142
  dict(
117
143
  testcase_name="test_no_update",
@@ -271,7 +297,7 @@ class TransformationPerformerTest(parameterized.TestCase):
271
297
  expected_added_op_id_map,
272
298
  )
273
299
 
274
- def test__update_instructions_updates_tensor_id_after_duplicate_tensor(self):
300
+ def test_update_instructions_updates_tensor_id_after_duplicate_tensor(self):
275
301
  def get_test_instruction(transformation, consumers):
276
302
  return qtyping.TransformationInst(
277
303
  transformation=transformation,
@@ -325,6 +351,8 @@ class TransformationPerformerTest(parameterized.TestCase):
325
351
  tensor_name="sequential/conv2d/Relu;sequential/conv2d/BiasAdd;"
326
352
  + "sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1",
327
353
  subgraph_id=0,
354
+ # Conv2d: op_id=0, output_tensor_id=7.
355
+ # This should add two sequential dequants after the conv2d.
328
356
  instructions=[
329
357
  qtyping.TransformationInst(
330
358
  transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
@@ -349,6 +377,8 @@ class TransformationPerformerTest(parameterized.TestCase):
349
377
  "sequential/average_pooling2d/AvgPool": qtyping.TensorTransformationInsts(
350
378
  tensor_name="sequential/average_pooling2d/AvgPool",
351
379
  subgraph_id=0,
380
+ # Avg_pool: op_id=1, output_tensor_id=8.
381
+ # This should add two sequential dequants after the avg_pool.
352
382
  instructions=[
353
383
  qtyping.TransformationInst(
354
384
  transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
@@ -376,19 +406,111 @@ class TransformationPerformerTest(parameterized.TestCase):
376
406
  )
377
407
  self.assertLen(self._test_model.subgraphs, 1)
378
408
  self.assertLen(self._test_model.subgraphs[0].operators, 10)
409
+ # The original model has 13 tensors, each dequant adds 1 tensor.
379
410
  self.assertLen(self._test_model.subgraphs[0].tensors, 17)
411
+ # Check that the dequant opcode is added to the model.
380
412
  self.assertEqual(
381
413
  self._test_model.subgraphs[0].operators[1].opcodeIndex,
382
414
  len(self._test_model.operatorCodes) - 1,
383
415
  )
416
+ # Conv2d, dequant, dequant, avgpool, dequant, dequant, etc.
417
+ expected_builtin_op_order = [3, 6, 6, 1, 6, 6, 22, 9, 9, 25]
418
+ for i, op in enumerate(self._test_model.subgraphs[0].operators):
419
+ op_code = self._test_model.operatorCodes[op.opcodeIndex].builtinCode
420
+ self.assertEqual(op_code, expected_builtin_op_order[i])
421
+ # Check that the first dequant input is connected to the conv2d output.
422
+ self.assertEqual(self._test_model.subgraphs[0].operators[1].inputs[0], 7)
423
+ # Output is a new tensor just added.
424
+ self.assertEqual(self._test_model.subgraphs[0].operators[1].outputs[0], 13)
425
+ # Second dequant has new tensors.
384
426
  self.assertEqual(self._test_model.subgraphs[0].operators[2].inputs[0], 13)
385
427
  self.assertEqual(self._test_model.subgraphs[0].operators[2].outputs[0], 14)
386
- self.assertEqual(
387
- self._test_model.subgraphs[0].operators[2].outputs[0],
388
- self._test_model.subgraphs[0].operators[3].inputs[0],
389
- )
428
+ # Avgpool's input is second dequant's output.
429
+ self.assertEqual(self._test_model.subgraphs[0].operators[3].inputs[0], 14)
430
+ # Avgpool's output remains the same.
390
431
  self.assertEqual(self._test_model.subgraphs[0].operators[3].outputs[0], 8)
432
+ # Third dequant's output is a new tensor.
391
433
  self.assertEqual(self._test_model.subgraphs[0].operators[4].outputs[0], 15)
434
+ # Fourth dequant.
435
+ self.assertEqual(self._test_model.subgraphs[0].operators[5].inputs[0], 15)
436
+ self.assertEqual(self._test_model.subgraphs[0].operators[5].outputs[0], 16)
437
+
438
+ # Avgpool (op_id=1) and reshape (op_id=2) are bumped by 2 due to the two
439
+ # dequants added after it.
440
+ expected_op_id_map = [0, 3, 6, 7, 8, 9]
441
+ self.assertEqual(
442
+ self._transformation_performer._original_op_id_map[0],
443
+ expected_op_id_map,
444
+ )
445
+ # New dequants are added at these indices.
446
+ expected_added_op_id_map = [1, 2, 4, 5]
447
+ self.assertEqual(
448
+ self._transformation_performer._added_op_id_map[0],
449
+ expected_added_op_id_map,
450
+ )
451
+
452
+ def test_op_insertion_at_input_and_output(self):
453
+ """test for _update_op_id_map."""
454
+ model = tfl_flatbuffer_utils.read_model(
455
+ os.path.join(
456
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
457
+ )
458
+ )
459
+ self._transformation_performer._create_op_id_map(model)
460
+ instructions = {
461
+ # Fully_connected: op_id=0, input_tensor_id=0, output_tensor_id=3.
462
+ # Add a new quantize op to the input of the fully_connected.
463
+ "serving_default_input_2:0": qtyping.TensorTransformationInsts(
464
+ tensor_name="serving_default_input_2:0",
465
+ subgraph_id=0,
466
+ instructions=[
467
+ qtyping.TransformationInst(
468
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
469
+ tensor_id=0,
470
+ producer=-1,
471
+ consumers=[0],
472
+ parameters=qtyping.UniformQuantParams(
473
+ 8, None, np.array([1]), np.array([0])
474
+ ),
475
+ ),
476
+ ],
477
+ ),
478
+ # Add a new dequantize op to the output of the fully_connected.
479
+ "StatefulPartitionedCall:0": qtyping.TensorTransformationInsts(
480
+ tensor_name="StatefulPartitionedCall:0",
481
+ subgraph_id=0,
482
+ instructions=[
483
+ qtyping.TransformationInst(
484
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
485
+ tensor_id=3,
486
+ producer=0,
487
+ consumers=[-1],
488
+ parameters=qtyping.UniformQuantParams(
489
+ 8, None, np.array([1]), np.array([0])
490
+ ),
491
+ ),
492
+ ],
493
+ ),
494
+ }
495
+ self._transformation_performer.transform_graph(instructions, model)
496
+
497
+ # Original fc (op_id=0) should be bumped to op_id=1.
498
+ self.assertEqual(
499
+ self._transformation_performer._original_op_id_map[0],
500
+ [1],
501
+ )
502
+ # New quantize added at op_id=0, dequantize added at op_id=1.
503
+ expected_added_op_id_map = [0, 2]
504
+ self.assertEqual(
505
+ self._transformation_performer._added_op_id_map[0],
506
+ expected_added_op_id_map,
507
+ )
508
+ # Quantize, fully_connected, dequantize.
509
+ expected_builtin_op_order = [114, 9, 6]
510
+ for i, op in enumerate(model.subgraphs[0].operators):
511
+ op_code = model.operatorCodes[op.opcodeIndex].builtinCode
512
+ self.assertEqual(op_code, expected_builtin_op_order[i])
513
+
392
514
 
393
515
  if __name__ == "__main__":
394
516
  googletest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.1.0.dev20250424
3
+ Version: 0.1.0.dev20250426
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -21,8 +21,8 @@ ai_edge_quantizer/recipe_manager_test.py,sha256=LulVxsYp6TBGFI2PLCUCd4VsFq8ELpC7
21
21
  ai_edge_quantizer/recipe_test.py,sha256=Fg_sfxovI2fRjk5qdu18ghOvXdUvhDR1TxbE0GHDczc,3381
22
22
  ai_edge_quantizer/transformation_instruction_generator.py,sha256=R7A90Qj6iQQROrznXmXLJd-5yXq0PRHbLOdNY51dEu4,27913
23
23
  ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=E0QSDCav6N6izlJ-a1ZJOsb2VEUxuxBmTbt0-EgDdxY,49890
24
- ai_edge_quantizer/transformation_performer.py,sha256=PIrylVhuWZCpnXEl7qSw2BlxRrY7lqj6aQvagJVCVts,11989
25
- ai_edge_quantizer/transformation_performer_test.py,sha256=n9xI6QMqvrj9KUul2LuObIsF7YdLSqgMg4X6d4BkFP8,15219
24
+ ai_edge_quantizer/transformation_performer.py,sha256=zAzrQOb2n2IpB3qopmKV59e5E99HmTOL60QTCn9-7kA,12821
25
+ ai_edge_quantizer/transformation_performer_test.py,sha256=xk6A3LStCyPclN51--9uO7XjSxNfZmpdfvrzOL0maNM,20349
26
26
  ai_edge_quantizer/algorithms/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
27
27
  ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
28
28
  ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py,sha256=Bs9CK7wZAw6jNaZ8xEtbwO2vM34VYXNZSMVWvxJo9nw,9297
@@ -66,8 +66,8 @@ ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=x2xA2CFPpe_2trcV8v5xGaBE
66
66
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=Op3JxtOqlrjzmYF18jnnstL1k9xiY9kKJ8S2vklKGkc,11327
67
67
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
68
68
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
69
- ai_edge_quantizer_nightly-0.1.0.dev20250424.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
70
- ai_edge_quantizer_nightly-0.1.0.dev20250424.dist-info/METADATA,sha256=ymVF3awwYNfrNqBMSN903Tnc_Catt8qj7xvktKDsnoU,1527
71
- ai_edge_quantizer_nightly-0.1.0.dev20250424.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
72
- ai_edge_quantizer_nightly-0.1.0.dev20250424.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
73
- ai_edge_quantizer_nightly-0.1.0.dev20250424.dist-info/RECORD,,
69
+ ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
70
+ ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info/METADATA,sha256=CIV-3K_joKQSZc9qpwHgYbFHYwaAtPCZWE6yEYWgDkc,1527
71
+ ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
72
+ ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
73
+ ai_edge_quantizer_nightly-0.1.0.dev20250426.dist-info/RECORD,,