ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +224 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
- ai_edge_quantizer/calibrator.py +58 -94
- ai_edge_quantizer/calibrator_test.py +5 -74
- ai_edge_quantizer/default_policy.py +108 -16
- ai_edge_quantizer/model_modifier.py +132 -8
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +38 -10
- ai_edge_quantizer/model_validator_test.py +2 -1
- ai_edge_quantizer/params_generator.py +230 -47
- ai_edge_quantizer/params_generator_test.py +366 -261
- ai_edge_quantizer/qtyping.py +92 -6
- ai_edge_quantizer/quantizer.py +167 -23
- ai_edge_quantizer/quantizer_test.py +288 -26
- ai_edge_quantizer/recipe.py +156 -21
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +313 -46
- ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
- ai_edge_quantizer/transformation_performer.py +112 -58
- ai_edge_quantizer/transformation_performer_test.py +176 -4
- ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
- ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
- ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
- ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
- ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
- ai_edge_quantizer/transformations/transformation_utils.py +157 -11
- ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +191 -58
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -15,7 +15,9 @@
|
|
|
15
15
|
|
|
16
16
|
"""Tests for instruction_generator."""
|
|
17
17
|
|
|
18
|
+
from collections.abc import Sequence
|
|
18
19
|
import os
|
|
20
|
+
from typing import Optional
|
|
19
21
|
|
|
20
22
|
import numpy as np
|
|
21
23
|
|
|
@@ -27,6 +29,8 @@ from ai_edge_quantizer.utils import test_utils
|
|
|
27
29
|
|
|
28
30
|
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(".")
|
|
29
31
|
|
|
32
|
+
_QTransf = qtyping.QuantTransformation
|
|
33
|
+
|
|
30
34
|
|
|
31
35
|
class InstructionGeneratorTest(parameterized.TestCase):
|
|
32
36
|
|
|
@@ -951,33 +955,6 @@ class InstructionGeneratorTest(parameterized.TestCase):
|
|
|
951
955
|
instructions["StatefulPartitionedCall:0"], output_transformation
|
|
952
956
|
)
|
|
953
957
|
|
|
954
|
-
def test_raise_error_on_op_replacement_transformation_is_not_unique(self):
|
|
955
|
-
test_model_path = os.path.join(
|
|
956
|
-
TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
|
|
957
|
-
)
|
|
958
|
-
quant_parameters = {}
|
|
959
|
-
quant_parameters["tfl.quantize"] = qtyping.TensorTransformationParams(
|
|
960
|
-
"tfl.quantize",
|
|
961
|
-
qtyping.OpToTensorParams(
|
|
962
|
-
subgraph_op_id=0,
|
|
963
|
-
transformations=[
|
|
964
|
-
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
965
|
-
qtyping.QuantTransformation.EMULATED_SUBCHANNEL,
|
|
966
|
-
],
|
|
967
|
-
parameters=qtyping.UniformQuantParams(
|
|
968
|
-
8, None, np.array([1]), np.array([0])
|
|
969
|
-
),
|
|
970
|
-
),
|
|
971
|
-
[],
|
|
972
|
-
)
|
|
973
|
-
ins_gen = instruction_generator.TransformationInstructionsGenerator(
|
|
974
|
-
test_model_path
|
|
975
|
-
)
|
|
976
|
-
with self.assertRaisesRegex(
|
|
977
|
-
ValueError, "op replacement transformation can not be combined"
|
|
978
|
-
):
|
|
979
|
-
ins_gen.quant_params_to_transformation_insts(quant_parameters)
|
|
980
|
-
|
|
981
958
|
def test_raise_error_on_no_quant_conflict(self):
|
|
982
959
|
test_model_path = os.path.join(
|
|
983
960
|
TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
|
|
@@ -1077,6 +1054,451 @@ class InstructionGeneratorTest(parameterized.TestCase):
|
|
|
1077
1054
|
self.assertLen(instructions, 1)
|
|
1078
1055
|
self.assertEqual(instructions["tfl.quantize"], expected_instructions)
|
|
1079
1056
|
|
|
1057
|
+
def test_instruction_generator_keeps_buffer_duplication_as_first_transformation(
|
|
1058
|
+
self,
|
|
1059
|
+
):
|
|
1060
|
+
test_tensor_name = "test_tensor"
|
|
1061
|
+
|
|
1062
|
+
dummy_quant_params = qtyping.UniformQuantParams(
|
|
1063
|
+
8, None, np.array([1]), np.array([0])
|
|
1064
|
+
)
|
|
1065
|
+
consumer_params_1 = qtyping.OpToTensorParams(
|
|
1066
|
+
subgraph_op_id=0,
|
|
1067
|
+
transformations=[
|
|
1068
|
+
qtyping.QuantTransformation.DUPLICATE_BUFFER,
|
|
1069
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
|
1070
|
+
],
|
|
1071
|
+
parameters=dummy_quant_params,
|
|
1072
|
+
)
|
|
1073
|
+
consumer_params_2 = qtyping.OpToTensorParams(
|
|
1074
|
+
subgraph_op_id=2,
|
|
1075
|
+
transformations=[
|
|
1076
|
+
qtyping.QuantTransformation.DUPLICATE_BUFFER,
|
|
1077
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
|
1078
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
1079
|
+
],
|
|
1080
|
+
parameters=dummy_quant_params,
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1083
|
+
quant_parameters = {
|
|
1084
|
+
test_tensor_name: qtyping.TensorTransformationParams(
|
|
1085
|
+
tensor_name=test_tensor_name,
|
|
1086
|
+
producer=None,
|
|
1087
|
+
consumers=[consumer_params_1, consumer_params_2],
|
|
1088
|
+
),
|
|
1089
|
+
}
|
|
1090
|
+
instruction_gen = (
|
|
1091
|
+
instruction_generator.TransformationInstructionsGenerator()
|
|
1092
|
+
)
|
|
1093
|
+
# _tensor_name_to_graph_info has to have an entry for the test tensor for
|
|
1094
|
+
# `quant_params_to_transformation_insts` to work. But the values do not
|
|
1095
|
+
# matter for this test.
|
|
1096
|
+
instruction_gen._tensor_name_to_graph_info[test_tensor_name] = (
|
|
1097
|
+
instruction_generator.TransformationInstructionsGenerator.TensorGraphInfo(
|
|
1098
|
+
tensor_id=1,
|
|
1099
|
+
subgraph_id=0,
|
|
1100
|
+
producer=0,
|
|
1101
|
+
consumers=[2],
|
|
1102
|
+
)
|
|
1103
|
+
)
|
|
1104
|
+
instructions = instruction_gen.quant_params_to_transformation_insts(
|
|
1105
|
+
quant_parameters
|
|
1106
|
+
)
|
|
1107
|
+
self.assertLen(instructions, 1)
|
|
1108
|
+
instructions = instructions[test_tensor_name].instructions
|
|
1109
|
+
self.assertGreater(len(instructions), 1)
|
|
1110
|
+
self.assertEqual(instructions[0].transformation, _QTransf.DUPLICATE_BUFFER)
|
|
1111
|
+
self.assertNotIn(_QTransf.DUPLICATE_BUFFER, instructions[1:])
|
|
1112
|
+
|
|
1113
|
+
def _get_test_instruction(self, transformation, consumers=None):
|
|
1114
|
+
if consumers is None:
|
|
1115
|
+
consumers = []
|
|
1116
|
+
return qtyping.TransformationInst(
|
|
1117
|
+
transformation=transformation,
|
|
1118
|
+
consumers=consumers,
|
|
1119
|
+
# Dummy values below.
|
|
1120
|
+
tensor_id=0,
|
|
1121
|
+
producer=None,
|
|
1122
|
+
parameters=None,
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
def test__remove_last_tensor_duplication_succeeds(self):
|
|
1126
|
+
tensor_instructions = qtyping.TensorTransformationInsts(
|
|
1127
|
+
tensor_name="test_tensor",
|
|
1128
|
+
subgraph_id=0,
|
|
1129
|
+
instructions=[
|
|
1130
|
+
self._get_test_instruction(_QTransf.DUPLICATE_TENSOR),
|
|
1131
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE),
|
|
1132
|
+
self._get_test_instruction(_QTransf.DUPLICATE_TENSOR),
|
|
1133
|
+
self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
|
|
1134
|
+
],
|
|
1135
|
+
)
|
|
1136
|
+
instruction_gen = (
|
|
1137
|
+
instruction_generator.TransformationInstructionsGenerator()
|
|
1138
|
+
)
|
|
1139
|
+
instruction_gen._remove_last_tensor_duplication(tensor_instructions)
|
|
1140
|
+
|
|
1141
|
+
self.assertLen(tensor_instructions.instructions, 3)
|
|
1142
|
+
expected_transformations = [
|
|
1143
|
+
_QTransf.DUPLICATE_TENSOR,
|
|
1144
|
+
_QTransf.ADD_QUANTIZE,
|
|
1145
|
+
_QTransf.ADD_DEQUANTIZE,
|
|
1146
|
+
]
|
|
1147
|
+
got_transformations = [
|
|
1148
|
+
instruction.transformation
|
|
1149
|
+
for instruction in tensor_instructions.instructions
|
|
1150
|
+
]
|
|
1151
|
+
self.assertEqual(got_transformations, expected_transformations)
|
|
1152
|
+
|
|
1153
|
+
def test__remove_unnecessary_buffer_duplication_succeeds(
|
|
1154
|
+
self,
|
|
1155
|
+
):
|
|
1156
|
+
instructions = [
|
|
1157
|
+
self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[1]),
|
|
1158
|
+
self._get_test_instruction(_QTransf.DUPLICATE_BUFFER, consumers=[1]),
|
|
1159
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE),
|
|
1160
|
+
self._get_test_instruction(_QTransf.DUPLICATE_BUFFER, consumers=[1]),
|
|
1161
|
+
self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
|
|
1162
|
+
self._get_test_instruction(_QTransf.DUPLICATE_BUFFER, consumers=[2]),
|
|
1163
|
+
self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[3, 4]),
|
|
1164
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE),
|
|
1165
|
+
self._get_test_instruction(_QTransf.DUPLICATE_BUFFER, consumers=[3, 4]),
|
|
1166
|
+
]
|
|
1167
|
+
tensor_instructions = qtyping.TensorTransformationInsts(
|
|
1168
|
+
tensor_name="test_tensor",
|
|
1169
|
+
subgraph_id=0,
|
|
1170
|
+
instructions=instructions,
|
|
1171
|
+
)
|
|
1172
|
+
instruction_gen = (
|
|
1173
|
+
instruction_generator.TransformationInstructionsGenerator()
|
|
1174
|
+
)
|
|
1175
|
+
instruction_gen._remove_unnecessary_buffer_duplication(tensor_instructions)
|
|
1176
|
+
|
|
1177
|
+
self.assertLen(tensor_instructions.instructions, 6)
|
|
1178
|
+
expected_transformations = [
|
|
1179
|
+
_QTransf.DUPLICATE_TENSOR,
|
|
1180
|
+
_QTransf.ADD_QUANTIZE,
|
|
1181
|
+
_QTransf.ADD_DEQUANTIZE,
|
|
1182
|
+
_QTransf.DUPLICATE_BUFFER,
|
|
1183
|
+
_QTransf.DUPLICATE_TENSOR,
|
|
1184
|
+
_QTransf.ADD_QUANTIZE,
|
|
1185
|
+
]
|
|
1186
|
+
got_transformations = [
|
|
1187
|
+
instruction.transformation
|
|
1188
|
+
for instruction in tensor_instructions.instructions
|
|
1189
|
+
]
|
|
1190
|
+
self.assertEqual(got_transformations, expected_transformations)
|
|
1191
|
+
|
|
1192
|
+
def test__instruction_generator_removes_unnecessary_tensor_and_buffer_duplication(
|
|
1193
|
+
self,
|
|
1194
|
+
):
|
|
1195
|
+
test_model_path = os.path.join(
|
|
1196
|
+
TEST_DATA_PREFIX_PATH,
|
|
1197
|
+
"tests/models/constant_tensor_and_buffer_only_sharing_weight_fcs.tflite",
|
|
1198
|
+
)
|
|
1199
|
+
params_4_bits = qtyping.UniformQuantParams(
|
|
1200
|
+
4, None, np.array([1]), np.array([0])
|
|
1201
|
+
)
|
|
1202
|
+
params_8_bits = qtyping.UniformQuantParams(
|
|
1203
|
+
8, None, np.array([1]), np.array([0])
|
|
1204
|
+
)
|
|
1205
|
+
quant_parameters = {}
|
|
1206
|
+
# Two FCs share a weight tensor `arith.constant`.
|
|
1207
|
+
quant_parameters["arith.constant"] = qtyping.TensorTransformationParams(
|
|
1208
|
+
tensor_name="arith.constant",
|
|
1209
|
+
producer=None,
|
|
1210
|
+
consumers=[
|
|
1211
|
+
qtyping.OpToTensorParams(
|
|
1212
|
+
subgraph_op_id=0,
|
|
1213
|
+
transformations=[
|
|
1214
|
+
_QTransf.DUPLICATE_TENSOR,
|
|
1215
|
+
_QTransf.DUPLICATE_BUFFER, # Expected to be removed.
|
|
1216
|
+
_QTransf.QUANTIZE_TENSOR,
|
|
1217
|
+
],
|
|
1218
|
+
parameters=params_8_bits,
|
|
1219
|
+
),
|
|
1220
|
+
qtyping.OpToTensorParams(
|
|
1221
|
+
subgraph_op_id=1,
|
|
1222
|
+
transformations=[
|
|
1223
|
+
_QTransf.DUPLICATE_TENSOR, # Expected to be removed.
|
|
1224
|
+
_QTransf.DUPLICATE_BUFFER,
|
|
1225
|
+
_QTransf.QUANTIZE_TENSOR,
|
|
1226
|
+
],
|
|
1227
|
+
parameters=params_4_bits,
|
|
1228
|
+
),
|
|
1229
|
+
],
|
|
1230
|
+
)
|
|
1231
|
+
instruction_gen = instruction_generator.TransformationInstructionsGenerator(
|
|
1232
|
+
test_model_path
|
|
1233
|
+
)
|
|
1234
|
+
instructions = instruction_gen.quant_params_to_transformation_insts(
|
|
1235
|
+
quant_parameters
|
|
1236
|
+
)
|
|
1237
|
+
|
|
1238
|
+
def get_expected_instruction(transformation, consumers, params):
|
|
1239
|
+
return qtyping.TransformationInst(
|
|
1240
|
+
transformation=transformation,
|
|
1241
|
+
consumers=consumers,
|
|
1242
|
+
tensor_id=1,
|
|
1243
|
+
producer=-1,
|
|
1244
|
+
parameters=params,
|
|
1245
|
+
)
|
|
1246
|
+
|
|
1247
|
+
expected_instructions = qtyping.TensorTransformationInsts(
|
|
1248
|
+
tensor_name="arith.constant",
|
|
1249
|
+
subgraph_id=0,
|
|
1250
|
+
instructions=[
|
|
1251
|
+
get_expected_instruction(
|
|
1252
|
+
_QTransf.DUPLICATE_TENSOR, consumers=[0], params=params_8_bits
|
|
1253
|
+
),
|
|
1254
|
+
get_expected_instruction(
|
|
1255
|
+
_QTransf.DUPLICATE_BUFFER, consumers=[1], params=params_4_bits
|
|
1256
|
+
),
|
|
1257
|
+
get_expected_instruction(
|
|
1258
|
+
_QTransf.QUANTIZE_TENSOR, consumers=[0], params=params_8_bits
|
|
1259
|
+
),
|
|
1260
|
+
get_expected_instruction(
|
|
1261
|
+
_QTransf.QUANTIZE_TENSOR, consumers=[1], params=params_4_bits
|
|
1262
|
+
),
|
|
1263
|
+
],
|
|
1264
|
+
)
|
|
1265
|
+
self.assertLen(instructions, 1)
|
|
1266
|
+
self.assertEqual(instructions["arith.constant"], expected_instructions)
|
|
1267
|
+
|
|
1268
|
+
def test__split_instructions_by_tensor_duplication_returns_expected_subsets(
|
|
1269
|
+
self,
|
|
1270
|
+
):
|
|
1271
|
+
instructions = [
|
|
1272
|
+
self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[1, 2, 3]), # pylint: disable=line-too-long
|
|
1273
|
+
self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[4]),
|
|
1274
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1, 2]),
|
|
1275
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[3]),
|
|
1276
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[4]),
|
|
1277
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[5]),
|
|
1278
|
+
]
|
|
1279
|
+
tensor_instructions = qtyping.TensorTransformationInsts(
|
|
1280
|
+
tensor_name="test_tensor", subgraph_id=0, instructions=instructions
|
|
1281
|
+
)
|
|
1282
|
+
instruction_gen = (
|
|
1283
|
+
instruction_generator.TransformationInstructionsGenerator()
|
|
1284
|
+
)
|
|
1285
|
+
got = instruction_gen._split_instructions_by_tensor_duplication(
|
|
1286
|
+
tensor_instructions
|
|
1287
|
+
)
|
|
1288
|
+
expected = [
|
|
1289
|
+
[self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[5])],
|
|
1290
|
+
[
|
|
1291
|
+
self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[1, 2, 3]), # pylint: disable=line-too-long
|
|
1292
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1, 2]),
|
|
1293
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[3]),
|
|
1294
|
+
],
|
|
1295
|
+
[
|
|
1296
|
+
self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[4]), # pylint: disable=line-too-long
|
|
1297
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[4]),
|
|
1298
|
+
],
|
|
1299
|
+
]
|
|
1300
|
+
self.assertEqual(got, expected)
|
|
1301
|
+
|
|
1302
|
+
def test__check_tensor_transformation_instructions_valid_succeeds_on_q_dq_with_duplication(
|
|
1303
|
+
self,
|
|
1304
|
+
):
|
|
1305
|
+
instructions = [
|
|
1306
|
+
self._get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[1]),
|
|
1307
|
+
self._get_test_instruction(_QTransf.NO_QUANTIZE, consumers=[1]),
|
|
1308
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[2]),
|
|
1309
|
+
]
|
|
1310
|
+
tensor_instructions = qtyping.TensorTransformationInsts(
|
|
1311
|
+
tensor_name="test_tensor", subgraph_id=0, instructions=instructions
|
|
1312
|
+
)
|
|
1313
|
+
instruction_gen = (
|
|
1314
|
+
instruction_generator.TransformationInstructionsGenerator()
|
|
1315
|
+
)
|
|
1316
|
+
instruction_gen._check_tensor_transformation_instructions_valid(
|
|
1317
|
+
tensor_instructions
|
|
1318
|
+
)
|
|
1319
|
+
|
|
1320
|
+
def test__check_tensor_transformation_instructions_valid_fails_when_q_noq_wo_duplication(
|
|
1321
|
+
self,
|
|
1322
|
+
):
|
|
1323
|
+
tensor_instructions = qtyping.TensorTransformationInsts(
|
|
1324
|
+
tensor_name="test_tensor",
|
|
1325
|
+
subgraph_id=0,
|
|
1326
|
+
instructions=[
|
|
1327
|
+
self._get_test_instruction(_QTransf.NO_QUANTIZE, consumers=[1]),
|
|
1328
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[2]),
|
|
1329
|
+
],
|
|
1330
|
+
)
|
|
1331
|
+
instruction_gen = (
|
|
1332
|
+
instruction_generator.TransformationInstructionsGenerator()
|
|
1333
|
+
)
|
|
1334
|
+
with self.assertRaisesRegex(
|
|
1335
|
+
ValueError, "can not be both quantized and unquantized"
|
|
1336
|
+
):
|
|
1337
|
+
instruction_gen._check_tensor_transformation_instructions_valid(
|
|
1338
|
+
tensor_instructions
|
|
1339
|
+
)
|
|
1340
|
+
|
|
1341
|
+
|
|
1342
|
+
class EliminateUnnecessaryRequantizationTest(parameterized.TestCase):
|
|
1343
|
+
|
|
1344
|
+
def setUp(self):
|
|
1345
|
+
super().setUp()
|
|
1346
|
+
self.ins_gen = instruction_generator.TransformationInstructionsGenerator(
|
|
1347
|
+
os.path.join(TEST_DATA_PREFIX_PATH, "tests/models/conv_fc_mnist.tflite")
|
|
1348
|
+
)
|
|
1349
|
+
|
|
1350
|
+
def _get_test_instruction(
|
|
1351
|
+
self,
|
|
1352
|
+
transformation: qtyping.QuantTransformation,
|
|
1353
|
+
producer: int = -1,
|
|
1354
|
+
consumers: Optional[Sequence[int]] = None,
|
|
1355
|
+
qparams: Optional[qtyping.UniformQuantParams] = None,
|
|
1356
|
+
) -> qtyping.TransformationInst:
|
|
1357
|
+
if consumers is None:
|
|
1358
|
+
consumers = []
|
|
1359
|
+
if qparams is None:
|
|
1360
|
+
qparams = qtyping.UniformQuantParams(
|
|
1361
|
+
num_bits=8,
|
|
1362
|
+
quantized_dimension=None,
|
|
1363
|
+
scale=np.array([1]),
|
|
1364
|
+
zero_point=np.array([0]),
|
|
1365
|
+
)
|
|
1366
|
+
return qtyping.TransformationInst(
|
|
1367
|
+
transformation=transformation,
|
|
1368
|
+
producer=producer,
|
|
1369
|
+
consumers=consumers,
|
|
1370
|
+
parameters=qparams,
|
|
1371
|
+
# Dummy values below.
|
|
1372
|
+
tensor_id=0,
|
|
1373
|
+
)
|
|
1374
|
+
|
|
1375
|
+
def _create_test_insts(
|
|
1376
|
+
self, instructions: list[qtyping.TransformationInst]
|
|
1377
|
+
) -> qtyping.TensorTransformationInsts:
|
|
1378
|
+
return qtyping.TensorTransformationInsts(
|
|
1379
|
+
tensor_name="test_tensor", subgraph_id=0, instructions=instructions
|
|
1380
|
+
)
|
|
1381
|
+
|
|
1382
|
+
def test_no_fusion_when_too_few_instructions(self):
|
|
1383
|
+
tensor_insts = self._create_test_insts([
|
|
1384
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
|
|
1385
|
+
])
|
|
1386
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1387
|
+
tensor_insts
|
|
1388
|
+
)
|
|
1389
|
+
self.assertLen(tensor_insts.instructions, 1)
|
|
1390
|
+
|
|
1391
|
+
def test_no_fusion_when_too_many_instructions(self):
|
|
1392
|
+
tensor_insts = self._create_test_insts([
|
|
1393
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
|
|
1394
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE),
|
|
1395
|
+
self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
|
|
1396
|
+
])
|
|
1397
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1398
|
+
tensor_insts
|
|
1399
|
+
)
|
|
1400
|
+
self.assertLen(tensor_insts.instructions, 3)
|
|
1401
|
+
|
|
1402
|
+
def test_no_fusion_when_invalid_transformation_pair(self):
|
|
1403
|
+
tensor_insts = self._create_test_insts([
|
|
1404
|
+
self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
|
|
1405
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE),
|
|
1406
|
+
])
|
|
1407
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1408
|
+
tensor_insts
|
|
1409
|
+
)
|
|
1410
|
+
self.assertLen(tensor_insts.instructions, 2)
|
|
1411
|
+
|
|
1412
|
+
def test_no_fusion_when_consumers_mismatch(self):
|
|
1413
|
+
tensor_insts = self._create_test_insts([
|
|
1414
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[0]),
|
|
1415
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1]),
|
|
1416
|
+
])
|
|
1417
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1418
|
+
tensor_insts
|
|
1419
|
+
)
|
|
1420
|
+
self.assertLen(tensor_insts.instructions, 2)
|
|
1421
|
+
|
|
1422
|
+
def test_no_fusion_when_no_producer(self):
|
|
1423
|
+
producer = -1
|
|
1424
|
+
tensor_insts = self._create_test_insts([
|
|
1425
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer),
|
|
1426
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer),
|
|
1427
|
+
])
|
|
1428
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1429
|
+
tensor_insts
|
|
1430
|
+
)
|
|
1431
|
+
self.assertLen(tensor_insts.instructions, 2)
|
|
1432
|
+
|
|
1433
|
+
def test_no_fusion_when_quant_params_are_incompatible(self):
|
|
1434
|
+
params_8_bits = qtyping.UniformQuantParams(
|
|
1435
|
+
8, None, np.array([1]), np.array([0])
|
|
1436
|
+
)
|
|
1437
|
+
params_16_bits = qtyping.UniformQuantParams(
|
|
1438
|
+
16, None, np.array([1]), np.array([0])
|
|
1439
|
+
)
|
|
1440
|
+
tensor_insts = self._create_test_insts([
|
|
1441
|
+
self._get_test_instruction(
|
|
1442
|
+
_QTransf.QUANTIZE_TENSOR, qparams=params_8_bits
|
|
1443
|
+
),
|
|
1444
|
+
self._get_test_instruction(
|
|
1445
|
+
_QTransf.ADD_QUANTIZE, qparams=params_16_bits
|
|
1446
|
+
),
|
|
1447
|
+
])
|
|
1448
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1449
|
+
tensor_insts
|
|
1450
|
+
)
|
|
1451
|
+
self.assertLen(tensor_insts.instructions, 2)
|
|
1452
|
+
|
|
1453
|
+
def test_no_fusion_when_producer_constrained(self):
|
|
1454
|
+
# Reshape op (op index 2) has same as input scale constraint.
|
|
1455
|
+
tensor_insts = self._create_test_insts([
|
|
1456
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer=2),
|
|
1457
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer=2),
|
|
1458
|
+
])
|
|
1459
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1460
|
+
tensor_insts
|
|
1461
|
+
)
|
|
1462
|
+
self.assertLen(tensor_insts.instructions, 2)
|
|
1463
|
+
|
|
1464
|
+
def test_fusion_succeeds(self):
|
|
1465
|
+
producer = 0
|
|
1466
|
+
consumers = [1]
|
|
1467
|
+
params_0 = qtyping.UniformQuantParams(
|
|
1468
|
+
num_bits=8,
|
|
1469
|
+
quantized_dimension=None,
|
|
1470
|
+
scale=np.array([1]),
|
|
1471
|
+
zero_point=np.array([0]),
|
|
1472
|
+
)
|
|
1473
|
+
params_1 = qtyping.UniformQuantParams(
|
|
1474
|
+
num_bits=8,
|
|
1475
|
+
quantized_dimension=None,
|
|
1476
|
+
scale=np.array([2]),
|
|
1477
|
+
zero_point=np.array([1]),
|
|
1478
|
+
)
|
|
1479
|
+
inst_0 = self._get_test_instruction(
|
|
1480
|
+
_QTransf.QUANTIZE_TENSOR, producer, consumers, params_0
|
|
1481
|
+
)
|
|
1482
|
+
inst_1 = self._get_test_instruction(
|
|
1483
|
+
_QTransf.ADD_QUANTIZE, producer, consumers, params_1
|
|
1484
|
+
)
|
|
1485
|
+
tensor_insts = self._create_test_insts([inst_0, inst_1])
|
|
1486
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1487
|
+
tensor_insts
|
|
1488
|
+
)
|
|
1489
|
+
|
|
1490
|
+
self.assertLen(tensor_insts.instructions, 1)
|
|
1491
|
+
result_inst = tensor_insts.instructions[0]
|
|
1492
|
+
self.assertEqual(result_inst.transformation, _QTransf.QUANTIZE_TENSOR)
|
|
1493
|
+
|
|
1494
|
+
result_params = result_inst.parameters
|
|
1495
|
+
# Explicitly narrow the type for pytype.
|
|
1496
|
+
if not isinstance(result_params, qtyping.UniformQuantParams):
|
|
1497
|
+
self.fail("Fused instruction parameters are not UniformQuantParams")
|
|
1498
|
+
|
|
1499
|
+
self.assertEqual(result_params.scale, params_1.scale)
|
|
1500
|
+
self.assertEqual(result_params.zero_point, params_1.zero_point)
|
|
1501
|
+
|
|
1080
1502
|
|
|
1081
1503
|
if __name__ == "__main__":
|
|
1082
1504
|
googletest.main()
|