ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,278 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Python manager for transformations to be applied to TFlite models."""
|
17
|
+
|
18
|
+
import numpy as np
|
19
|
+
from ai_edge_quantizer import qtyping
|
20
|
+
from ai_edge_quantizer.transformations import dequant_insert
|
21
|
+
from ai_edge_quantizer.transformations import emulated_subchannel
|
22
|
+
from ai_edge_quantizer.transformations import quant_insert
|
23
|
+
from ai_edge_quantizer.transformations import quantize_tensor
|
24
|
+
from ai_edge_quantizer.transformations import transformation_utils
|
25
|
+
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
26
|
+
|
27
|
+
|
28
|
+
class TransformationPerformer:
|
29
|
+
"""Wrapper class for transformations.
|
30
|
+
|
31
|
+
all transformations supported by the AI Edge Quantizer should be registered in
|
32
|
+
the Transformation Performer.
|
33
|
+
|
34
|
+
transformation to be appied to a tensor in the Instruction Generator is
|
35
|
+
specified
|
36
|
+
as a key. Given this key, the transformation performer will return a function
|
37
|
+
that will apply the corresponding transformation on the graph
|
38
|
+
|
39
|
+
A transformation is defined as a Callable that takes the following parameters:
|
40
|
+
tensor_id: A tensor id that represents the tensor to be applied
|
41
|
+
operatorCodes: list of OperatorCodesT from the source TFlite ModelT
|
42
|
+
buffers: list of BufferT from the source TFLite ModelT
|
43
|
+
subgraph: the specific subgraph where the transformation should be applied
|
44
|
+
producer: the op index for the producer of the tensor
|
45
|
+
consumers: a list of op index representing consumers to apply the change on
|
46
|
+
quant_param: the quantization parameters in qtyping.UniformQuantParams
|
47
|
+
And returns a qtyping.TransformationInfo which contains the index where the
|
48
|
+
ops are added and how many ops are added
|
49
|
+
|
50
|
+
Additionally, op additions must be consecutive
|
51
|
+
|
52
|
+
this class is expected to be created by the Model Modifier and nothing else
|
53
|
+
|
54
|
+
Model modifier would pass in a dict of transformations to be applied, this
|
55
|
+
class will apply the transformations in a pre-determined static order
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(self):
|
59
|
+
"""Initializes the TransformationPerformer."""
|
60
|
+
self._transformation_registration = {
|
61
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE: (
|
62
|
+
dequant_insert.insert_dequant
|
63
|
+
),
|
64
|
+
qtyping.QuantTransformation.QUANTIZE_TENSOR: (
|
65
|
+
quantize_tensor.quantize_tensor
|
66
|
+
),
|
67
|
+
qtyping.QuantTransformation.EMULATED_SUBCHANNEL: (
|
68
|
+
emulated_subchannel.emulated_subchannel
|
69
|
+
),
|
70
|
+
qtyping.QuantTransformation.ADD_QUANTIZE: quant_insert.insert_quant,
|
71
|
+
}
|
72
|
+
# transformations are seprated in two categories:
|
73
|
+
# op_insertion_transformations are transformations that only insert ops
|
74
|
+
# into the graph, whereas op_replacement_transformations will replace one op
|
75
|
+
# with a pattern
|
76
|
+
self._op_insertion_transformations = set([
|
77
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
78
|
+
qtyping.QuantTransformation.QUANTIZE_TENSOR,
|
79
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
80
|
+
])
|
81
|
+
self._op_replacement_transformations = set(
|
82
|
+
[qtyping.QuantTransformation.EMULATED_SUBCHANNEL]
|
83
|
+
)
|
84
|
+
self._original_op_id_map = []
|
85
|
+
self._added_op_id_map = []
|
86
|
+
|
87
|
+
def _create_op_id_map(self, tflite_model: schema_py_generated.ModelT):
|
88
|
+
"""init the original op_id to modified op_id map.
|
89
|
+
|
90
|
+
At the beginning the graph has not been updated, so op_id maps to it's
|
91
|
+
current id.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
tflite_model: the model we're create op_id mapping
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
None, modifies self._original_op_id_map inplace
|
98
|
+
"""
|
99
|
+
for subgraph in tflite_model.subgraphs:
|
100
|
+
self._original_op_id_map.append(list(range(len(subgraph.operators))))
|
101
|
+
self._added_op_id_map.append([])
|
102
|
+
|
103
|
+
def _update_op_id_map(
|
104
|
+
self, subgraph_id: int, original_op_id: int, num_ops_added: int
|
105
|
+
):
|
106
|
+
"""Update the mapping between the original op id and modified op ids.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
subgraph_id: the index of subgraph that we're interested in
|
110
|
+
original_op_id: the original id for which the first op is added
|
111
|
+
num_ops_added: the number of ops added starting from the op id
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
None, modify self._original_op_id_map
|
115
|
+
"""
|
116
|
+
np_op_id_map = np.array(self._original_op_id_map[subgraph_id])
|
117
|
+
np_op_id_map[original_op_id:] += num_ops_added
|
118
|
+
self._original_op_id_map[subgraph_id] = np_op_id_map.tolist()
|
119
|
+
|
120
|
+
def _update_instructions(
|
121
|
+
self,
|
122
|
+
prev_transformation_index: int,
|
123
|
+
transformations: list[qtyping.TransformationInst],
|
124
|
+
subgraph_id: int,
|
125
|
+
trans_info: qtyping.TransformationInfo,
|
126
|
+
):
|
127
|
+
"""Update the instructions after the graph is modified.
|
128
|
+
|
129
|
+
After an op is inserted, the topology is changed and this may impact the
|
130
|
+
following transformation to be applied. So we need to update instructions
|
131
|
+
that have yet to be applied.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
prev_transformation_index: the index of the last applied transformation
|
135
|
+
transformations: the list of transformations we're applying
|
136
|
+
subgraph_id: the subgraph where the provided instrucitons belongs to
|
137
|
+
trans_info: transformation info returned by a transformation
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
None, modifies the transformation in place
|
141
|
+
"""
|
142
|
+
# if no ops were added, then no need for update
|
143
|
+
if trans_info.num_ops_added == 0:
|
144
|
+
return
|
145
|
+
prev_transformation = transformations[prev_transformation_index]
|
146
|
+
self._added_op_id_map[subgraph_id].append(
|
147
|
+
trans_info.op_id + trans_info.num_ops_added - 1
|
148
|
+
)
|
149
|
+
for transformations_index in range(
|
150
|
+
prev_transformation_index + 1, len(transformations)
|
151
|
+
):
|
152
|
+
transformation = transformations[transformations_index]
|
153
|
+
for consumer_index in transformation.consumers:
|
154
|
+
# if the consumer need to use newly added ops, then the new added op
|
155
|
+
# index needs to be outside of the range of the orignal op ids.
|
156
|
+
if consumer_index in prev_transformation.consumers:
|
157
|
+
transformation.producer = (
|
158
|
+
len(self._original_op_id_map[subgraph_id])
|
159
|
+
+ len(self._added_op_id_map[subgraph_id])
|
160
|
+
- 1
|
161
|
+
)
|
162
|
+
transformation.tensor_id = trans_info.output_tensor_id
|
163
|
+
|
164
|
+
def _apply_single_transformation(
|
165
|
+
self,
|
166
|
+
transformation_inst: qtyping.TensorTransformationInsts,
|
167
|
+
transformation_index: int,
|
168
|
+
tflite_model: schema_py_generated.ModelT,
|
169
|
+
):
|
170
|
+
"""Apply a single transformation.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
transformation_inst: a TensorTransformationInsts type that contains all
|
174
|
+
transformations on a tensor
|
175
|
+
transformation_index: the index of the transformation to be applied
|
176
|
+
tflite_model: source tflite model to be updated
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
None, update the transformation_inst & tflite_model in place
|
180
|
+
"""
|
181
|
+
instruction = transformation_inst.instructions[transformation_index]
|
182
|
+
if not instruction.producer or instruction.producer < 0:
|
183
|
+
producer = -1
|
184
|
+
elif instruction.producer < len(
|
185
|
+
self._original_op_id_map[transformation_inst.subgraph_id]
|
186
|
+
):
|
187
|
+
producer = self._original_op_id_map[transformation_inst.subgraph_id][
|
188
|
+
instruction.producer
|
189
|
+
]
|
190
|
+
else:
|
191
|
+
# if the producer id is not in the original op map, it's an added op,
|
192
|
+
# go the corresponding new maps
|
193
|
+
producer = self._added_op_id_map[transformation_inst.subgraph_id][
|
194
|
+
instruction.producer
|
195
|
+
- len(self._original_op_id_map[transformation_inst.subgraph_id])
|
196
|
+
]
|
197
|
+
consumers = []
|
198
|
+
for original_op_id in instruction.consumers:
|
199
|
+
consumers.append(
|
200
|
+
self._original_op_id_map[transformation_inst.subgraph_id][
|
201
|
+
original_op_id
|
202
|
+
]
|
203
|
+
)
|
204
|
+
trans_info = self._transformation_registration[instruction.transformation](
|
205
|
+
transformation_utils.TransformationInput(
|
206
|
+
instruction.tensor_id,
|
207
|
+
tflite_model.operatorCodes,
|
208
|
+
tflite_model.buffers,
|
209
|
+
tflite_model.subgraphs[transformation_inst.subgraph_id],
|
210
|
+
producer,
|
211
|
+
consumers,
|
212
|
+
instruction.parameters,
|
213
|
+
)
|
214
|
+
)
|
215
|
+
self._update_instructions(
|
216
|
+
transformation_index,
|
217
|
+
transformation_inst.instructions,
|
218
|
+
transformation_inst.subgraph_id,
|
219
|
+
trans_info,
|
220
|
+
)
|
221
|
+
self._update_op_id_map(
|
222
|
+
transformation_inst.subgraph_id,
|
223
|
+
min(instruction.consumers),
|
224
|
+
trans_info.num_ops_added,
|
225
|
+
)
|
226
|
+
|
227
|
+
def _apply_transformations(
|
228
|
+
self,
|
229
|
+
transformation_inst: qtyping.TensorTransformationInsts,
|
230
|
+
tflite_model: schema_py_generated.ModelT,
|
231
|
+
):
|
232
|
+
"""Apply all transformations for a tensor.
|
233
|
+
|
234
|
+
transformations are separated in two types and applied separately in two
|
235
|
+
different passes
|
236
|
+
|
237
|
+
Args:
|
238
|
+
transformation_inst: a TensorTransformationInsts type that contains all
|
239
|
+
transformation on a tensor
|
240
|
+
tflite_model: source tflite model to be updated
|
241
|
+
|
242
|
+
Returns:
|
243
|
+
None, update the transformation_inst & tflite_model in place
|
244
|
+
"""
|
245
|
+
# pass 1: apply all the op insertion transformation, because op replacement
|
246
|
+
# may remove consumer or producer of some tensors
|
247
|
+
for index, instruction in enumerate(transformation_inst.instructions):
|
248
|
+
if instruction.transformation in self._op_insertion_transformations:
|
249
|
+
self._apply_single_transformation(
|
250
|
+
transformation_inst, index, tflite_model
|
251
|
+
)
|
252
|
+
# pass 2: apply all the op replacement transformation
|
253
|
+
for index, instruction in enumerate(transformation_inst.instructions):
|
254
|
+
if instruction.transformation in self._op_replacement_transformations:
|
255
|
+
self._apply_single_transformation(
|
256
|
+
transformation_inst, index, tflite_model
|
257
|
+
)
|
258
|
+
|
259
|
+
def transform_graph(
|
260
|
+
self,
|
261
|
+
transformation_instructions: dict[str, qtyping.TensorTransformationInsts],
|
262
|
+
tflite_model: schema_py_generated.ModelT,
|
263
|
+
):
|
264
|
+
"""Apply all transformations to the given tflite_model.
|
265
|
+
|
266
|
+
Args:
|
267
|
+
transformation_instructions: a dict of transformation instructions grouped
|
268
|
+
by tensors, produced by transformation_instruction_generator
|
269
|
+
tflite_model: the tflite model to apply quantization on
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
None, modifies the input tflite_model in place
|
273
|
+
"""
|
274
|
+
self._original_op_id_map = []
|
275
|
+
self._added_op_id_map = []
|
276
|
+
self._create_op_id_map(tflite_model)
|
277
|
+
for transformation_inst in transformation_instructions.values():
|
278
|
+
self._apply_transformations(transformation_inst, tflite_model)
|
@@ -0,0 +1,344 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Tests for transformation_performer."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
|
22
|
+
from tensorflow.python.platform import googletest
|
23
|
+
from absl.testing import parameterized
|
24
|
+
from ai_edge_quantizer import qtyping
|
25
|
+
from ai_edge_quantizer import transformation_performer
|
26
|
+
from ai_edge_quantizer.utils import test_utils
|
27
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
28
|
+
|
29
|
+
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(".")
|
30
|
+
|
31
|
+
|
32
|
+
class TransformationPerformerTest(parameterized.TestCase):
|
33
|
+
|
34
|
+
def setUp(self):
|
35
|
+
super().setUp()
|
36
|
+
self._transformation_performer = (
|
37
|
+
transformation_performer.TransformationPerformer()
|
38
|
+
)
|
39
|
+
self._test_model = tfl_flatbuffer_utils.read_model(
|
40
|
+
os.path.join(TEST_DATA_PREFIX_PATH, "tests/models/conv_fc_mnist.tflite")
|
41
|
+
)
|
42
|
+
|
43
|
+
def test_apply_single_insert_dequant(self):
|
44
|
+
"""test for _apply_transformation."""
|
45
|
+
self._transformation_performer._create_op_id_map(self._test_model)
|
46
|
+
instructions = qtyping.TensorTransformationInsts(
|
47
|
+
tensor_name="sequential/conv2d/Relu;sequential/conv2d/BiasAdd;"
|
48
|
+
+ "sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1",
|
49
|
+
subgraph_id=0,
|
50
|
+
instructions=[
|
51
|
+
qtyping.TransformationInst(
|
52
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
53
|
+
tensor_id=7,
|
54
|
+
producer=0,
|
55
|
+
consumers=[1],
|
56
|
+
parameters=qtyping.UniformQuantParams(
|
57
|
+
8, None, np.array([1]), np.array([0])
|
58
|
+
),
|
59
|
+
),
|
60
|
+
qtyping.TransformationInst(
|
61
|
+
transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
|
62
|
+
tensor_id=7,
|
63
|
+
producer=0,
|
64
|
+
consumers=[1],
|
65
|
+
parameters=qtyping.UniformQuantParams(
|
66
|
+
8, None, np.array([1]), np.array([0])
|
67
|
+
),
|
68
|
+
),
|
69
|
+
],
|
70
|
+
)
|
71
|
+
self._transformation_performer._apply_single_transformation(
|
72
|
+
instructions, 0, self._test_model
|
73
|
+
)
|
74
|
+
subgraph = self._test_model.subgraphs[0]
|
75
|
+
self.assertIn(b"_dequant", subgraph.tensors[13].name)
|
76
|
+
self.assertEqual(
|
77
|
+
subgraph.operators[1].opcodeIndex,
|
78
|
+
len(self._test_model.operatorCodes) - 1,
|
79
|
+
)
|
80
|
+
self.assertEqual(subgraph.operators[2].inputs[0], 13)
|
81
|
+
|
82
|
+
def test_create_op_id_map(self):
|
83
|
+
"""test for _create_op_id_map."""
|
84
|
+
self._transformation_performer._create_op_id_map(self._test_model)
|
85
|
+
op_id_map = self._transformation_performer._original_op_id_map
|
86
|
+
self.assertLen(op_id_map, 1)
|
87
|
+
self.assertLen(op_id_map[0], 6)
|
88
|
+
for index, op_id in enumerate(op_id_map[0]):
|
89
|
+
self.assertEqual(op_id, index)
|
90
|
+
|
91
|
+
def test_update_op_id_map_changing_value(self):
|
92
|
+
"""test for _update_op_id_map."""
|
93
|
+
self._transformation_performer._create_op_id_map(self._test_model)
|
94
|
+
self._transformation_performer._update_op_id_map(0, 1, 6)
|
95
|
+
op_id_map = self._transformation_performer._original_op_id_map
|
96
|
+
self.assertLen(op_id_map, 1)
|
97
|
+
self.assertLen(op_id_map[0], 6)
|
98
|
+
for index in range(1, len(op_id_map[0])):
|
99
|
+
self.assertEqual(op_id_map[0][index], index + 6)
|
100
|
+
|
101
|
+
def test_update_op_id_map_not_changing_value(self):
|
102
|
+
"""test for _update_op_id_map."""
|
103
|
+
self._transformation_performer._create_op_id_map(self._test_model)
|
104
|
+
self._transformation_performer._update_op_id_map(0, 0, 0)
|
105
|
+
op_id_map = self._transformation_performer._original_op_id_map
|
106
|
+
self.assertLen(op_id_map, 1)
|
107
|
+
self.assertLen(op_id_map[0], 6)
|
108
|
+
for index, op_id in enumerate(op_id_map[0]):
|
109
|
+
self.assertEqual(op_id, index)
|
110
|
+
|
111
|
+
@parameterized.named_parameters(
|
112
|
+
dict(
|
113
|
+
testcase_name="test_no_update",
|
114
|
+
prev_trans_idx=0,
|
115
|
+
instructions=[
|
116
|
+
qtyping.TransformationInst(
|
117
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
118
|
+
tensor_id=0,
|
119
|
+
producer=0,
|
120
|
+
consumers=[1],
|
121
|
+
parameters=qtyping.UniformQuantParams(
|
122
|
+
8, None, np.array([1]), np.array([0])
|
123
|
+
),
|
124
|
+
),
|
125
|
+
qtyping.TransformationInst(
|
126
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
127
|
+
tensor_id=0,
|
128
|
+
producer=0,
|
129
|
+
consumers=[1],
|
130
|
+
parameters=qtyping.UniformQuantParams(
|
131
|
+
8, None, np.array([1]), np.array([0])
|
132
|
+
),
|
133
|
+
),
|
134
|
+
],
|
135
|
+
trans_info=qtyping.TransformationInfo(0, 0, 0),
|
136
|
+
expected_instructions=[
|
137
|
+
qtyping.TransformationInst(
|
138
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
139
|
+
tensor_id=0,
|
140
|
+
producer=0,
|
141
|
+
consumers=[1],
|
142
|
+
parameters=qtyping.UniformQuantParams(
|
143
|
+
8, None, np.array([1]), np.array([0])
|
144
|
+
),
|
145
|
+
),
|
146
|
+
qtyping.TransformationInst(
|
147
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
148
|
+
tensor_id=0,
|
149
|
+
producer=0,
|
150
|
+
consumers=[1],
|
151
|
+
parameters=qtyping.UniformQuantParams(
|
152
|
+
8, None, np.array([1]), np.array([0])
|
153
|
+
),
|
154
|
+
),
|
155
|
+
],
|
156
|
+
expected_added_op_id_map=[[]],
|
157
|
+
),
|
158
|
+
dict(
|
159
|
+
testcase_name="test_no_matching_consumer",
|
160
|
+
prev_trans_idx=0,
|
161
|
+
instructions=[
|
162
|
+
qtyping.TransformationInst(
|
163
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
164
|
+
tensor_id=0,
|
165
|
+
producer=0,
|
166
|
+
consumers=[1],
|
167
|
+
parameters=qtyping.UniformQuantParams(
|
168
|
+
8, None, np.array([1]), np.array([0])
|
169
|
+
),
|
170
|
+
),
|
171
|
+
qtyping.TransformationInst(
|
172
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
173
|
+
tensor_id=0,
|
174
|
+
producer=0,
|
175
|
+
consumers=[2],
|
176
|
+
parameters=qtyping.UniformQuantParams(
|
177
|
+
8, None, np.array([1]), np.array([0])
|
178
|
+
),
|
179
|
+
),
|
180
|
+
],
|
181
|
+
trans_info=qtyping.TransformationInfo(2, 2, 13),
|
182
|
+
expected_instructions=[
|
183
|
+
qtyping.TransformationInst(
|
184
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
185
|
+
tensor_id=0,
|
186
|
+
producer=0,
|
187
|
+
consumers=[1],
|
188
|
+
parameters=qtyping.UniformQuantParams(
|
189
|
+
8, None, np.array([1]), np.array([0])
|
190
|
+
),
|
191
|
+
),
|
192
|
+
qtyping.TransformationInst(
|
193
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
194
|
+
tensor_id=0,
|
195
|
+
producer=0,
|
196
|
+
consumers=[2],
|
197
|
+
parameters=qtyping.UniformQuantParams(
|
198
|
+
8, None, np.array([1]), np.array([0])
|
199
|
+
),
|
200
|
+
),
|
201
|
+
],
|
202
|
+
expected_added_op_id_map=[[3]],
|
203
|
+
),
|
204
|
+
dict(
|
205
|
+
testcase_name="test_insert_one_op",
|
206
|
+
prev_trans_idx=0,
|
207
|
+
instructions=[
|
208
|
+
qtyping.TransformationInst(
|
209
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
210
|
+
tensor_id=0,
|
211
|
+
producer=0,
|
212
|
+
consumers=[1],
|
213
|
+
parameters=qtyping.UniformQuantParams(
|
214
|
+
8, None, np.array([1]), np.array([0])
|
215
|
+
),
|
216
|
+
),
|
217
|
+
qtyping.TransformationInst(
|
218
|
+
transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
|
219
|
+
tensor_id=0,
|
220
|
+
producer=0,
|
221
|
+
consumers=[1],
|
222
|
+
parameters=qtyping.UniformQuantParams(
|
223
|
+
8, None, np.array([1]), np.array([0])
|
224
|
+
),
|
225
|
+
),
|
226
|
+
],
|
227
|
+
trans_info=qtyping.TransformationInfo(1, 1, 13),
|
228
|
+
expected_instructions=[
|
229
|
+
qtyping.TransformationInst(
|
230
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
231
|
+
tensor_id=0,
|
232
|
+
producer=0,
|
233
|
+
consumers=[1],
|
234
|
+
parameters=qtyping.UniformQuantParams(
|
235
|
+
8, None, np.array([1]), np.array([0])
|
236
|
+
),
|
237
|
+
),
|
238
|
+
qtyping.TransformationInst(
|
239
|
+
transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
|
240
|
+
tensor_id=13,
|
241
|
+
producer=6,
|
242
|
+
consumers=[1],
|
243
|
+
parameters=qtyping.UniformQuantParams(
|
244
|
+
8, None, np.array([1]), np.array([0])
|
245
|
+
),
|
246
|
+
),
|
247
|
+
],
|
248
|
+
expected_added_op_id_map=[[1]],
|
249
|
+
),
|
250
|
+
)
|
251
|
+
def test_update_instructions(
|
252
|
+
self,
|
253
|
+
prev_trans_idx,
|
254
|
+
instructions,
|
255
|
+
trans_info,
|
256
|
+
expected_instructions,
|
257
|
+
expected_added_op_id_map,
|
258
|
+
):
|
259
|
+
"""test for _update_instructions."""
|
260
|
+
self._transformation_performer._create_op_id_map(self._test_model)
|
261
|
+
self._transformation_performer._update_instructions(
|
262
|
+
prev_trans_idx, instructions, 0, trans_info
|
263
|
+
)
|
264
|
+
self.assertSequenceEqual(instructions, expected_instructions)
|
265
|
+
self.assertListEqual(
|
266
|
+
self._transformation_performer._added_op_id_map,
|
267
|
+
expected_added_op_id_map,
|
268
|
+
)
|
269
|
+
|
270
|
+
def test_transform_graph(self):
|
271
|
+
"""test for transform_graph."""
|
272
|
+
instructions = {
|
273
|
+
"sequential/conv2d/Relu;sequential/conv2d/BiasAdd;"
|
274
|
+
+ "sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1": qtyping.TensorTransformationInsts(
|
275
|
+
tensor_name="sequential/conv2d/Relu;sequential/conv2d/BiasAdd;"
|
276
|
+
+ "sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1",
|
277
|
+
subgraph_id=0,
|
278
|
+
instructions=[
|
279
|
+
qtyping.TransformationInst(
|
280
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
281
|
+
tensor_id=7,
|
282
|
+
producer=0,
|
283
|
+
consumers=[1],
|
284
|
+
parameters=qtyping.UniformQuantParams(
|
285
|
+
8, None, np.array([1]), np.array([0])
|
286
|
+
),
|
287
|
+
),
|
288
|
+
qtyping.TransformationInst(
|
289
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
290
|
+
tensor_id=7,
|
291
|
+
producer=0,
|
292
|
+
consumers=[1],
|
293
|
+
parameters=qtyping.UniformQuantParams(
|
294
|
+
8, None, np.array([1]), np.array([0])
|
295
|
+
),
|
296
|
+
),
|
297
|
+
],
|
298
|
+
),
|
299
|
+
"sequential/average_pooling2d/AvgPool": qtyping.TensorTransformationInsts(
|
300
|
+
tensor_name="sequential/average_pooling2d/AvgPool",
|
301
|
+
subgraph_id=0,
|
302
|
+
instructions=[
|
303
|
+
qtyping.TransformationInst(
|
304
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
305
|
+
tensor_id=8,
|
306
|
+
producer=1,
|
307
|
+
consumers=[2],
|
308
|
+
parameters=qtyping.UniformQuantParams(
|
309
|
+
8, None, np.array([1]), np.array([0])
|
310
|
+
),
|
311
|
+
),
|
312
|
+
qtyping.TransformationInst(
|
313
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
314
|
+
tensor_id=8,
|
315
|
+
producer=1,
|
316
|
+
consumers=[2],
|
317
|
+
parameters=qtyping.UniformQuantParams(
|
318
|
+
8, None, np.array([1]), np.array([0])
|
319
|
+
),
|
320
|
+
),
|
321
|
+
],
|
322
|
+
),
|
323
|
+
}
|
324
|
+
self._transformation_performer.transform_graph(
|
325
|
+
instructions, self._test_model
|
326
|
+
)
|
327
|
+
self.assertLen(self._test_model.subgraphs, 1)
|
328
|
+
self.assertLen(self._test_model.subgraphs[0].operators, 10)
|
329
|
+
self.assertLen(self._test_model.subgraphs[0].tensors, 17)
|
330
|
+
self.assertEqual(
|
331
|
+
self._test_model.subgraphs[0].operators[1].opcodeIndex,
|
332
|
+
len(self._test_model.operatorCodes) - 1,
|
333
|
+
)
|
334
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[2].inputs[0], 13)
|
335
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[2].outputs[0], 14)
|
336
|
+
self.assertEqual(
|
337
|
+
self._test_model.subgraphs[0].operators[2].outputs[0],
|
338
|
+
self._test_model.subgraphs[0].operators[3].inputs[0],
|
339
|
+
)
|
340
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[3].outputs[0], 8)
|
341
|
+
self.assertEqual(self._test_model.subgraphs[0].operators[4].outputs[0], 15)
|
342
|
+
|
343
|
+
if __name__ == "__main__":
|
344
|
+
googletest.main()
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|