ai-edge-quantizer-nightly 0.4.0.dev20250829__py3-none-any.whl → 0.4.0.dev20250831__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,10 +23,13 @@ from collections.abc import Iterator
23
23
  import dataclasses
24
24
  from typing import Optional
25
25
  from ai_edge_quantizer import qtyping
26
+ from ai_edge_quantizer.algorithms.utils import common_utils
27
+ from ai_edge_quantizer.utils import constrained_ops_utils
26
28
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
27
29
  from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
28
30
 
29
31
 
32
+ _OpQuantConstraint = common_utils.OpQuantConstraint
30
33
  _QuantTransformation = qtyping.QuantTransformation
31
34
 
32
35
 
@@ -165,6 +168,16 @@ class TransformationInstructionsGenerator:
165
168
  else:
166
169
  self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
167
170
  self._create_tensor_name_to_graph_info_map()
171
+ self._same_as_input_scale_ops = (
172
+ constrained_ops_utils.get_constrained_op_list(
173
+ _OpQuantConstraint.SAME_AS_INPUT_SCALE
174
+ )
175
+ )
176
+ self._same_as_output_scale_ops = (
177
+ constrained_ops_utils.get_constrained_op_list(
178
+ _OpQuantConstraint.SAME_AS_OUTPUT_SCALE
179
+ )
180
+ )
168
181
 
169
182
  @dataclasses.dataclass(frozen=True)
170
183
  class TensorGraphInfo:
@@ -506,6 +519,89 @@ class TransformationInstructionsGenerator:
506
519
  ):
507
520
  instructions.pop(i)
508
521
 
522
+ def _is_valid_quantize_requantize_pair(
523
+ self,
524
+ instr_0: qtyping.TransformationInst,
525
+ instr_1: qtyping.TransformationInst,
526
+ ) -> bool:
527
+ """Checks if the two instructions form a valid quantize and requantize pair."""
528
+ return (
529
+ instr_0.transformation == _QuantTransformation.QUANTIZE_TENSOR
530
+ and instr_1.transformation == _QuantTransformation.ADD_QUANTIZE
531
+ and instr_0.consumers == instr_1.consumers
532
+ )
533
+
534
+ def _is_op_constrained(
535
+ self, subgraph_id: int, op_index: int
536
+ ) -> bool:
537
+ """Checks if the op has same as input or output scale constraints."""
538
+ op_name = tfl_flatbuffer_utils.get_op_name_by_index(
539
+ self.flatbuffer_model, subgraph_id, op_index
540
+ )
541
+ return (
542
+ op_name in self._same_as_input_scale_ops
543
+ or op_name in self._same_as_output_scale_ops
544
+ )
545
+
546
+ def _are_quant_params_compatible(
547
+ self,
548
+ params_0: qtyping.UniformQuantParams,
549
+ params_1: qtyping.UniformQuantParams,
550
+ ) -> bool:
551
+ """Checks if quant params are the same except for the scale and zero point."""
552
+ ignore_set = {"scale", "zero_point"}
553
+ for field_info in dataclasses.fields(qtyping.UniformQuantParams):
554
+ field_name = field_info.name
555
+ if field_name in ignore_set:
556
+ continue
557
+ if getattr(params_0, field_name) != getattr(params_1, field_name):
558
+ return False
559
+ return True
560
+
561
+ def _eliminate_requantization_for_nonconstrained_provider(
562
+ self, tensor_trans_insts: qtyping.TensorTransformationInsts
563
+ ) -> None:
564
+ """Removes requantization for tensors with a non-constrained provider.
565
+
566
+ Fuses [QUANTIZE_TENSOR, ADD_QUANTIZE] instructions when a tensor has a
567
+ provider op without same as input/ouput scale constrains. Quant params from
568
+ the second instruction are copied to the first one and ADD_QUANTIZE is
569
+ removed.
570
+
571
+ Args:
572
+ tensor_trans_insts: Transformation instructions for a tensor.
573
+ """
574
+ instructions = tensor_trans_insts.instructions
575
+ if instructions is None or len(instructions) != 2:
576
+ return
577
+
578
+ instr_0, instr_1 = instructions
579
+ params_0 = instr_0.parameters
580
+ params_1 = instr_1.parameters
581
+ producer_op_index = instr_0.producer
582
+ if (
583
+ not isinstance(params_0, qtyping.UniformQuantParams)
584
+ or not isinstance(params_1, qtyping.UniformQuantParams)
585
+ or not self._is_valid_quantize_requantize_pair(instr_0, instr_1)
586
+ or not self._are_quant_params_compatible(params_0, params_1)
587
+ # To avoid fusion when subgraph inputs connected to the main subgraph
588
+ # (e.g. while_body), we skip all tensors with no producer.
589
+ or producer_op_index == -1
590
+ # Can't apply fusion to tensors with a constrained producer since that
591
+ # will break the constraint.
592
+ or self._is_op_constrained(
593
+ tensor_trans_insts.subgraph_id, producer_op_index
594
+ )
595
+ ):
596
+ return
597
+
598
+ # Fuse the quantize and requantize.
599
+ instr_0.parameters = dataclasses.replace(
600
+ params_0, scale=params_1.scale, zero_point=params_1.zero_point
601
+ )
602
+ # Remove the requantize instruction.
603
+ instructions.pop(1)
604
+
509
605
  def _quant_params_to_transformation_insts(
510
606
  self,
511
607
  param: qtyping.TensorTransformationParams,
@@ -578,6 +674,12 @@ class TransformationInstructionsGenerator:
578
674
  # will raise an error if the instructions are not valid.
579
675
  self._check_tensor_transformation_instructions_valid(tensor_trans_insts)
580
676
 
677
+ # Remove unnecessary [QUANTIZE_TENSOR, ADD_QUANTIZE] pairs for tensors with
678
+ # providers without same as input/output scale constraints.
679
+ self._eliminate_requantization_for_nonconstrained_provider(
680
+ tensor_trans_insts
681
+ )
682
+
581
683
  return tensor_trans_insts
582
684
 
583
685
  def _split_instructions_by_tensor_duplication(
@@ -15,7 +15,9 @@
15
15
 
16
16
  """Tests for instruction_generator."""
17
17
 
18
+ from collections.abc import Sequence
18
19
  import os
20
+ from typing import Optional
19
21
 
20
22
  import numpy as np
21
23
 
@@ -1337,5 +1339,166 @@ class InstructionGeneratorTest(parameterized.TestCase):
1337
1339
  )
1338
1340
 
1339
1341
 
1342
+ class EliminateUnnecessaryRequantizationTest(parameterized.TestCase):
1343
+
1344
+ def setUp(self):
1345
+ super().setUp()
1346
+ self.ins_gen = instruction_generator.TransformationInstructionsGenerator(
1347
+ os.path.join(TEST_DATA_PREFIX_PATH, "tests/models/conv_fc_mnist.tflite")
1348
+ )
1349
+
1350
+ def _get_test_instruction(
1351
+ self,
1352
+ transformation: qtyping.QuantTransformation,
1353
+ producer: int = -1,
1354
+ consumers: Optional[Sequence[int]] = None,
1355
+ qparams: Optional[qtyping.UniformQuantParams] = None,
1356
+ ) -> qtyping.TransformationInst:
1357
+ if consumers is None:
1358
+ consumers = []
1359
+ if qparams is None:
1360
+ qparams = qtyping.UniformQuantParams(
1361
+ num_bits=8,
1362
+ quantized_dimension=None,
1363
+ scale=np.array([1]),
1364
+ zero_point=np.array([0]),
1365
+ )
1366
+ return qtyping.TransformationInst(
1367
+ transformation=transformation,
1368
+ producer=producer,
1369
+ consumers=consumers,
1370
+ parameters=qparams,
1371
+ # Dummy values below.
1372
+ tensor_id=0,
1373
+ )
1374
+
1375
+ def _create_test_insts(
1376
+ self, instructions: list[qtyping.TransformationInst]
1377
+ ) -> qtyping.TensorTransformationInsts:
1378
+ return qtyping.TensorTransformationInsts(
1379
+ tensor_name="test_tensor", subgraph_id=0, instructions=instructions
1380
+ )
1381
+
1382
+ def test_no_fusion_when_too_few_instructions(self):
1383
+ tensor_insts = self._create_test_insts([
1384
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
1385
+ ])
1386
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1387
+ tensor_insts
1388
+ )
1389
+ self.assertLen(tensor_insts.instructions, 1)
1390
+
1391
+ def test_no_fusion_when_too_many_instructions(self):
1392
+ tensor_insts = self._create_test_insts([
1393
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
1394
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE),
1395
+ self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
1396
+ ])
1397
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1398
+ tensor_insts
1399
+ )
1400
+ self.assertLen(tensor_insts.instructions, 3)
1401
+
1402
+ def test_no_fusion_when_invalid_transformation_pair(self):
1403
+ tensor_insts = self._create_test_insts([
1404
+ self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
1405
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE),
1406
+ ])
1407
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1408
+ tensor_insts
1409
+ )
1410
+ self.assertLen(tensor_insts.instructions, 2)
1411
+
1412
+ def test_no_fusion_when_consumers_mismatch(self):
1413
+ tensor_insts = self._create_test_insts([
1414
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[0]),
1415
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1]),
1416
+ ])
1417
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1418
+ tensor_insts
1419
+ )
1420
+ self.assertLen(tensor_insts.instructions, 2)
1421
+
1422
+ def test_no_fusion_when_no_producer(self):
1423
+ producer = -1
1424
+ tensor_insts = self._create_test_insts([
1425
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer),
1426
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer),
1427
+ ])
1428
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1429
+ tensor_insts
1430
+ )
1431
+ self.assertLen(tensor_insts.instructions, 2)
1432
+
1433
+ def test_no_fusion_when_quant_params_are_incompatible(self):
1434
+ params_8_bits = qtyping.UniformQuantParams(
1435
+ 8, None, np.array([1]), np.array([0])
1436
+ )
1437
+ params_16_bits = qtyping.UniformQuantParams(
1438
+ 16, None, np.array([1]), np.array([0])
1439
+ )
1440
+ tensor_insts = self._create_test_insts([
1441
+ self._get_test_instruction(
1442
+ _QTransf.QUANTIZE_TENSOR, qparams=params_8_bits
1443
+ ),
1444
+ self._get_test_instruction(
1445
+ _QTransf.ADD_QUANTIZE, qparams=params_16_bits
1446
+ ),
1447
+ ])
1448
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1449
+ tensor_insts
1450
+ )
1451
+ self.assertLen(tensor_insts.instructions, 2)
1452
+
1453
+ def test_no_fusion_when_producer_constrained(self):
1454
+ # Reshape op (op index 2) has same as input scale constraint.
1455
+ tensor_insts = self._create_test_insts([
1456
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer=2),
1457
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer=2),
1458
+ ])
1459
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1460
+ tensor_insts
1461
+ )
1462
+ self.assertLen(tensor_insts.instructions, 2)
1463
+
1464
+ def test_fusion_succeeds(self):
1465
+ producer = 0
1466
+ consumers = [1]
1467
+ params_0 = qtyping.UniformQuantParams(
1468
+ num_bits=8,
1469
+ quantized_dimension=None,
1470
+ scale=np.array([1]),
1471
+ zero_point=np.array([0]),
1472
+ )
1473
+ params_1 = qtyping.UniformQuantParams(
1474
+ num_bits=8,
1475
+ quantized_dimension=None,
1476
+ scale=np.array([2]),
1477
+ zero_point=np.array([1]),
1478
+ )
1479
+ inst_0 = self._get_test_instruction(
1480
+ _QTransf.QUANTIZE_TENSOR, producer, consumers, params_0
1481
+ )
1482
+ inst_1 = self._get_test_instruction(
1483
+ _QTransf.ADD_QUANTIZE, producer, consumers, params_1
1484
+ )
1485
+ tensor_insts = self._create_test_insts([inst_0, inst_1])
1486
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1487
+ tensor_insts
1488
+ )
1489
+
1490
+ self.assertLen(tensor_insts.instructions, 1)
1491
+ result_inst = tensor_insts.instructions[0]
1492
+ self.assertEqual(result_inst.transformation, _QTransf.QUANTIZE_TENSOR)
1493
+
1494
+ result_params = result_inst.parameters
1495
+ # Explicitly narrow the type for pytype.
1496
+ if not isinstance(result_params, qtyping.UniformQuantParams):
1497
+ self.fail("Fused instruction parameters are not UniformQuantParams")
1498
+
1499
+ self.assertEqual(result_params.scale, params_1.scale)
1500
+ self.assertEqual(result_params.zero_point, params_1.zero_point)
1501
+
1502
+
1340
1503
  if __name__ == "__main__":
1341
1504
  googletest.main()
@@ -342,3 +342,12 @@ def get_op_side_effect_subgraphs(
342
342
  return [opts.decompositionSubgraphIndex]
343
343
  # Can add other nested ops here (control flow ops, etc).
344
344
  return []
345
+
346
+
347
+ def get_op_name_by_index(
348
+ flatbuffer_model: Any, subgraph_id: int, op_index: int
349
+ ) -> str:
350
+ """Get the op name from the flatbuffer model."""
351
+ op = flatbuffer_model.subgraphs[subgraph_id].operators[op_index]
352
+ builtin_code = flatbuffer_model.operatorCodes[op.opcodeIndex].builtinCode
353
+ return TFL_OP_CODE_TO_NAME[builtin_code]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.4.0.dev20250829
3
+ Version: 0.4.0.dev20250831
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -19,8 +19,8 @@ ai_edge_quantizer/recipe.py,sha256=MEkfQ2Sg3KAE9LAORHWcbjYNPg06EUbwc1d-VspQA2U,6
19
19
  ai_edge_quantizer/recipe_manager.py,sha256=6dgbE-IZfEetzXH3p3Qm_9eQutNDOpZnMpiaLTbP-ZQ,14744
20
20
  ai_edge_quantizer/recipe_manager_test.py,sha256=H-B75vwPN5ND-nUa3pOXizeHTv4mufPiC5cL_OlDIYU,34040
21
21
  ai_edge_quantizer/recipe_test.py,sha256=GKuo6N65wKLS2xwSpjd-BWWeVRpF1zc7Yt7phSMYSxA,5905
22
- ai_edge_quantizer/transformation_instruction_generator.py,sha256=iMGXy7_ufqgQRzu4drAfO31VGdze35peEFh1BMZlVHk,27714
23
- ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=Zw3EOSnvzjuB4NWeo129eJZxK_EHno9oF9OtEQ-0dnM,48905
22
+ ai_edge_quantizer/transformation_instruction_generator.py,sha256=O0U2aZcB8aXQgOV8r9g1rGNzDUiuI5Ta53XnxZbVffE,31576
23
+ ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=KW5-WoTTo9IqLEVnWxVC8ut8eWLi_91xfKgGqVQ9QDk,54635
24
24
  ai_edge_quantizer/transformation_performer.py,sha256=o4J6OUbI0dLoobVYjkOFw5Po3yH0gZJXrfuTIYais4o,13029
25
25
  ai_edge_quantizer/transformation_performer_test.py,sha256=xk6A3LStCyPclN51--9uO7XjSxNfZmpdfvrzOL0maNM,20349
26
26
  ai_edge_quantizer/algorithms/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
@@ -64,14 +64,14 @@ ai_edge_quantizer/utils/calibration_utils_test.py,sha256=4BlksXl7b4yptL8xPR67hmJ
64
64
  ai_edge_quantizer/utils/constrained_ops_utils.py,sha256=EAITCf7Ku_PFZcw3K-wd-8hGbyuRd5W5UtNdGvalwAE,4478
65
65
  ai_edge_quantizer/utils/constrained_ops_utils_test.py,sha256=6k_AqfB-NmiLkW5WwEV5NSuswFWky2sL0xBGmV6Fdwk,1756
66
66
  ai_edge_quantizer/utils/test_utils.py,sha256=a4Nk-wbeB09dFjTDZiA0K67d26j5DD0UDH_GIVmVG_4,8685
67
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=RL6oq6FzZj-xV0Zgh0UBn7-fOQaRXSxZ-PPG_LmtyUY,11384
67
+ ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=aNtL4dpWH5uGGGlaygnMDkh5llTstbgs5ZxO0JkH5VQ,11718
68
68
  ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
69
69
  ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=EoVjI_hplX_Rml3hfRsGmQOihexmizeJqt4SQcET9aA,14925
70
70
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
71
71
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
72
72
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
73
- ai_edge_quantizer_nightly-0.4.0.dev20250829.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
74
- ai_edge_quantizer_nightly-0.4.0.dev20250829.dist-info/METADATA,sha256=fg5k0J7zQJc0ufSBvuidEZKz57iydiIhRI4teV-7AZI,1535
75
- ai_edge_quantizer_nightly-0.4.0.dev20250829.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
76
- ai_edge_quantizer_nightly-0.4.0.dev20250829.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
77
- ai_edge_quantizer_nightly-0.4.0.dev20250829.dist-info/RECORD,,
73
+ ai_edge_quantizer_nightly-0.4.0.dev20250831.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
74
+ ai_edge_quantizer_nightly-0.4.0.dev20250831.dist-info/METADATA,sha256=TwazFRbRa2j0kWXJB38Tz5tH0ZCeujk2wCBKsnSdk9I,1535
75
+ ai_edge_quantizer_nightly-0.4.0.dev20250831.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
76
+ ai_edge_quantizer_nightly-0.4.0.dev20250831.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
77
+ ai_edge_quantizer_nightly-0.4.0.dev20250831.dist-info/RECORD,,