ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -23,10 +23,13 @@ from collections.abc import Iterator
23
23
  import dataclasses
24
24
  from typing import Optional
25
25
  from ai_edge_quantizer import qtyping
26
+ from ai_edge_quantizer.algorithms.utils import common_utils
27
+ from ai_edge_quantizer.utils import constrained_ops_utils
26
28
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
27
29
  from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
28
30
 
29
31
 
32
+ _OpQuantConstraint = common_utils.OpQuantConstraint
30
33
  _QuantTransformation = qtyping.QuantTransformation
31
34
 
32
35
 
@@ -51,6 +54,15 @@ def check_horizontal_optimization(
51
54
  Returns:
52
55
  True if the two transformations can be merged, False otherwise
53
56
  """
57
+ if (
58
+ isinstance(param1.parameters, qtyping.UniformQuantParams)
59
+ and param1.parameters.hadamard is not None
60
+ ):
61
+ if (
62
+ isinstance(param2.parameters, qtyping.UniformQuantParams)
63
+ and param2.parameters.hadamard is not None
64
+ ):
65
+ return True
54
66
  return (
55
67
  param1.parameters == param2.parameters
56
68
  and len(param1.transformations) > index
@@ -165,6 +177,16 @@ class TransformationInstructionsGenerator:
165
177
  else:
166
178
  self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
167
179
  self._create_tensor_name_to_graph_info_map()
180
+ self._same_as_input_scale_ops = (
181
+ constrained_ops_utils.get_constrained_op_list(
182
+ _OpQuantConstraint.SAME_AS_INPUT_SCALE
183
+ )
184
+ )
185
+ self._same_as_output_scale_ops = (
186
+ constrained_ops_utils.get_constrained_op_list(
187
+ _OpQuantConstraint.SAME_AS_OUTPUT_SCALE
188
+ )
189
+ )
168
190
 
169
191
  @dataclasses.dataclass(frozen=True)
170
192
  class TensorGraphInfo:
@@ -186,11 +208,13 @@ class TransformationInstructionsGenerator:
186
208
  A tuple of tensor_name and TensorGraphInfo.
187
209
  """
188
210
  for tensor_id, tensor in enumerate(subgraph.tensors):
189
- consumers = [
190
- op_id
191
- for (op_id, op) in enumerate(subgraph.operators)
192
- if tensor_id in op.inputs
193
- ]
211
+ consumers = []
212
+ for op_id, op in enumerate(subgraph.operators):
213
+ # Some ops may use the same input tensor multiple times,
214
+ # and we should handle each time independently.
215
+ for op_input in op.inputs:
216
+ if op_input == tensor_id:
217
+ consumers.append(op_id)
194
218
  producer = -1
195
219
  for op_id, op in enumerate(subgraph.operators):
196
220
  if tensor_id in op.outputs:
@@ -504,6 +528,89 @@ class TransformationInstructionsGenerator:
504
528
  ):
505
529
  instructions.pop(i)
506
530
 
531
+ def _is_valid_quantize_requantize_pair(
532
+ self,
533
+ instr_0: qtyping.TransformationInst,
534
+ instr_1: qtyping.TransformationInst,
535
+ ) -> bool:
536
+ """Checks if the two instructions form a valid quantize and requantize pair."""
537
+ return (
538
+ instr_0.transformation == _QuantTransformation.QUANTIZE_TENSOR
539
+ and instr_1.transformation == _QuantTransformation.ADD_QUANTIZE
540
+ and instr_0.consumers == instr_1.consumers
541
+ )
542
+
543
+ def _is_op_constrained(
544
+ self, subgraph_id: int, op_index: int
545
+ ) -> bool:
546
+ """Checks if the op has same as input or output scale constraints."""
547
+ op_name = tfl_flatbuffer_utils.get_op_name_by_index(
548
+ self.flatbuffer_model, subgraph_id, op_index
549
+ )
550
+ return (
551
+ op_name in self._same_as_input_scale_ops
552
+ or op_name in self._same_as_output_scale_ops
553
+ )
554
+
555
+ def _are_quant_params_compatible(
556
+ self,
557
+ params_0: qtyping.UniformQuantParams,
558
+ params_1: qtyping.UniformQuantParams,
559
+ ) -> bool:
560
+ """Checks if quant params are the same except for the scale and zero point."""
561
+ ignore_set = {"scale", "zero_point"}
562
+ for field_info in dataclasses.fields(qtyping.UniformQuantParams):
563
+ field_name = field_info.name
564
+ if field_name in ignore_set:
565
+ continue
566
+ if getattr(params_0, field_name) != getattr(params_1, field_name):
567
+ return False
568
+ return True
569
+
570
+ def _eliminate_requantization_for_nonconstrained_provider(
571
+ self, tensor_trans_insts: qtyping.TensorTransformationInsts
572
+ ) -> None:
573
+ """Removes requantization for tensors with a non-constrained provider.
574
+
575
+ Fuses [QUANTIZE_TENSOR, ADD_QUANTIZE] instructions when a tensor has a
576
+ provider op without same as input/ouput scale constrains. Quant params from
577
+ the second instruction are copied to the first one and ADD_QUANTIZE is
578
+ removed.
579
+
580
+ Args:
581
+ tensor_trans_insts: Transformation instructions for a tensor.
582
+ """
583
+ instructions = tensor_trans_insts.instructions
584
+ if instructions is None or len(instructions) != 2:
585
+ return
586
+
587
+ instr_0, instr_1 = instructions
588
+ params_0 = instr_0.parameters
589
+ params_1 = instr_1.parameters
590
+ producer_op_index = instr_0.producer
591
+ if (
592
+ not isinstance(params_0, qtyping.UniformQuantParams)
593
+ or not isinstance(params_1, qtyping.UniformQuantParams)
594
+ or not self._is_valid_quantize_requantize_pair(instr_0, instr_1)
595
+ or not self._are_quant_params_compatible(params_0, params_1)
596
+ # To avoid fusion when subgraph inputs connected to the main subgraph
597
+ # (e.g. while_body), we skip all tensors with no producer.
598
+ or producer_op_index == -1
599
+ # Can't apply fusion to tensors with a constrained producer since that
600
+ # will break the constraint.
601
+ or self._is_op_constrained(
602
+ tensor_trans_insts.subgraph_id, producer_op_index
603
+ )
604
+ ):
605
+ return
606
+
607
+ # Fuse the quantize and requantize.
608
+ instr_0.parameters = dataclasses.replace(
609
+ params_0, scale=params_1.scale, zero_point=params_1.zero_point
610
+ )
611
+ # Remove the requantize instruction.
612
+ instructions.pop(1)
613
+
507
614
  def _quant_params_to_transformation_insts(
508
615
  self,
509
616
  param: qtyping.TensorTransformationParams,
@@ -576,6 +683,12 @@ class TransformationInstructionsGenerator:
576
683
  # will raise an error if the instructions are not valid.
577
684
  self._check_tensor_transformation_instructions_valid(tensor_trans_insts)
578
685
 
686
+ # Remove unnecessary [QUANTIZE_TENSOR, ADD_QUANTIZE] pairs for tensors with
687
+ # providers without same as input/output scale constraints.
688
+ self._eliminate_requantization_for_nonconstrained_provider(
689
+ tensor_trans_insts
690
+ )
691
+
579
692
  return tensor_trans_insts
580
693
 
581
694
  def _split_instructions_by_tensor_duplication(
@@ -671,7 +784,6 @@ class TransformationInstructionsGenerator:
671
784
  """
672
785
  is_tensor_unquantized = False
673
786
  is_tensor_quantized = False
674
- is_operator_emulated = False
675
787
  for instruction in instructions:
676
788
  transform_type = instruction.transformation
677
789
  if transform_type == qtyping.QuantTransformation.NO_QUANTIZE:
@@ -681,17 +793,10 @@ class TransformationInstructionsGenerator:
681
793
  or transform_type == qtyping.QuantTransformation.ADD_DEQUANTIZE
682
794
  ):
683
795
  is_tensor_quantized = True
684
- elif transform_type == qtyping.QuantTransformation.EMULATED_SUBCHANNEL:
685
- is_operator_emulated = True
686
796
  if is_tensor_unquantized and is_tensor_quantized:
687
797
  raise ValueError(
688
798
  "Tensor %s can not be both quantized and unquantized" % tensor_name
689
799
  )
690
- if is_operator_emulated and len(instructions) > 1:
691
- raise ValueError(
692
- "Tensor %s : op replacement transformation can not be combined with"
693
- " other transformations." % tensor_name
694
- )
695
800
 
696
801
  def _check_tensor_transformation_instructions_valid(
697
802
  self,
@@ -15,7 +15,9 @@
15
15
 
16
16
  """Tests for instruction_generator."""
17
17
 
18
+ from collections.abc import Sequence
18
19
  import os
20
+ from typing import Optional
19
21
 
20
22
  import numpy as np
21
23
 
@@ -953,33 +955,6 @@ class InstructionGeneratorTest(parameterized.TestCase):
953
955
  instructions["StatefulPartitionedCall:0"], output_transformation
954
956
  )
955
957
 
956
- def test_raise_error_on_op_replacement_transformation_is_not_unique(self):
957
- test_model_path = os.path.join(
958
- TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
959
- )
960
- quant_parameters = {}
961
- quant_parameters["tfl.quantize"] = qtyping.TensorTransformationParams(
962
- "tfl.quantize",
963
- qtyping.OpToTensorParams(
964
- subgraph_op_id=0,
965
- transformations=[
966
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
967
- qtyping.QuantTransformation.EMULATED_SUBCHANNEL,
968
- ],
969
- parameters=qtyping.UniformQuantParams(
970
- 8, None, np.array([1]), np.array([0])
971
- ),
972
- ),
973
- [],
974
- )
975
- ins_gen = instruction_generator.TransformationInstructionsGenerator(
976
- test_model_path
977
- )
978
- with self.assertRaisesRegex(
979
- ValueError, "op replacement transformation can not be combined"
980
- ):
981
- ins_gen.quant_params_to_transformation_insts(quant_parameters)
982
-
983
958
  def test_raise_error_on_no_quant_conflict(self):
984
959
  test_model_path = os.path.join(
985
960
  TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
@@ -1364,5 +1339,166 @@ class InstructionGeneratorTest(parameterized.TestCase):
1364
1339
  )
1365
1340
 
1366
1341
 
1342
+ class EliminateUnnecessaryRequantizationTest(parameterized.TestCase):
1343
+
1344
+ def setUp(self):
1345
+ super().setUp()
1346
+ self.ins_gen = instruction_generator.TransformationInstructionsGenerator(
1347
+ os.path.join(TEST_DATA_PREFIX_PATH, "tests/models/conv_fc_mnist.tflite")
1348
+ )
1349
+
1350
+ def _get_test_instruction(
1351
+ self,
1352
+ transformation: qtyping.QuantTransformation,
1353
+ producer: int = -1,
1354
+ consumers: Optional[Sequence[int]] = None,
1355
+ qparams: Optional[qtyping.UniformQuantParams] = None,
1356
+ ) -> qtyping.TransformationInst:
1357
+ if consumers is None:
1358
+ consumers = []
1359
+ if qparams is None:
1360
+ qparams = qtyping.UniformQuantParams(
1361
+ num_bits=8,
1362
+ quantized_dimension=None,
1363
+ scale=np.array([1]),
1364
+ zero_point=np.array([0]),
1365
+ )
1366
+ return qtyping.TransformationInst(
1367
+ transformation=transformation,
1368
+ producer=producer,
1369
+ consumers=consumers,
1370
+ parameters=qparams,
1371
+ # Dummy values below.
1372
+ tensor_id=0,
1373
+ )
1374
+
1375
+ def _create_test_insts(
1376
+ self, instructions: list[qtyping.TransformationInst]
1377
+ ) -> qtyping.TensorTransformationInsts:
1378
+ return qtyping.TensorTransformationInsts(
1379
+ tensor_name="test_tensor", subgraph_id=0, instructions=instructions
1380
+ )
1381
+
1382
+ def test_no_fusion_when_too_few_instructions(self):
1383
+ tensor_insts = self._create_test_insts([
1384
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
1385
+ ])
1386
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1387
+ tensor_insts
1388
+ )
1389
+ self.assertLen(tensor_insts.instructions, 1)
1390
+
1391
+ def test_no_fusion_when_too_many_instructions(self):
1392
+ tensor_insts = self._create_test_insts([
1393
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
1394
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE),
1395
+ self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
1396
+ ])
1397
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1398
+ tensor_insts
1399
+ )
1400
+ self.assertLen(tensor_insts.instructions, 3)
1401
+
1402
+ def test_no_fusion_when_invalid_transformation_pair(self):
1403
+ tensor_insts = self._create_test_insts([
1404
+ self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
1405
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE),
1406
+ ])
1407
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1408
+ tensor_insts
1409
+ )
1410
+ self.assertLen(tensor_insts.instructions, 2)
1411
+
1412
+ def test_no_fusion_when_consumers_mismatch(self):
1413
+ tensor_insts = self._create_test_insts([
1414
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[0]),
1415
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1]),
1416
+ ])
1417
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1418
+ tensor_insts
1419
+ )
1420
+ self.assertLen(tensor_insts.instructions, 2)
1421
+
1422
+ def test_no_fusion_when_no_producer(self):
1423
+ producer = -1
1424
+ tensor_insts = self._create_test_insts([
1425
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer),
1426
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer),
1427
+ ])
1428
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1429
+ tensor_insts
1430
+ )
1431
+ self.assertLen(tensor_insts.instructions, 2)
1432
+
1433
+ def test_no_fusion_when_quant_params_are_incompatible(self):
1434
+ params_8_bits = qtyping.UniformQuantParams(
1435
+ 8, None, np.array([1]), np.array([0])
1436
+ )
1437
+ params_16_bits = qtyping.UniformQuantParams(
1438
+ 16, None, np.array([1]), np.array([0])
1439
+ )
1440
+ tensor_insts = self._create_test_insts([
1441
+ self._get_test_instruction(
1442
+ _QTransf.QUANTIZE_TENSOR, qparams=params_8_bits
1443
+ ),
1444
+ self._get_test_instruction(
1445
+ _QTransf.ADD_QUANTIZE, qparams=params_16_bits
1446
+ ),
1447
+ ])
1448
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1449
+ tensor_insts
1450
+ )
1451
+ self.assertLen(tensor_insts.instructions, 2)
1452
+
1453
+ def test_no_fusion_when_producer_constrained(self):
1454
+ # Reshape op (op index 2) has same as input scale constraint.
1455
+ tensor_insts = self._create_test_insts([
1456
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer=2),
1457
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer=2),
1458
+ ])
1459
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1460
+ tensor_insts
1461
+ )
1462
+ self.assertLen(tensor_insts.instructions, 2)
1463
+
1464
+ def test_fusion_succeeds(self):
1465
+ producer = 0
1466
+ consumers = [1]
1467
+ params_0 = qtyping.UniformQuantParams(
1468
+ num_bits=8,
1469
+ quantized_dimension=None,
1470
+ scale=np.array([1]),
1471
+ zero_point=np.array([0]),
1472
+ )
1473
+ params_1 = qtyping.UniformQuantParams(
1474
+ num_bits=8,
1475
+ quantized_dimension=None,
1476
+ scale=np.array([2]),
1477
+ zero_point=np.array([1]),
1478
+ )
1479
+ inst_0 = self._get_test_instruction(
1480
+ _QTransf.QUANTIZE_TENSOR, producer, consumers, params_0
1481
+ )
1482
+ inst_1 = self._get_test_instruction(
1483
+ _QTransf.ADD_QUANTIZE, producer, consumers, params_1
1484
+ )
1485
+ tensor_insts = self._create_test_insts([inst_0, inst_1])
1486
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1487
+ tensor_insts
1488
+ )
1489
+
1490
+ self.assertLen(tensor_insts.instructions, 1)
1491
+ result_inst = tensor_insts.instructions[0]
1492
+ self.assertEqual(result_inst.transformation, _QTransf.QUANTIZE_TENSOR)
1493
+
1494
+ result_params = result_inst.parameters
1495
+ # Explicitly narrow the type for pytype.
1496
+ if not isinstance(result_params, qtyping.UniformQuantParams):
1497
+ self.fail("Fused instruction parameters are not UniformQuantParams")
1498
+
1499
+ self.assertEqual(result_params.scale, params_1.scale)
1500
+ self.assertEqual(result_params.zero_point, params_1.zero_point)
1501
+
1502
+
1367
1503
  if __name__ == "__main__":
1368
1504
  googletest.main()
@@ -24,7 +24,8 @@ from ai_edge_quantizer import qtyping
24
24
  from ai_edge_quantizer.transformations import dequant_insert
25
25
  from ai_edge_quantizer.transformations import duplicate_buffer
26
26
  from ai_edge_quantizer.transformations import duplicate_tensor
27
- from ai_edge_quantizer.transformations import emulated_subchannel
27
+ from ai_edge_quantizer.transformations import insert_decomposed_hadamard_rotation
28
+ from ai_edge_quantizer.transformations import insert_hadamard_rotation
28
29
  from ai_edge_quantizer.transformations import quant_insert
29
30
  from ai_edge_quantizer.transformations import quantize_tensor
30
31
  from ai_edge_quantizer.transformations import transformation_utils
@@ -71,7 +72,7 @@ class TransformationPerformer:
71
72
  quantize_tensor.quantize_tensor
72
73
  ),
73
74
  qtyping.QuantTransformation.EMULATED_SUBCHANNEL: (
74
- emulated_subchannel.emulated_subchannel
75
+ transformation_utils.raise_deprecated_error
75
76
  ),
76
77
  qtyping.QuantTransformation.ADD_QUANTIZE: quant_insert.insert_quant,
77
78
  qtyping.QuantTransformation.DUPLICATE_BUFFER: (
@@ -80,6 +81,12 @@ class TransformationPerformer:
80
81
  qtyping.QuantTransformation.DUPLICATE_TENSOR: (
81
82
  duplicate_tensor.duplicate_tensor
82
83
  ),
84
+ qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION: (
85
+ insert_hadamard_rotation.insert_hadamard_rotation
86
+ ),
87
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION: (
88
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation
89
+ ),
83
90
  }
84
91
  # transformations are seprated in two categories:
85
92
  # op_insertion_transformations are transformations that only insert ops
@@ -91,6 +98,8 @@ class TransformationPerformer:
91
98
  qtyping.QuantTransformation.ADD_QUANTIZE,
92
99
  qtyping.QuantTransformation.DUPLICATE_BUFFER,
93
100
  qtyping.QuantTransformation.DUPLICATE_TENSOR,
101
+ qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
102
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION,
94
103
  ])
95
104
  self._op_replacement_transformations = set(
96
105
  [qtyping.QuantTransformation.EMULATED_SUBCHANNEL]
@@ -180,6 +189,38 @@ class TransformationPerformer:
180
189
  )
181
190
  transformation.tensor_id = trans_info.output_tensor_id
182
191
 
192
+ def _get_updated_producer_id(
193
+ self, original_producer_id: int, subgraph_id: int
194
+ ) -> int:
195
+ """Update the producer of a transformation instruction."""
196
+ if original_producer_id is None or original_producer_id < 0:
197
+ producer = -1
198
+ elif original_producer_id < len(self._original_op_id_map[subgraph_id]):
199
+ producer = self._original_op_id_map[subgraph_id][original_producer_id]
200
+ else:
201
+ # If the producer id is not in the original op map, it's an added op,
202
+ # go the added op map to find the producer.
203
+ producer = self._added_op_id_map[subgraph_id][
204
+ original_producer_id - len(self._original_op_id_map[subgraph_id])
205
+ ]
206
+ return producer
207
+
208
+ def _get_updated_consumer_ids(
209
+ self,
210
+ original_consumer_ids: list[int],
211
+ subgraph_id: int,
212
+ ) -> list[int]:
213
+ """Update the consumers of a transformation instruction."""
214
+ consumers = []
215
+ for original_op_id in original_consumer_ids:
216
+ new_consumer_id = (
217
+ -1
218
+ if original_op_id == -1
219
+ else self._original_op_id_map[subgraph_id][original_op_id]
220
+ )
221
+ consumers.append(new_consumer_id)
222
+ return consumers
223
+
183
224
  def _apply_single_transformation(
184
225
  self,
185
226
  transformation_inst: qtyping.TensorTransformationInsts,
@@ -198,28 +239,12 @@ class TransformationPerformer:
198
239
  None, update the transformation_inst & tflite_model in place
199
240
  """
200
241
  instruction = transformation_inst.instructions[transformation_index]
201
- if not instruction.producer or instruction.producer < 0:
202
- producer = -1
203
- elif instruction.producer < len(
204
- self._original_op_id_map[transformation_inst.subgraph_id]
205
- ):
206
- producer = self._original_op_id_map[transformation_inst.subgraph_id][
207
- instruction.producer
208
- ]
209
- else:
210
- # if the producer id is not in the original op map, it's an added op,
211
- # go the corresponding new maps
212
- producer = self._added_op_id_map[transformation_inst.subgraph_id][
213
- instruction.producer
214
- - len(self._original_op_id_map[transformation_inst.subgraph_id])
215
- ]
216
- consumers = []
217
- for original_op_id in instruction.consumers:
218
- consumers.append(
219
- self._original_op_id_map[transformation_inst.subgraph_id][
220
- original_op_id
221
- ]
222
- )
242
+ producer = self._get_updated_producer_id(
243
+ instruction.producer, transformation_inst.subgraph_id
244
+ )
245
+ consumers = self._get_updated_consumer_ids(
246
+ instruction.consumers, transformation_inst.subgraph_id
247
+ )
223
248
  trans_info = self._transformation_registration[instruction.transformation](
224
249
  transformation_utils.TransformationInput(
225
250
  instruction.tensor_id,
@@ -239,7 +264,12 @@ class TransformationPerformer:
239
264
  )
240
265
  self._update_op_id_map(
241
266
  transformation_inst.subgraph_id,
242
- min(instruction.consumers),
267
+ # The added op must be right before the most immediate consumer, unless
268
+ # the consumer is the graph output (id=-1), then use the producer's
269
+ # index instead.
270
+ min(instruction.consumers)
271
+ if min(instruction.consumers) >= 0
272
+ else instruction.producer + 1,
243
273
  trans_info.num_ops_added,
244
274
  )
245
275