ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +224 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
- ai_edge_quantizer/calibrator.py +58 -94
- ai_edge_quantizer/calibrator_test.py +5 -74
- ai_edge_quantizer/default_policy.py +108 -16
- ai_edge_quantizer/model_modifier.py +132 -8
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +38 -10
- ai_edge_quantizer/model_validator_test.py +2 -1
- ai_edge_quantizer/params_generator.py +230 -47
- ai_edge_quantizer/params_generator_test.py +366 -261
- ai_edge_quantizer/qtyping.py +92 -6
- ai_edge_quantizer/quantizer.py +167 -23
- ai_edge_quantizer/quantizer_test.py +288 -26
- ai_edge_quantizer/recipe.py +156 -21
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +313 -46
- ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
- ai_edge_quantizer/transformation_performer.py +112 -58
- ai_edge_quantizer/transformation_performer_test.py +176 -4
- ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
- ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
- ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
- ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
- ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
- ai_edge_quantizer/transformations/transformation_utils.py +157 -11
- ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +191 -58
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -15,10 +15,17 @@
|
|
|
15
15
|
|
|
16
16
|
"""Python manager for transformations to be applied to TFlite models."""
|
|
17
17
|
|
|
18
|
+
from collections.abc import Sequence
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
18
21
|
import numpy as np
|
|
22
|
+
|
|
19
23
|
from ai_edge_quantizer import qtyping
|
|
20
24
|
from ai_edge_quantizer.transformations import dequant_insert
|
|
21
|
-
from ai_edge_quantizer.transformations import
|
|
25
|
+
from ai_edge_quantizer.transformations import duplicate_buffer
|
|
26
|
+
from ai_edge_quantizer.transformations import duplicate_tensor
|
|
27
|
+
from ai_edge_quantizer.transformations import insert_decomposed_hadamard_rotation
|
|
28
|
+
from ai_edge_quantizer.transformations import insert_hadamard_rotation
|
|
22
29
|
from ai_edge_quantizer.transformations import quant_insert
|
|
23
30
|
from ai_edge_quantizer.transformations import quantize_tensor
|
|
24
31
|
from ai_edge_quantizer.transformations import transformation_utils
|
|
@@ -65,9 +72,21 @@ class TransformationPerformer:
|
|
|
65
72
|
quantize_tensor.quantize_tensor
|
|
66
73
|
),
|
|
67
74
|
qtyping.QuantTransformation.EMULATED_SUBCHANNEL: (
|
|
68
|
-
|
|
75
|
+
transformation_utils.raise_deprecated_error
|
|
69
76
|
),
|
|
70
77
|
qtyping.QuantTransformation.ADD_QUANTIZE: quant_insert.insert_quant,
|
|
78
|
+
qtyping.QuantTransformation.DUPLICATE_BUFFER: (
|
|
79
|
+
duplicate_buffer.duplicate_buffer
|
|
80
|
+
),
|
|
81
|
+
qtyping.QuantTransformation.DUPLICATE_TENSOR: (
|
|
82
|
+
duplicate_tensor.duplicate_tensor
|
|
83
|
+
),
|
|
84
|
+
qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION: (
|
|
85
|
+
insert_hadamard_rotation.insert_hadamard_rotation
|
|
86
|
+
),
|
|
87
|
+
qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION: (
|
|
88
|
+
insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation
|
|
89
|
+
),
|
|
71
90
|
}
|
|
72
91
|
# transformations are seprated in two categories:
|
|
73
92
|
# op_insertion_transformations are transformations that only insert ops
|
|
@@ -77,6 +96,10 @@ class TransformationPerformer:
|
|
|
77
96
|
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
78
97
|
qtyping.QuantTransformation.QUANTIZE_TENSOR,
|
|
79
98
|
qtyping.QuantTransformation.ADD_QUANTIZE,
|
|
99
|
+
qtyping.QuantTransformation.DUPLICATE_BUFFER,
|
|
100
|
+
qtyping.QuantTransformation.DUPLICATE_TENSOR,
|
|
101
|
+
qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
|
|
102
|
+
qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION,
|
|
80
103
|
])
|
|
81
104
|
self._op_replacement_transformations = set(
|
|
82
105
|
[qtyping.QuantTransformation.EMULATED_SUBCHANNEL]
|
|
@@ -123,44 +146,81 @@ class TransformationPerformer:
|
|
|
123
146
|
transformations: list[qtyping.TransformationInst],
|
|
124
147
|
subgraph_id: int,
|
|
125
148
|
trans_info: qtyping.TransformationInfo,
|
|
126
|
-
):
|
|
127
|
-
"""Update the instructions after the graph is modified.
|
|
149
|
+
) -> None:
|
|
150
|
+
"""Update the instructions in-place after the graph is modified.
|
|
128
151
|
|
|
129
|
-
After an op is inserted, the topology is changed
|
|
130
|
-
following transformation to be applied. So we need
|
|
131
|
-
that have yet to be applied.
|
|
152
|
+
After an op is inserted or a tensor is duplicated, the topology is changed
|
|
153
|
+
and this may impact the following transformation to be applied. So we need
|
|
154
|
+
to update instructions that have yet to be applied.
|
|
132
155
|
|
|
133
156
|
Args:
|
|
134
|
-
prev_transformation_index:
|
|
135
|
-
transformations:
|
|
136
|
-
subgraph_id:
|
|
137
|
-
trans_info:
|
|
138
|
-
|
|
139
|
-
Returns:
|
|
140
|
-
None, modifies the transformation in place
|
|
157
|
+
prev_transformation_index: The index of the last applied transformation.
|
|
158
|
+
transformations: The list of transformations we're applying.
|
|
159
|
+
subgraph_id: The subgraph where the provided instructions belong to.
|
|
160
|
+
trans_info: Transformation info returned by a transformation.
|
|
141
161
|
"""
|
|
142
|
-
# if no ops were added, then no need for update
|
|
143
|
-
if trans_info.num_ops_added == 0:
|
|
144
|
-
return
|
|
145
162
|
prev_transformation = transformations[prev_transformation_index]
|
|
146
|
-
|
|
147
|
-
|
|
163
|
+
is_prev_not_duplicate_tensor = (
|
|
164
|
+
prev_transformation.transformation
|
|
165
|
+
!= qtyping.QuantTransformation.DUPLICATE_TENSOR
|
|
148
166
|
)
|
|
167
|
+
was_op_added = trans_info.num_ops_added > 0
|
|
168
|
+
if not was_op_added and is_prev_not_duplicate_tensor:
|
|
169
|
+
return
|
|
170
|
+
|
|
171
|
+
if was_op_added:
|
|
172
|
+
self._added_op_id_map[subgraph_id].append(
|
|
173
|
+
trans_info.op_id + trans_info.num_ops_added - 1
|
|
174
|
+
)
|
|
175
|
+
|
|
149
176
|
for transformations_index in range(
|
|
150
177
|
prev_transformation_index + 1, len(transformations)
|
|
151
178
|
):
|
|
152
179
|
transformation = transformations[transformations_index]
|
|
153
180
|
for consumer_index in transformation.consumers:
|
|
154
|
-
#
|
|
181
|
+
# If the consumer needs to use newly added ops, then the new added op
|
|
155
182
|
# index needs to be outside of the range of the orignal op ids.
|
|
156
183
|
if consumer_index in prev_transformation.consumers:
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
184
|
+
if was_op_added:
|
|
185
|
+
transformation.producer = (
|
|
186
|
+
len(self._original_op_id_map[subgraph_id])
|
|
187
|
+
+ len(self._added_op_id_map[subgraph_id])
|
|
188
|
+
- 1
|
|
189
|
+
)
|
|
162
190
|
transformation.tensor_id = trans_info.output_tensor_id
|
|
163
191
|
|
|
192
|
+
def _get_updated_producer_id(
|
|
193
|
+
self, original_producer_id: int, subgraph_id: int
|
|
194
|
+
) -> int:
|
|
195
|
+
"""Update the producer of a transformation instruction."""
|
|
196
|
+
if original_producer_id is None or original_producer_id < 0:
|
|
197
|
+
producer = -1
|
|
198
|
+
elif original_producer_id < len(self._original_op_id_map[subgraph_id]):
|
|
199
|
+
producer = self._original_op_id_map[subgraph_id][original_producer_id]
|
|
200
|
+
else:
|
|
201
|
+
# If the producer id is not in the original op map, it's an added op,
|
|
202
|
+
# go the added op map to find the producer.
|
|
203
|
+
producer = self._added_op_id_map[subgraph_id][
|
|
204
|
+
original_producer_id - len(self._original_op_id_map[subgraph_id])
|
|
205
|
+
]
|
|
206
|
+
return producer
|
|
207
|
+
|
|
208
|
+
def _get_updated_consumer_ids(
|
|
209
|
+
self,
|
|
210
|
+
original_consumer_ids: list[int],
|
|
211
|
+
subgraph_id: int,
|
|
212
|
+
) -> list[int]:
|
|
213
|
+
"""Update the consumers of a transformation instruction."""
|
|
214
|
+
consumers = []
|
|
215
|
+
for original_op_id in original_consumer_ids:
|
|
216
|
+
new_consumer_id = (
|
|
217
|
+
-1
|
|
218
|
+
if original_op_id == -1
|
|
219
|
+
else self._original_op_id_map[subgraph_id][original_op_id]
|
|
220
|
+
)
|
|
221
|
+
consumers.append(new_consumer_id)
|
|
222
|
+
return consumers
|
|
223
|
+
|
|
164
224
|
def _apply_single_transformation(
|
|
165
225
|
self,
|
|
166
226
|
transformation_inst: qtyping.TensorTransformationInsts,
|
|
@@ -179,28 +239,12 @@ class TransformationPerformer:
|
|
|
179
239
|
None, update the transformation_inst & tflite_model in place
|
|
180
240
|
"""
|
|
181
241
|
instruction = transformation_inst.instructions[transformation_index]
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
instruction.producer
|
|
189
|
-
]
|
|
190
|
-
else:
|
|
191
|
-
# if the producer id is not in the original op map, it's an added op,
|
|
192
|
-
# go the corresponding new maps
|
|
193
|
-
producer = self._added_op_id_map[transformation_inst.subgraph_id][
|
|
194
|
-
instruction.producer
|
|
195
|
-
- len(self._original_op_id_map[transformation_inst.subgraph_id])
|
|
196
|
-
]
|
|
197
|
-
consumers = []
|
|
198
|
-
for original_op_id in instruction.consumers:
|
|
199
|
-
consumers.append(
|
|
200
|
-
self._original_op_id_map[transformation_inst.subgraph_id][
|
|
201
|
-
original_op_id
|
|
202
|
-
]
|
|
203
|
-
)
|
|
242
|
+
producer = self._get_updated_producer_id(
|
|
243
|
+
instruction.producer, transformation_inst.subgraph_id
|
|
244
|
+
)
|
|
245
|
+
consumers = self._get_updated_consumer_ids(
|
|
246
|
+
instruction.consumers, transformation_inst.subgraph_id
|
|
247
|
+
)
|
|
204
248
|
trans_info = self._transformation_registration[instruction.transformation](
|
|
205
249
|
transformation_utils.TransformationInput(
|
|
206
250
|
instruction.tensor_id,
|
|
@@ -220,7 +264,12 @@ class TransformationPerformer:
|
|
|
220
264
|
)
|
|
221
265
|
self._update_op_id_map(
|
|
222
266
|
transformation_inst.subgraph_id,
|
|
223
|
-
|
|
267
|
+
# The added op must be right before the most immediate consumer, unless
|
|
268
|
+
# the consumer is the graph output (id=-1), then use the producer's
|
|
269
|
+
# index instead.
|
|
270
|
+
min(instruction.consumers)
|
|
271
|
+
if min(instruction.consumers) >= 0
|
|
272
|
+
else instruction.producer + 1,
|
|
224
273
|
trans_info.num_ops_added,
|
|
225
274
|
)
|
|
226
275
|
|
|
@@ -260,19 +309,24 @@ class TransformationPerformer:
|
|
|
260
309
|
self,
|
|
261
310
|
transformation_instructions: dict[str, qtyping.TensorTransformationInsts],
|
|
262
311
|
tflite_model: schema_py_generated.ModelT,
|
|
263
|
-
|
|
264
|
-
|
|
312
|
+
tensor_processing_order: Optional[Sequence[str]] = None,
|
|
313
|
+
) -> None:
|
|
314
|
+
"""Apply all transformations to the given tflite_model in place.
|
|
265
315
|
|
|
266
316
|
Args:
|
|
267
|
-
transformation_instructions:
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
317
|
+
transformation_instructions: Mapping from tensor name to its
|
|
318
|
+
transformation instructions, produced by
|
|
319
|
+
transformation_instruction_generator.
|
|
320
|
+
tflite_model: The tflite model to apply quantization to.
|
|
321
|
+
tensor_processing_order: The order of tensors to process. If not provided,
|
|
322
|
+
the order will be inferred from `transformation_instructions`.
|
|
273
323
|
"""
|
|
274
324
|
self._original_op_id_map = []
|
|
275
325
|
self._added_op_id_map = []
|
|
276
326
|
self._create_op_id_map(tflite_model)
|
|
277
|
-
|
|
278
|
-
|
|
327
|
+
if tensor_processing_order is None:
|
|
328
|
+
tensor_processing_order = transformation_instructions.keys()
|
|
329
|
+
for tensor_name in tensor_processing_order:
|
|
330
|
+
self._apply_transformations(
|
|
331
|
+
transformation_instructions[tensor_name], tflite_model
|
|
332
|
+
)
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
"""Tests for transformation_performer."""
|
|
17
17
|
|
|
18
|
+
import copy
|
|
18
19
|
import os
|
|
19
20
|
|
|
20
21
|
import numpy as np
|
|
@@ -26,6 +27,9 @@ from ai_edge_quantizer import transformation_performer
|
|
|
26
27
|
from ai_edge_quantizer.utils import test_utils
|
|
27
28
|
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
28
29
|
|
|
30
|
+
_QTransf = qtyping.QuantTransformation
|
|
31
|
+
|
|
32
|
+
|
|
29
33
|
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(".")
|
|
30
34
|
|
|
31
35
|
|
|
@@ -108,6 +112,32 @@ class TransformationPerformerTest(parameterized.TestCase):
|
|
|
108
112
|
for index, op_id in enumerate(op_id_map[0]):
|
|
109
113
|
self.assertEqual(op_id, index)
|
|
110
114
|
|
|
115
|
+
def test_update_op_id_map_not_changing_value_single_op_model(self):
|
|
116
|
+
"""test for _update_op_id_map."""
|
|
117
|
+
model = tfl_flatbuffer_utils.read_model(
|
|
118
|
+
os.path.join(
|
|
119
|
+
TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
self._transformation_performer._create_op_id_map(model)
|
|
123
|
+
instruction = qtyping.TransformationInst(
|
|
124
|
+
transformation=qtyping.QuantTransformation.QUANTIZE_TENSOR,
|
|
125
|
+
tensor_id=0,
|
|
126
|
+
producer=0,
|
|
127
|
+
consumers=[-1],
|
|
128
|
+
parameters=qtyping.UniformQuantParams(
|
|
129
|
+
8, None, np.array([1]), np.array([0])
|
|
130
|
+
),
|
|
131
|
+
)
|
|
132
|
+
producer = self._transformation_performer._get_updated_producer_id(
|
|
133
|
+
instruction.producer, 0
|
|
134
|
+
)
|
|
135
|
+
consumers = self._transformation_performer._get_updated_consumer_ids(
|
|
136
|
+
instruction.consumers, 0
|
|
137
|
+
)
|
|
138
|
+
self.assertEqual(producer, 0)
|
|
139
|
+
self.assertEqual(consumers, [-1])
|
|
140
|
+
|
|
111
141
|
@parameterized.named_parameters(
|
|
112
142
|
dict(
|
|
113
143
|
testcase_name="test_no_update",
|
|
@@ -267,6 +297,52 @@ class TransformationPerformerTest(parameterized.TestCase):
|
|
|
267
297
|
expected_added_op_id_map,
|
|
268
298
|
)
|
|
269
299
|
|
|
300
|
+
def test_update_instructions_updates_tensor_id_after_duplicate_tensor(self):
|
|
301
|
+
def get_test_instruction(transformation, consumers):
|
|
302
|
+
return qtyping.TransformationInst(
|
|
303
|
+
transformation=transformation,
|
|
304
|
+
consumers=consumers,
|
|
305
|
+
# Dummy values below.
|
|
306
|
+
tensor_id=0,
|
|
307
|
+
producer=0,
|
|
308
|
+
parameters=qtyping.UniformQuantParams(
|
|
309
|
+
8, None, np.array([1]), np.array([0])
|
|
310
|
+
),
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
instructions = [
|
|
314
|
+
get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[1]),
|
|
315
|
+
get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1]),
|
|
316
|
+
get_test_instruction(_QTransf.ADD_DEQUANTIZE, consumers=[1]),
|
|
317
|
+
get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[2]),
|
|
318
|
+
]
|
|
319
|
+
# Simulate a situation as if the first instruction (duplicate tensor) was
|
|
320
|
+
# applied.
|
|
321
|
+
subgraph_id = 0
|
|
322
|
+
duplicated_tensor_id = 13
|
|
323
|
+
prev_trans_idx = 0
|
|
324
|
+
trans_info = qtyping.TransformationInfo(
|
|
325
|
+
# Copy of what duplicate_tensor.py returns.
|
|
326
|
+
op_id=0,
|
|
327
|
+
num_ops_added=0,
|
|
328
|
+
output_tensor_id=duplicated_tensor_id,
|
|
329
|
+
)
|
|
330
|
+
self._transformation_performer._create_op_id_map(self._test_model)
|
|
331
|
+
self._transformation_performer._update_instructions(
|
|
332
|
+
prev_trans_idx, instructions, subgraph_id, trans_info
|
|
333
|
+
)
|
|
334
|
+
# Expecting the ops with the same consumers as in the DUPLICATE_TENSOR
|
|
335
|
+
# instruction to use the new tensor id.
|
|
336
|
+
expected_instructions = copy.deepcopy(instructions)
|
|
337
|
+
expected_instructions[1].tensor_id = duplicated_tensor_id
|
|
338
|
+
expected_instructions[2].tensor_id = duplicated_tensor_id
|
|
339
|
+
self.assertSequenceEqual(instructions, expected_instructions)
|
|
340
|
+
# Expecting no change to the op id map.
|
|
341
|
+
self.assertListEqual(
|
|
342
|
+
self._transformation_performer._added_op_id_map,
|
|
343
|
+
[[]],
|
|
344
|
+
)
|
|
345
|
+
|
|
270
346
|
def test_transform_graph(self):
|
|
271
347
|
"""test for transform_graph."""
|
|
272
348
|
instructions = {
|
|
@@ -275,6 +351,8 @@ class TransformationPerformerTest(parameterized.TestCase):
|
|
|
275
351
|
tensor_name="sequential/conv2d/Relu;sequential/conv2d/BiasAdd;"
|
|
276
352
|
+ "sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1",
|
|
277
353
|
subgraph_id=0,
|
|
354
|
+
# Conv2d: op_id=0, output_tensor_id=7.
|
|
355
|
+
# This should add two sequential dequants after the conv2d.
|
|
278
356
|
instructions=[
|
|
279
357
|
qtyping.TransformationInst(
|
|
280
358
|
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
@@ -299,6 +377,8 @@ class TransformationPerformerTest(parameterized.TestCase):
|
|
|
299
377
|
"sequential/average_pooling2d/AvgPool": qtyping.TensorTransformationInsts(
|
|
300
378
|
tensor_name="sequential/average_pooling2d/AvgPool",
|
|
301
379
|
subgraph_id=0,
|
|
380
|
+
# Avg_pool: op_id=1, output_tensor_id=8.
|
|
381
|
+
# This should add two sequential dequants after the avg_pool.
|
|
302
382
|
instructions=[
|
|
303
383
|
qtyping.TransformationInst(
|
|
304
384
|
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
@@ -326,19 +406,111 @@ class TransformationPerformerTest(parameterized.TestCase):
|
|
|
326
406
|
)
|
|
327
407
|
self.assertLen(self._test_model.subgraphs, 1)
|
|
328
408
|
self.assertLen(self._test_model.subgraphs[0].operators, 10)
|
|
409
|
+
# The original model has 13 tensors, each dequant adds 1 tensor.
|
|
329
410
|
self.assertLen(self._test_model.subgraphs[0].tensors, 17)
|
|
411
|
+
# Check that the dequant opcode is added to the model.
|
|
330
412
|
self.assertEqual(
|
|
331
413
|
self._test_model.subgraphs[0].operators[1].opcodeIndex,
|
|
332
414
|
len(self._test_model.operatorCodes) - 1,
|
|
333
415
|
)
|
|
416
|
+
# Conv2d, dequant, dequant, avgpool, dequant, dequant, etc.
|
|
417
|
+
expected_builtin_op_order = [3, 6, 6, 1, 6, 6, 22, 9, 9, 25]
|
|
418
|
+
for i, op in enumerate(self._test_model.subgraphs[0].operators):
|
|
419
|
+
op_code = self._test_model.operatorCodes[op.opcodeIndex].builtinCode
|
|
420
|
+
self.assertEqual(op_code, expected_builtin_op_order[i])
|
|
421
|
+
# Check that the first dequant input is connected to the conv2d output.
|
|
422
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[1].inputs[0], 7)
|
|
423
|
+
# Output is a new tensor just added.
|
|
424
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[1].outputs[0], 13)
|
|
425
|
+
# Second dequant has new tensors.
|
|
334
426
|
self.assertEqual(self._test_model.subgraphs[0].operators[2].inputs[0], 13)
|
|
335
427
|
self.assertEqual(self._test_model.subgraphs[0].operators[2].outputs[0], 14)
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
)
|
|
428
|
+
# Avgpool's input is second dequant's output.
|
|
429
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[3].inputs[0], 14)
|
|
430
|
+
# Avgpool's output remains the same.
|
|
340
431
|
self.assertEqual(self._test_model.subgraphs[0].operators[3].outputs[0], 8)
|
|
432
|
+
# Third dequant's output is a new tensor.
|
|
341
433
|
self.assertEqual(self._test_model.subgraphs[0].operators[4].outputs[0], 15)
|
|
434
|
+
# Fourth dequant.
|
|
435
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[5].inputs[0], 15)
|
|
436
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[5].outputs[0], 16)
|
|
437
|
+
|
|
438
|
+
# Avgpool (op_id=1) and reshape (op_id=2) are bumped by 2 due to the two
|
|
439
|
+
# dequants added after it.
|
|
440
|
+
expected_op_id_map = [0, 3, 6, 7, 8, 9]
|
|
441
|
+
self.assertEqual(
|
|
442
|
+
self._transformation_performer._original_op_id_map[0],
|
|
443
|
+
expected_op_id_map,
|
|
444
|
+
)
|
|
445
|
+
# New dequants are added at these indices.
|
|
446
|
+
expected_added_op_id_map = [1, 2, 4, 5]
|
|
447
|
+
self.assertEqual(
|
|
448
|
+
self._transformation_performer._added_op_id_map[0],
|
|
449
|
+
expected_added_op_id_map,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
def test_op_insertion_at_input_and_output(self):
|
|
453
|
+
"""test for _update_op_id_map."""
|
|
454
|
+
model = tfl_flatbuffer_utils.read_model(
|
|
455
|
+
os.path.join(
|
|
456
|
+
TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
|
|
457
|
+
)
|
|
458
|
+
)
|
|
459
|
+
self._transformation_performer._create_op_id_map(model)
|
|
460
|
+
instructions = {
|
|
461
|
+
# Fully_connected: op_id=0, input_tensor_id=0, output_tensor_id=3.
|
|
462
|
+
# Add a new quantize op to the input of the fully_connected.
|
|
463
|
+
"serving_default_input_2:0": qtyping.TensorTransformationInsts(
|
|
464
|
+
tensor_name="serving_default_input_2:0",
|
|
465
|
+
subgraph_id=0,
|
|
466
|
+
instructions=[
|
|
467
|
+
qtyping.TransformationInst(
|
|
468
|
+
transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
|
|
469
|
+
tensor_id=0,
|
|
470
|
+
producer=-1,
|
|
471
|
+
consumers=[0],
|
|
472
|
+
parameters=qtyping.UniformQuantParams(
|
|
473
|
+
8, None, np.array([1]), np.array([0])
|
|
474
|
+
),
|
|
475
|
+
),
|
|
476
|
+
],
|
|
477
|
+
),
|
|
478
|
+
# Add a new dequantize op to the output of the fully_connected.
|
|
479
|
+
"StatefulPartitionedCall:0": qtyping.TensorTransformationInsts(
|
|
480
|
+
tensor_name="StatefulPartitionedCall:0",
|
|
481
|
+
subgraph_id=0,
|
|
482
|
+
instructions=[
|
|
483
|
+
qtyping.TransformationInst(
|
|
484
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
485
|
+
tensor_id=3,
|
|
486
|
+
producer=0,
|
|
487
|
+
consumers=[-1],
|
|
488
|
+
parameters=qtyping.UniformQuantParams(
|
|
489
|
+
8, None, np.array([1]), np.array([0])
|
|
490
|
+
),
|
|
491
|
+
),
|
|
492
|
+
],
|
|
493
|
+
),
|
|
494
|
+
}
|
|
495
|
+
self._transformation_performer.transform_graph(instructions, model)
|
|
496
|
+
|
|
497
|
+
# Original fc (op_id=0) should be bumped to op_id=1.
|
|
498
|
+
self.assertEqual(
|
|
499
|
+
self._transformation_performer._original_op_id_map[0],
|
|
500
|
+
[1],
|
|
501
|
+
)
|
|
502
|
+
# New quantize added at op_id=0, dequantize added at op_id=1.
|
|
503
|
+
expected_added_op_id_map = [0, 2]
|
|
504
|
+
self.assertEqual(
|
|
505
|
+
self._transformation_performer._added_op_id_map[0],
|
|
506
|
+
expected_added_op_id_map,
|
|
507
|
+
)
|
|
508
|
+
# Quantize, fully_connected, dequantize.
|
|
509
|
+
expected_builtin_op_order = [114, 9, 6]
|
|
510
|
+
for i, op in enumerate(model.subgraphs[0].operators):
|
|
511
|
+
op_code = model.operatorCodes[op.opcodeIndex].builtinCode
|
|
512
|
+
self.assertEqual(op_code, expected_builtin_op_order[i])
|
|
513
|
+
|
|
342
514
|
|
|
343
515
|
if __name__ == "__main__":
|
|
344
516
|
googletest.main()
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Duplicate buffer transformation."""
|
|
17
|
+
|
|
18
|
+
from ai_edge_quantizer import qtyping
|
|
19
|
+
from ai_edge_quantizer.transformations import transformation_utils
|
|
20
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def duplicate_buffer(
|
|
24
|
+
transformation_input: transformation_utils.TransformationInput,
|
|
25
|
+
) -> qtyping.TransformationInfo:
|
|
26
|
+
"""Duplicates the buffer of the tensor."""
|
|
27
|
+
tensor_id = transformation_input.tensor_id
|
|
28
|
+
tensor = transformation_input.subgraph.tensors[tensor_id]
|
|
29
|
+
buffer_data = transformation_input.buffers[tensor.buffer].data
|
|
30
|
+
if buffer_data is None:
|
|
31
|
+
tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
|
|
32
|
+
raise ValueError(
|
|
33
|
+
'Duplicate Buffer transformation supports only constant tensors.'
|
|
34
|
+
f' Tensor {tensor_name} is not constant.'
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
duplicated_buffer_id = transformation_utils.get_constant_buffer(
|
|
38
|
+
data=buffer_data,
|
|
39
|
+
buffers=transformation_input.buffers,
|
|
40
|
+
force_duplicate_buffer=True,
|
|
41
|
+
)
|
|
42
|
+
tensor.buffer = duplicated_buffer_id
|
|
43
|
+
|
|
44
|
+
return qtyping.TransformationInfo(
|
|
45
|
+
op_id=0, num_ops_added=0, output_tensor_id=tensor_id
|
|
46
|
+
)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import numpy as np
|
|
18
|
+
from tensorflow.python.platform import googletest
|
|
19
|
+
from ai_edge_quantizer import qtyping
|
|
20
|
+
from ai_edge_quantizer.transformations import duplicate_buffer
|
|
21
|
+
from ai_edge_quantizer.transformations import transformation_utils
|
|
22
|
+
from ai_edge_quantizer.utils import test_utils
|
|
23
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
24
|
+
|
|
25
|
+
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('..')
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DuplicateBufferTest(googletest.TestCase):
|
|
29
|
+
|
|
30
|
+
def setUp(self):
|
|
31
|
+
super().setUp()
|
|
32
|
+
model_path = os.path.join(
|
|
33
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/weight_sharing_fcs.tflite'
|
|
34
|
+
)
|
|
35
|
+
self.model = tfl_flatbuffer_utils.read_model(model_path)
|
|
36
|
+
|
|
37
|
+
def _get_transformation_input(
|
|
38
|
+
self, subgraph_idx: int, tensor_idx: int
|
|
39
|
+
) -> transformation_utils.TransformationInput:
|
|
40
|
+
return transformation_utils.TransformationInput(
|
|
41
|
+
tensor_id=tensor_idx,
|
|
42
|
+
buffers=self.model.buffers,
|
|
43
|
+
# Dummy params below.
|
|
44
|
+
op_codes=self.model.operatorCodes,
|
|
45
|
+
subgraph=self.model.subgraphs[subgraph_idx],
|
|
46
|
+
producer=-1,
|
|
47
|
+
consumers=[],
|
|
48
|
+
quant_params=qtyping.UniformQuantParams(
|
|
49
|
+
num_bits=8,
|
|
50
|
+
quantized_dimension=None,
|
|
51
|
+
scale=np.ones(1),
|
|
52
|
+
zero_point=np.zeros(1),
|
|
53
|
+
),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def test_constant_buffer_is_correctly_duplicated(self):
|
|
57
|
+
# Duplicate the FC weight tensor in the second subgraph.
|
|
58
|
+
subgraph_idx = 1
|
|
59
|
+
subgraph = self.model.subgraphs[subgraph_idx]
|
|
60
|
+
weight_tensor_idx = 1
|
|
61
|
+
prev_buffer_id = subgraph.tensors[weight_tensor_idx].buffer
|
|
62
|
+
prev_num_buffers = len(self.model.buffers)
|
|
63
|
+
transformation_input = self._get_transformation_input(
|
|
64
|
+
subgraph_idx, weight_tensor_idx
|
|
65
|
+
)
|
|
66
|
+
transformation_info = duplicate_buffer.duplicate_buffer(
|
|
67
|
+
transformation_input
|
|
68
|
+
)
|
|
69
|
+
self.assertEqual(transformation_info.op_id, 0)
|
|
70
|
+
self.assertEqual(transformation_info.num_ops_added, 0)
|
|
71
|
+
self.assertEqual(transformation_info.output_tensor_id, 1)
|
|
72
|
+
# Check that a new buffer was added.
|
|
73
|
+
self.assertLen(self.model.buffers, prev_num_buffers + 1)
|
|
74
|
+
# Check that the new buffer is used by the weight tensor.
|
|
75
|
+
new_buffer_id = len(self.model.buffers) - 1
|
|
76
|
+
self.assertEqual(subgraph.tensors[weight_tensor_idx].buffer, new_buffer_id)
|
|
77
|
+
# Check that the new buffer has the same data as the original one.
|
|
78
|
+
self.assertTrue(
|
|
79
|
+
np.all(
|
|
80
|
+
np.frombuffer(
|
|
81
|
+
self.model.buffers[new_buffer_id].data,
|
|
82
|
+
dtype=np.float32,
|
|
83
|
+
)
|
|
84
|
+
== np.frombuffer(
|
|
85
|
+
self.model.buffers[prev_buffer_id].data,
|
|
86
|
+
dtype=np.float32,
|
|
87
|
+
)
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def test_duplicate_buffer_raises_error_when_tensor_is_not_constant(self):
|
|
92
|
+
# Duplicate the FC input tensor in the second subgraph.
|
|
93
|
+
subgraph_idx = 1
|
|
94
|
+
weight_tensor_idx = 0
|
|
95
|
+
transformation_input = self._get_transformation_input(
|
|
96
|
+
subgraph_idx, weight_tensor_idx
|
|
97
|
+
)
|
|
98
|
+
with self.assertRaisesRegex(
|
|
99
|
+
ValueError,
|
|
100
|
+
'Duplicate Buffer transformation supports only constant tensors.',
|
|
101
|
+
):
|
|
102
|
+
duplicate_buffer.duplicate_buffer(transformation_input)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
if __name__ == '__main__':
|
|
106
|
+
googletest.main()
|