ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,9 @@
15
15
 
16
16
  """Tests for instruction_generator."""
17
17
 
18
+ from collections.abc import Sequence
18
19
  import os
20
+ from typing import Optional
19
21
 
20
22
  import numpy as np
21
23
 
@@ -27,6 +29,8 @@ from ai_edge_quantizer.utils import test_utils
27
29
 
28
30
  TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(".")
29
31
 
32
+ _QTransf = qtyping.QuantTransformation
33
+
30
34
 
31
35
  class InstructionGeneratorTest(parameterized.TestCase):
32
36
 
@@ -951,33 +955,6 @@ class InstructionGeneratorTest(parameterized.TestCase):
951
955
  instructions["StatefulPartitionedCall:0"], output_transformation
952
956
  )
953
957
 
954
- def test_raise_error_on_op_replacement_transformation_is_not_unique(self):
955
- test_model_path = os.path.join(
956
- TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
957
- )
958
- quant_parameters = {}
959
- quant_parameters["tfl.quantize"] = qtyping.TensorTransformationParams(
960
- "tfl.quantize",
961
- qtyping.OpToTensorParams(
962
- subgraph_op_id=0,
963
- transformations=[
964
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
965
- qtyping.QuantTransformation.EMULATED_SUBCHANNEL,
966
- ],
967
- parameters=qtyping.UniformQuantParams(
968
- 8, None, np.array([1]), np.array([0])
969
- ),
970
- ),
971
- [],
972
- )
973
- ins_gen = instruction_generator.TransformationInstructionsGenerator(
974
- test_model_path
975
- )
976
- with self.assertRaisesRegex(
977
- ValueError, "op replacement transformation can not be combined"
978
- ):
979
- ins_gen.quant_params_to_transformation_insts(quant_parameters)
980
-
981
958
  def test_raise_error_on_no_quant_conflict(self):
982
959
  test_model_path = os.path.join(
983
960
  TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
@@ -1077,6 +1054,451 @@ class InstructionGeneratorTest(parameterized.TestCase):
1077
1054
  self.assertLen(instructions, 1)
1078
1055
  self.assertEqual(instructions["tfl.quantize"], expected_instructions)
1079
1056
 
1057
+ def test_instruction_generator_keeps_buffer_duplication_as_first_transformation(
1058
+ self,
1059
+ ):
1060
+ test_tensor_name = "test_tensor"
1061
+
1062
+ dummy_quant_params = qtyping.UniformQuantParams(
1063
+ 8, None, np.array([1]), np.array([0])
1064
+ )
1065
+ consumer_params_1 = qtyping.OpToTensorParams(
1066
+ subgraph_op_id=0,
1067
+ transformations=[
1068
+ qtyping.QuantTransformation.DUPLICATE_BUFFER,
1069
+ qtyping.QuantTransformation.ADD_QUANTIZE,
1070
+ ],
1071
+ parameters=dummy_quant_params,
1072
+ )
1073
+ consumer_params_2 = qtyping.OpToTensorParams(
1074
+ subgraph_op_id=2,
1075
+ transformations=[
1076
+ qtyping.QuantTransformation.DUPLICATE_BUFFER,
1077
+ qtyping.QuantTransformation.ADD_QUANTIZE,
1078
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
1079
+ ],
1080
+ parameters=dummy_quant_params,
1081
+ )
1082
+
1083
+ quant_parameters = {
1084
+ test_tensor_name: qtyping.TensorTransformationParams(
1085
+ tensor_name=test_tensor_name,
1086
+ producer=None,
1087
+ consumers=[consumer_params_1, consumer_params_2],
1088
+ ),
1089
+ }
1090
+ instruction_gen = (
1091
+ instruction_generator.TransformationInstructionsGenerator()
1092
+ )
1093
+ # _tensor_name_to_graph_info has to have an entry for the test tensor for
1094
+ # `quant_params_to_transformation_insts` to work. But the values do not
1095
+ # matter for this test.
1096
+ instruction_gen._tensor_name_to_graph_info[test_tensor_name] = (
1097
+ instruction_generator.TransformationInstructionsGenerator.TensorGraphInfo(
1098
+ tensor_id=1,
1099
+ subgraph_id=0,
1100
+ producer=0,
1101
+ consumers=[2],
1102
+ )
1103
+ )
1104
+ instructions = instruction_gen.quant_params_to_transformation_insts(
1105
+ quant_parameters
1106
+ )
1107
+ self.assertLen(instructions, 1)
1108
+ instructions = instructions[test_tensor_name].instructions
1109
+ self.assertGreater(len(instructions), 1)
1110
+ self.assertEqual(instructions[0].transformation, _QTransf.DUPLICATE_BUFFER)
1111
+ self.assertNotIn(_QTransf.DUPLICATE_BUFFER, instructions[1:])
1112
+
1113
+ def _get_test_instruction(self, transformation, consumers=None):
1114
+ if consumers is None:
1115
+ consumers = []
1116
+ return qtyping.TransformationInst(
1117
+ transformation=transformation,
1118
+ consumers=consumers,
1119
+ # Dummy values below.
1120
+ tensor_id=0,
1121
+ producer=None,
1122
+ parameters=None,
1123
+ )
1124
+
1125
+ def test__remove_last_tensor_duplication_succeeds(self):
1126
+ tensor_instructions = qtyping.TensorTransformationInsts(
1127
+ tensor_name="test_tensor",
1128
+ subgraph_id=0,
1129
+ instructions=[
1130
+ self._get_test_instruction(_QTransf.DUPLICATE_TENSOR),
1131
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE),
1132
+ self._get_test_instruction(_QTransf.DUPLICATE_TENSOR),
1133
+ self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
1134
+ ],
1135
+ )
1136
+ instruction_gen = (
1137
+ instruction_generator.TransformationInstructionsGenerator()
1138
+ )
1139
+ instruction_gen._remove_last_tensor_duplication(tensor_instructions)
1140
+
1141
+ self.assertLen(tensor_instructions.instructions, 3)
1142
+ expected_transformations = [
1143
+ _QTransf.DUPLICATE_TENSOR,
1144
+ _QTransf.ADD_QUANTIZE,
1145
+ _QTransf.ADD_DEQUANTIZE,
1146
+ ]
1147
+ got_transformations = [
1148
+ instruction.transformation
1149
+ for instruction in tensor_instructions.instructions
1150
+ ]
1151
+ self.assertEqual(got_transformations, expected_transformations)
1152
+
1153
+ def test__remove_unnecessary_buffer_duplication_succeeds(
1154
+ self,
1155
+ ):
1156
+ instructions = [
1157
+ self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[1]),
1158
+ self._get_test_instruction(_QTransf.DUPLICATE_BUFFER, consumers=[1]),
1159
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE),
1160
+ self._get_test_instruction(_QTransf.DUPLICATE_BUFFER, consumers=[1]),
1161
+ self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
1162
+ self._get_test_instruction(_QTransf.DUPLICATE_BUFFER, consumers=[2]),
1163
+ self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[3, 4]),
1164
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE),
1165
+ self._get_test_instruction(_QTransf.DUPLICATE_BUFFER, consumers=[3, 4]),
1166
+ ]
1167
+ tensor_instructions = qtyping.TensorTransformationInsts(
1168
+ tensor_name="test_tensor",
1169
+ subgraph_id=0,
1170
+ instructions=instructions,
1171
+ )
1172
+ instruction_gen = (
1173
+ instruction_generator.TransformationInstructionsGenerator()
1174
+ )
1175
+ instruction_gen._remove_unnecessary_buffer_duplication(tensor_instructions)
1176
+
1177
+ self.assertLen(tensor_instructions.instructions, 6)
1178
+ expected_transformations = [
1179
+ _QTransf.DUPLICATE_TENSOR,
1180
+ _QTransf.ADD_QUANTIZE,
1181
+ _QTransf.ADD_DEQUANTIZE,
1182
+ _QTransf.DUPLICATE_BUFFER,
1183
+ _QTransf.DUPLICATE_TENSOR,
1184
+ _QTransf.ADD_QUANTIZE,
1185
+ ]
1186
+ got_transformations = [
1187
+ instruction.transformation
1188
+ for instruction in tensor_instructions.instructions
1189
+ ]
1190
+ self.assertEqual(got_transformations, expected_transformations)
1191
+
1192
+ def test__instruction_generator_removes_unnecessary_tensor_and_buffer_duplication(
1193
+ self,
1194
+ ):
1195
+ test_model_path = os.path.join(
1196
+ TEST_DATA_PREFIX_PATH,
1197
+ "tests/models/constant_tensor_and_buffer_only_sharing_weight_fcs.tflite",
1198
+ )
1199
+ params_4_bits = qtyping.UniformQuantParams(
1200
+ 4, None, np.array([1]), np.array([0])
1201
+ )
1202
+ params_8_bits = qtyping.UniformQuantParams(
1203
+ 8, None, np.array([1]), np.array([0])
1204
+ )
1205
+ quant_parameters = {}
1206
+ # Two FCs share a weight tensor `arith.constant`.
1207
+ quant_parameters["arith.constant"] = qtyping.TensorTransformationParams(
1208
+ tensor_name="arith.constant",
1209
+ producer=None,
1210
+ consumers=[
1211
+ qtyping.OpToTensorParams(
1212
+ subgraph_op_id=0,
1213
+ transformations=[
1214
+ _QTransf.DUPLICATE_TENSOR,
1215
+ _QTransf.DUPLICATE_BUFFER, # Expected to be removed.
1216
+ _QTransf.QUANTIZE_TENSOR,
1217
+ ],
1218
+ parameters=params_8_bits,
1219
+ ),
1220
+ qtyping.OpToTensorParams(
1221
+ subgraph_op_id=1,
1222
+ transformations=[
1223
+ _QTransf.DUPLICATE_TENSOR, # Expected to be removed.
1224
+ _QTransf.DUPLICATE_BUFFER,
1225
+ _QTransf.QUANTIZE_TENSOR,
1226
+ ],
1227
+ parameters=params_4_bits,
1228
+ ),
1229
+ ],
1230
+ )
1231
+ instruction_gen = instruction_generator.TransformationInstructionsGenerator(
1232
+ test_model_path
1233
+ )
1234
+ instructions = instruction_gen.quant_params_to_transformation_insts(
1235
+ quant_parameters
1236
+ )
1237
+
1238
+ def get_expected_instruction(transformation, consumers, params):
1239
+ return qtyping.TransformationInst(
1240
+ transformation=transformation,
1241
+ consumers=consumers,
1242
+ tensor_id=1,
1243
+ producer=-1,
1244
+ parameters=params,
1245
+ )
1246
+
1247
+ expected_instructions = qtyping.TensorTransformationInsts(
1248
+ tensor_name="arith.constant",
1249
+ subgraph_id=0,
1250
+ instructions=[
1251
+ get_expected_instruction(
1252
+ _QTransf.DUPLICATE_TENSOR, consumers=[0], params=params_8_bits
1253
+ ),
1254
+ get_expected_instruction(
1255
+ _QTransf.DUPLICATE_BUFFER, consumers=[1], params=params_4_bits
1256
+ ),
1257
+ get_expected_instruction(
1258
+ _QTransf.QUANTIZE_TENSOR, consumers=[0], params=params_8_bits
1259
+ ),
1260
+ get_expected_instruction(
1261
+ _QTransf.QUANTIZE_TENSOR, consumers=[1], params=params_4_bits
1262
+ ),
1263
+ ],
1264
+ )
1265
+ self.assertLen(instructions, 1)
1266
+ self.assertEqual(instructions["arith.constant"], expected_instructions)
1267
+
1268
+ def test__split_instructions_by_tensor_duplication_returns_expected_subsets(
1269
+ self,
1270
+ ):
1271
+ instructions = [
1272
+ self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[1, 2, 3]), # pylint: disable=line-too-long
1273
+ self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[4]),
1274
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1, 2]),
1275
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[3]),
1276
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[4]),
1277
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[5]),
1278
+ ]
1279
+ tensor_instructions = qtyping.TensorTransformationInsts(
1280
+ tensor_name="test_tensor", subgraph_id=0, instructions=instructions
1281
+ )
1282
+ instruction_gen = (
1283
+ instruction_generator.TransformationInstructionsGenerator()
1284
+ )
1285
+ got = instruction_gen._split_instructions_by_tensor_duplication(
1286
+ tensor_instructions
1287
+ )
1288
+ expected = [
1289
+ [self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[5])],
1290
+ [
1291
+ self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[1, 2, 3]), # pylint: disable=line-too-long
1292
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1, 2]),
1293
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[3]),
1294
+ ],
1295
+ [
1296
+ self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[4]), # pylint: disable=line-too-long
1297
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[4]),
1298
+ ],
1299
+ ]
1300
+ self.assertEqual(got, expected)
1301
+
1302
+ def test__check_tensor_transformation_instructions_valid_succeeds_on_q_dq_with_duplication(
1303
+ self,
1304
+ ):
1305
+ instructions = [
1306
+ self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[1]),
1307
+ self._get_test_instruction(_QTransf.NO_QUANTIZE, consumers=[1]),
1308
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[2]),
1309
+ ]
1310
+ tensor_instructions = qtyping.TensorTransformationInsts(
1311
+ tensor_name="test_tensor", subgraph_id=0, instructions=instructions
1312
+ )
1313
+ instruction_gen = (
1314
+ instruction_generator.TransformationInstructionsGenerator()
1315
+ )
1316
+ instruction_gen._check_tensor_transformation_instructions_valid(
1317
+ tensor_instructions
1318
+ )
1319
+
1320
+ def test__check_tensor_transformation_instructions_valid_fails_when_q_noq_wo_duplication(
1321
+ self,
1322
+ ):
1323
+ tensor_instructions = qtyping.TensorTransformationInsts(
1324
+ tensor_name="test_tensor",
1325
+ subgraph_id=0,
1326
+ instructions=[
1327
+ self._get_test_instruction(_QTransf.NO_QUANTIZE, consumers=[1]),
1328
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[2]),
1329
+ ],
1330
+ )
1331
+ instruction_gen = (
1332
+ instruction_generator.TransformationInstructionsGenerator()
1333
+ )
1334
+ with self.assertRaisesRegex(
1335
+ ValueError, "can not be both quantized and unquantized"
1336
+ ):
1337
+ instruction_gen._check_tensor_transformation_instructions_valid(
1338
+ tensor_instructions
1339
+ )
1340
+
1341
+
1342
+ class EliminateUnnecessaryRequantizationTest(parameterized.TestCase):
1343
+
1344
+ def setUp(self):
1345
+ super().setUp()
1346
+ self.ins_gen = instruction_generator.TransformationInstructionsGenerator(
1347
+ os.path.join(TEST_DATA_PREFIX_PATH, "tests/models/conv_fc_mnist.tflite")
1348
+ )
1349
+
1350
+ def _get_test_instruction(
1351
+ self,
1352
+ transformation: qtyping.QuantTransformation,
1353
+ producer: int = -1,
1354
+ consumers: Optional[Sequence[int]] = None,
1355
+ qparams: Optional[qtyping.UniformQuantParams] = None,
1356
+ ) -> qtyping.TransformationInst:
1357
+ if consumers is None:
1358
+ consumers = []
1359
+ if qparams is None:
1360
+ qparams = qtyping.UniformQuantParams(
1361
+ num_bits=8,
1362
+ quantized_dimension=None,
1363
+ scale=np.array([1]),
1364
+ zero_point=np.array([0]),
1365
+ )
1366
+ return qtyping.TransformationInst(
1367
+ transformation=transformation,
1368
+ producer=producer,
1369
+ consumers=consumers,
1370
+ parameters=qparams,
1371
+ # Dummy values below.
1372
+ tensor_id=0,
1373
+ )
1374
+
1375
+ def _create_test_insts(
1376
+ self, instructions: list[qtyping.TransformationInst]
1377
+ ) -> qtyping.TensorTransformationInsts:
1378
+ return qtyping.TensorTransformationInsts(
1379
+ tensor_name="test_tensor", subgraph_id=0, instructions=instructions
1380
+ )
1381
+
1382
+ def test_no_fusion_when_too_few_instructions(self):
1383
+ tensor_insts = self._create_test_insts([
1384
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
1385
+ ])
1386
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1387
+ tensor_insts
1388
+ )
1389
+ self.assertLen(tensor_insts.instructions, 1)
1390
+
1391
+ def test_no_fusion_when_too_many_instructions(self):
1392
+ tensor_insts = self._create_test_insts([
1393
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
1394
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE),
1395
+ self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
1396
+ ])
1397
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1398
+ tensor_insts
1399
+ )
1400
+ self.assertLen(tensor_insts.instructions, 3)
1401
+
1402
+ def test_no_fusion_when_invalid_transformation_pair(self):
1403
+ tensor_insts = self._create_test_insts([
1404
+ self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
1405
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE),
1406
+ ])
1407
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1408
+ tensor_insts
1409
+ )
1410
+ self.assertLen(tensor_insts.instructions, 2)
1411
+
1412
+ def test_no_fusion_when_consumers_mismatch(self):
1413
+ tensor_insts = self._create_test_insts([
1414
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[0]),
1415
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1]),
1416
+ ])
1417
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1418
+ tensor_insts
1419
+ )
1420
+ self.assertLen(tensor_insts.instructions, 2)
1421
+
1422
+ def test_no_fusion_when_no_producer(self):
1423
+ producer = -1
1424
+ tensor_insts = self._create_test_insts([
1425
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer),
1426
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer),
1427
+ ])
1428
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1429
+ tensor_insts
1430
+ )
1431
+ self.assertLen(tensor_insts.instructions, 2)
1432
+
1433
+ def test_no_fusion_when_quant_params_are_incompatible(self):
1434
+ params_8_bits = qtyping.UniformQuantParams(
1435
+ 8, None, np.array([1]), np.array([0])
1436
+ )
1437
+ params_16_bits = qtyping.UniformQuantParams(
1438
+ 16, None, np.array([1]), np.array([0])
1439
+ )
1440
+ tensor_insts = self._create_test_insts([
1441
+ self._get_test_instruction(
1442
+ _QTransf.QUANTIZE_TENSOR, qparams=params_8_bits
1443
+ ),
1444
+ self._get_test_instruction(
1445
+ _QTransf.ADD_QUANTIZE, qparams=params_16_bits
1446
+ ),
1447
+ ])
1448
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1449
+ tensor_insts
1450
+ )
1451
+ self.assertLen(tensor_insts.instructions, 2)
1452
+
1453
+ def test_no_fusion_when_producer_constrained(self):
1454
+ # Reshape op (op index 2) has same as input scale constraint.
1455
+ tensor_insts = self._create_test_insts([
1456
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer=2),
1457
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer=2),
1458
+ ])
1459
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1460
+ tensor_insts
1461
+ )
1462
+ self.assertLen(tensor_insts.instructions, 2)
1463
+
1464
+ def test_fusion_succeeds(self):
1465
+ producer = 0
1466
+ consumers = [1]
1467
+ params_0 = qtyping.UniformQuantParams(
1468
+ num_bits=8,
1469
+ quantized_dimension=None,
1470
+ scale=np.array([1]),
1471
+ zero_point=np.array([0]),
1472
+ )
1473
+ params_1 = qtyping.UniformQuantParams(
1474
+ num_bits=8,
1475
+ quantized_dimension=None,
1476
+ scale=np.array([2]),
1477
+ zero_point=np.array([1]),
1478
+ )
1479
+ inst_0 = self._get_test_instruction(
1480
+ _QTransf.QUANTIZE_TENSOR, producer, consumers, params_0
1481
+ )
1482
+ inst_1 = self._get_test_instruction(
1483
+ _QTransf.ADD_QUANTIZE, producer, consumers, params_1
1484
+ )
1485
+ tensor_insts = self._create_test_insts([inst_0, inst_1])
1486
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1487
+ tensor_insts
1488
+ )
1489
+
1490
+ self.assertLen(tensor_insts.instructions, 1)
1491
+ result_inst = tensor_insts.instructions[0]
1492
+ self.assertEqual(result_inst.transformation, _QTransf.QUANTIZE_TENSOR)
1493
+
1494
+ result_params = result_inst.parameters
1495
+ # Explicitly narrow the type for pytype.
1496
+ if not isinstance(result_params, qtyping.UniformQuantParams):
1497
+ self.fail("Fused instruction parameters are not UniformQuantParams")
1498
+
1499
+ self.assertEqual(result_params.scale, params_1.scale)
1500
+ self.assertEqual(result_params.zero_point, params_1.zero_point)
1501
+
1080
1502
 
1081
1503
  if __name__ == "__main__":
1082
1504
  googletest.main()