ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,278 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Python manager for transformations to be applied to TFlite models."""
17
+
18
+ import numpy as np
19
+ from ai_edge_quantizer import qtyping
20
+ from ai_edge_quantizer.transformations import dequant_insert
21
+ from ai_edge_quantizer.transformations import emulated_subchannel
22
+ from ai_edge_quantizer.transformations import quant_insert
23
+ from ai_edge_quantizer.transformations import quantize_tensor
24
+ from ai_edge_quantizer.transformations import transformation_utils
25
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
26
+
27
+
28
+ class TransformationPerformer:
29
+ """Wrapper class for transformations.
30
+
31
+ all transformations supported by the AI Edge Quantizer should be registered in
32
+ the Transformation Performer.
33
+
34
+ transformation to be appied to a tensor in the Instruction Generator is
35
+ specified
36
+ as a key. Given this key, the transformation performer will return a function
37
+ that will apply the corresponding transformation on the graph
38
+
39
+ A transformation is defined as a Callable that takes the following parameters:
40
+ tensor_id: A tensor id that represents the tensor to be applied
41
+ operatorCodes: list of OperatorCodesT from the source TFlite ModelT
42
+ buffers: list of BufferT from the source TFLite ModelT
43
+ subgraph: the specific subgraph where the transformation should be applied
44
+ producer: the op index for the producer of the tensor
45
+ consumers: a list of op index representing consumers to apply the change on
46
+ quant_param: the quantization parameters in qtyping.UniformQuantParams
47
+ And returns a qtyping.TransformationInfo which contains the index where the
48
+ ops are added and how many ops are added
49
+
50
+ Additionally, op additions must be consecutive
51
+
52
+ this class is expected to be created by the Model Modifier and nothing else
53
+
54
+ Model modifier would pass in a dict of transformations to be applied, this
55
+ class will apply the transformations in a pre-determined static order
56
+ """
57
+
58
+ def __init__(self):
59
+ """Initializes the TransformationPerformer."""
60
+ self._transformation_registration = {
61
+ qtyping.QuantTransformation.ADD_DEQUANTIZE: (
62
+ dequant_insert.insert_dequant
63
+ ),
64
+ qtyping.QuantTransformation.QUANTIZE_TENSOR: (
65
+ quantize_tensor.quantize_tensor
66
+ ),
67
+ qtyping.QuantTransformation.EMULATED_SUBCHANNEL: (
68
+ emulated_subchannel.emulated_subchannel
69
+ ),
70
+ qtyping.QuantTransformation.ADD_QUANTIZE: quant_insert.insert_quant,
71
+ }
72
+ # transformations are seprated in two categories:
73
+ # op_insertion_transformations are transformations that only insert ops
74
+ # into the graph, whereas op_replacement_transformations will replace one op
75
+ # with a pattern
76
+ self._op_insertion_transformations = set([
77
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
78
+ qtyping.QuantTransformation.QUANTIZE_TENSOR,
79
+ qtyping.QuantTransformation.ADD_QUANTIZE,
80
+ ])
81
+ self._op_replacement_transformations = set(
82
+ [qtyping.QuantTransformation.EMULATED_SUBCHANNEL]
83
+ )
84
+ self._original_op_id_map = []
85
+ self._added_op_id_map = []
86
+
87
+ def _create_op_id_map(self, tflite_model: schema_py_generated.ModelT):
88
+ """init the original op_id to modified op_id map.
89
+
90
+ At the beginning the graph has not been updated, so op_id maps to it's
91
+ current id.
92
+
93
+ Args:
94
+ tflite_model: the model we're create op_id mapping
95
+
96
+ Returns:
97
+ None, modifies self._original_op_id_map inplace
98
+ """
99
+ for subgraph in tflite_model.subgraphs:
100
+ self._original_op_id_map.append(list(range(len(subgraph.operators))))
101
+ self._added_op_id_map.append([])
102
+
103
+ def _update_op_id_map(
104
+ self, subgraph_id: int, original_op_id: int, num_ops_added: int
105
+ ):
106
+ """Update the mapping between the original op id and modified op ids.
107
+
108
+ Args:
109
+ subgraph_id: the index of subgraph that we're interested in
110
+ original_op_id: the original id for which the first op is added
111
+ num_ops_added: the number of ops added starting from the op id
112
+
113
+ Returns:
114
+ None, modify self._original_op_id_map
115
+ """
116
+ np_op_id_map = np.array(self._original_op_id_map[subgraph_id])
117
+ np_op_id_map[original_op_id:] += num_ops_added
118
+ self._original_op_id_map[subgraph_id] = np_op_id_map.tolist()
119
+
120
+ def _update_instructions(
121
+ self,
122
+ prev_transformation_index: int,
123
+ transformations: list[qtyping.TransformationInst],
124
+ subgraph_id: int,
125
+ trans_info: qtyping.TransformationInfo,
126
+ ):
127
+ """Update the instructions after the graph is modified.
128
+
129
+ After an op is inserted, the topology is changed and this may impact the
130
+ following transformation to be applied. So we need to update instructions
131
+ that have yet to be applied.
132
+
133
+ Args:
134
+ prev_transformation_index: the index of the last applied transformation
135
+ transformations: the list of transformations we're applying
136
+ subgraph_id: the subgraph where the provided instrucitons belongs to
137
+ trans_info: transformation info returned by a transformation
138
+
139
+ Returns:
140
+ None, modifies the transformation in place
141
+ """
142
+ # if no ops were added, then no need for update
143
+ if trans_info.num_ops_added == 0:
144
+ return
145
+ prev_transformation = transformations[prev_transformation_index]
146
+ self._added_op_id_map[subgraph_id].append(
147
+ trans_info.op_id + trans_info.num_ops_added - 1
148
+ )
149
+ for transformations_index in range(
150
+ prev_transformation_index + 1, len(transformations)
151
+ ):
152
+ transformation = transformations[transformations_index]
153
+ for consumer_index in transformation.consumers:
154
+ # if the consumer need to use newly added ops, then the new added op
155
+ # index needs to be outside of the range of the orignal op ids.
156
+ if consumer_index in prev_transformation.consumers:
157
+ transformation.producer = (
158
+ len(self._original_op_id_map[subgraph_id])
159
+ + len(self._added_op_id_map[subgraph_id])
160
+ - 1
161
+ )
162
+ transformation.tensor_id = trans_info.output_tensor_id
163
+
164
+ def _apply_single_transformation(
165
+ self,
166
+ transformation_inst: qtyping.TensorTransformationInsts,
167
+ transformation_index: int,
168
+ tflite_model: schema_py_generated.ModelT,
169
+ ):
170
+ """Apply a single transformation.
171
+
172
+ Args:
173
+ transformation_inst: a TensorTransformationInsts type that contains all
174
+ transformations on a tensor
175
+ transformation_index: the index of the transformation to be applied
176
+ tflite_model: source tflite model to be updated
177
+
178
+ Returns:
179
+ None, update the transformation_inst & tflite_model in place
180
+ """
181
+ instruction = transformation_inst.instructions[transformation_index]
182
+ if not instruction.producer or instruction.producer < 0:
183
+ producer = -1
184
+ elif instruction.producer < len(
185
+ self._original_op_id_map[transformation_inst.subgraph_id]
186
+ ):
187
+ producer = self._original_op_id_map[transformation_inst.subgraph_id][
188
+ instruction.producer
189
+ ]
190
+ else:
191
+ # if the producer id is not in the original op map, it's an added op,
192
+ # go the corresponding new maps
193
+ producer = self._added_op_id_map[transformation_inst.subgraph_id][
194
+ instruction.producer
195
+ - len(self._original_op_id_map[transformation_inst.subgraph_id])
196
+ ]
197
+ consumers = []
198
+ for original_op_id in instruction.consumers:
199
+ consumers.append(
200
+ self._original_op_id_map[transformation_inst.subgraph_id][
201
+ original_op_id
202
+ ]
203
+ )
204
+ trans_info = self._transformation_registration[instruction.transformation](
205
+ transformation_utils.TransformationInput(
206
+ instruction.tensor_id,
207
+ tflite_model.operatorCodes,
208
+ tflite_model.buffers,
209
+ tflite_model.subgraphs[transformation_inst.subgraph_id],
210
+ producer,
211
+ consumers,
212
+ instruction.parameters,
213
+ )
214
+ )
215
+ self._update_instructions(
216
+ transformation_index,
217
+ transformation_inst.instructions,
218
+ transformation_inst.subgraph_id,
219
+ trans_info,
220
+ )
221
+ self._update_op_id_map(
222
+ transformation_inst.subgraph_id,
223
+ min(instruction.consumers),
224
+ trans_info.num_ops_added,
225
+ )
226
+
227
+ def _apply_transformations(
228
+ self,
229
+ transformation_inst: qtyping.TensorTransformationInsts,
230
+ tflite_model: schema_py_generated.ModelT,
231
+ ):
232
+ """Apply all transformations for a tensor.
233
+
234
+ transformations are separated in two types and applied separately in two
235
+ different passes
236
+
237
+ Args:
238
+ transformation_inst: a TensorTransformationInsts type that contains all
239
+ transformation on a tensor
240
+ tflite_model: source tflite model to be updated
241
+
242
+ Returns:
243
+ None, update the transformation_inst & tflite_model in place
244
+ """
245
+ # pass 1: apply all the op insertion transformation, because op replacement
246
+ # may remove consumer or producer of some tensors
247
+ for index, instruction in enumerate(transformation_inst.instructions):
248
+ if instruction.transformation in self._op_insertion_transformations:
249
+ self._apply_single_transformation(
250
+ transformation_inst, index, tflite_model
251
+ )
252
+ # pass 2: apply all the op replacement transformation
253
+ for index, instruction in enumerate(transformation_inst.instructions):
254
+ if instruction.transformation in self._op_replacement_transformations:
255
+ self._apply_single_transformation(
256
+ transformation_inst, index, tflite_model
257
+ )
258
+
259
+ def transform_graph(
260
+ self,
261
+ transformation_instructions: dict[str, qtyping.TensorTransformationInsts],
262
+ tflite_model: schema_py_generated.ModelT,
263
+ ):
264
+ """Apply all transformations to the given tflite_model.
265
+
266
+ Args:
267
+ transformation_instructions: a dict of transformation instructions grouped
268
+ by tensors, produced by transformation_instruction_generator
269
+ tflite_model: the tflite model to apply quantization on
270
+
271
+ Returns:
272
+ None, modifies the input tflite_model in place
273
+ """
274
+ self._original_op_id_map = []
275
+ self._added_op_id_map = []
276
+ self._create_op_id_map(tflite_model)
277
+ for transformation_inst in transformation_instructions.values():
278
+ self._apply_transformations(transformation_inst, tflite_model)
@@ -0,0 +1,344 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Tests for transformation_performer."""
17
+
18
+ import os
19
+
20
+ import numpy as np
21
+
22
+ from tensorflow.python.platform import googletest
23
+ from absl.testing import parameterized
24
+ from ai_edge_quantizer import qtyping
25
+ from ai_edge_quantizer import transformation_performer
26
+ from ai_edge_quantizer.utils import test_utils
27
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
+
29
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(".")
30
+
31
+
32
+ class TransformationPerformerTest(parameterized.TestCase):
33
+
34
+ def setUp(self):
35
+ super().setUp()
36
+ self._transformation_performer = (
37
+ transformation_performer.TransformationPerformer()
38
+ )
39
+ self._test_model = tfl_flatbuffer_utils.read_model(
40
+ os.path.join(TEST_DATA_PREFIX_PATH, "tests/models/conv_fc_mnist.tflite")
41
+ )
42
+
43
+ def test_apply_single_insert_dequant(self):
44
+ """test for _apply_transformation."""
45
+ self._transformation_performer._create_op_id_map(self._test_model)
46
+ instructions = qtyping.TensorTransformationInsts(
47
+ tensor_name="sequential/conv2d/Relu;sequential/conv2d/BiasAdd;"
48
+ + "sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1",
49
+ subgraph_id=0,
50
+ instructions=[
51
+ qtyping.TransformationInst(
52
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
53
+ tensor_id=7,
54
+ producer=0,
55
+ consumers=[1],
56
+ parameters=qtyping.UniformQuantParams(
57
+ 8, None, np.array([1]), np.array([0])
58
+ ),
59
+ ),
60
+ qtyping.TransformationInst(
61
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
62
+ tensor_id=7,
63
+ producer=0,
64
+ consumers=[1],
65
+ parameters=qtyping.UniformQuantParams(
66
+ 8, None, np.array([1]), np.array([0])
67
+ ),
68
+ ),
69
+ ],
70
+ )
71
+ self._transformation_performer._apply_single_transformation(
72
+ instructions, 0, self._test_model
73
+ )
74
+ subgraph = self._test_model.subgraphs[0]
75
+ self.assertIn(b"_dequant", subgraph.tensors[13].name)
76
+ self.assertEqual(
77
+ subgraph.operators[1].opcodeIndex,
78
+ len(self._test_model.operatorCodes) - 1,
79
+ )
80
+ self.assertEqual(subgraph.operators[2].inputs[0], 13)
81
+
82
+ def test_create_op_id_map(self):
83
+ """test for _create_op_id_map."""
84
+ self._transformation_performer._create_op_id_map(self._test_model)
85
+ op_id_map = self._transformation_performer._original_op_id_map
86
+ self.assertLen(op_id_map, 1)
87
+ self.assertLen(op_id_map[0], 6)
88
+ for index, op_id in enumerate(op_id_map[0]):
89
+ self.assertEqual(op_id, index)
90
+
91
+ def test_update_op_id_map_changing_value(self):
92
+ """test for _update_op_id_map."""
93
+ self._transformation_performer._create_op_id_map(self._test_model)
94
+ self._transformation_performer._update_op_id_map(0, 1, 6)
95
+ op_id_map = self._transformation_performer._original_op_id_map
96
+ self.assertLen(op_id_map, 1)
97
+ self.assertLen(op_id_map[0], 6)
98
+ for index in range(1, len(op_id_map[0])):
99
+ self.assertEqual(op_id_map[0][index], index + 6)
100
+
101
+ def test_update_op_id_map_not_changing_value(self):
102
+ """test for _update_op_id_map."""
103
+ self._transformation_performer._create_op_id_map(self._test_model)
104
+ self._transformation_performer._update_op_id_map(0, 0, 0)
105
+ op_id_map = self._transformation_performer._original_op_id_map
106
+ self.assertLen(op_id_map, 1)
107
+ self.assertLen(op_id_map[0], 6)
108
+ for index, op_id in enumerate(op_id_map[0]):
109
+ self.assertEqual(op_id, index)
110
+
111
+ @parameterized.named_parameters(
112
+ dict(
113
+ testcase_name="test_no_update",
114
+ prev_trans_idx=0,
115
+ instructions=[
116
+ qtyping.TransformationInst(
117
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
118
+ tensor_id=0,
119
+ producer=0,
120
+ consumers=[1],
121
+ parameters=qtyping.UniformQuantParams(
122
+ 8, None, np.array([1]), np.array([0])
123
+ ),
124
+ ),
125
+ qtyping.TransformationInst(
126
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
127
+ tensor_id=0,
128
+ producer=0,
129
+ consumers=[1],
130
+ parameters=qtyping.UniformQuantParams(
131
+ 8, None, np.array([1]), np.array([0])
132
+ ),
133
+ ),
134
+ ],
135
+ trans_info=qtyping.TransformationInfo(0, 0, 0),
136
+ expected_instructions=[
137
+ qtyping.TransformationInst(
138
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
139
+ tensor_id=0,
140
+ producer=0,
141
+ consumers=[1],
142
+ parameters=qtyping.UniformQuantParams(
143
+ 8, None, np.array([1]), np.array([0])
144
+ ),
145
+ ),
146
+ qtyping.TransformationInst(
147
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
148
+ tensor_id=0,
149
+ producer=0,
150
+ consumers=[1],
151
+ parameters=qtyping.UniformQuantParams(
152
+ 8, None, np.array([1]), np.array([0])
153
+ ),
154
+ ),
155
+ ],
156
+ expected_added_op_id_map=[[]],
157
+ ),
158
+ dict(
159
+ testcase_name="test_no_matching_consumer",
160
+ prev_trans_idx=0,
161
+ instructions=[
162
+ qtyping.TransformationInst(
163
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
164
+ tensor_id=0,
165
+ producer=0,
166
+ consumers=[1],
167
+ parameters=qtyping.UniformQuantParams(
168
+ 8, None, np.array([1]), np.array([0])
169
+ ),
170
+ ),
171
+ qtyping.TransformationInst(
172
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
173
+ tensor_id=0,
174
+ producer=0,
175
+ consumers=[2],
176
+ parameters=qtyping.UniformQuantParams(
177
+ 8, None, np.array([1]), np.array([0])
178
+ ),
179
+ ),
180
+ ],
181
+ trans_info=qtyping.TransformationInfo(2, 2, 13),
182
+ expected_instructions=[
183
+ qtyping.TransformationInst(
184
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
185
+ tensor_id=0,
186
+ producer=0,
187
+ consumers=[1],
188
+ parameters=qtyping.UniformQuantParams(
189
+ 8, None, np.array([1]), np.array([0])
190
+ ),
191
+ ),
192
+ qtyping.TransformationInst(
193
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
194
+ tensor_id=0,
195
+ producer=0,
196
+ consumers=[2],
197
+ parameters=qtyping.UniformQuantParams(
198
+ 8, None, np.array([1]), np.array([0])
199
+ ),
200
+ ),
201
+ ],
202
+ expected_added_op_id_map=[[3]],
203
+ ),
204
+ dict(
205
+ testcase_name="test_insert_one_op",
206
+ prev_trans_idx=0,
207
+ instructions=[
208
+ qtyping.TransformationInst(
209
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
210
+ tensor_id=0,
211
+ producer=0,
212
+ consumers=[1],
213
+ parameters=qtyping.UniformQuantParams(
214
+ 8, None, np.array([1]), np.array([0])
215
+ ),
216
+ ),
217
+ qtyping.TransformationInst(
218
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
219
+ tensor_id=0,
220
+ producer=0,
221
+ consumers=[1],
222
+ parameters=qtyping.UniformQuantParams(
223
+ 8, None, np.array([1]), np.array([0])
224
+ ),
225
+ ),
226
+ ],
227
+ trans_info=qtyping.TransformationInfo(1, 1, 13),
228
+ expected_instructions=[
229
+ qtyping.TransformationInst(
230
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
231
+ tensor_id=0,
232
+ producer=0,
233
+ consumers=[1],
234
+ parameters=qtyping.UniformQuantParams(
235
+ 8, None, np.array([1]), np.array([0])
236
+ ),
237
+ ),
238
+ qtyping.TransformationInst(
239
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
240
+ tensor_id=13,
241
+ producer=6,
242
+ consumers=[1],
243
+ parameters=qtyping.UniformQuantParams(
244
+ 8, None, np.array([1]), np.array([0])
245
+ ),
246
+ ),
247
+ ],
248
+ expected_added_op_id_map=[[1]],
249
+ ),
250
+ )
251
+ def test_update_instructions(
252
+ self,
253
+ prev_trans_idx,
254
+ instructions,
255
+ trans_info,
256
+ expected_instructions,
257
+ expected_added_op_id_map,
258
+ ):
259
+ """test for _update_instructions."""
260
+ self._transformation_performer._create_op_id_map(self._test_model)
261
+ self._transformation_performer._update_instructions(
262
+ prev_trans_idx, instructions, 0, trans_info
263
+ )
264
+ self.assertSequenceEqual(instructions, expected_instructions)
265
+ self.assertListEqual(
266
+ self._transformation_performer._added_op_id_map,
267
+ expected_added_op_id_map,
268
+ )
269
+
270
+ def test_transform_graph(self):
271
+ """test for transform_graph."""
272
+ instructions = {
273
+ "sequential/conv2d/Relu;sequential/conv2d/BiasAdd;"
274
+ + "sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1": qtyping.TensorTransformationInsts(
275
+ tensor_name="sequential/conv2d/Relu;sequential/conv2d/BiasAdd;"
276
+ + "sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1",
277
+ subgraph_id=0,
278
+ instructions=[
279
+ qtyping.TransformationInst(
280
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
281
+ tensor_id=7,
282
+ producer=0,
283
+ consumers=[1],
284
+ parameters=qtyping.UniformQuantParams(
285
+ 8, None, np.array([1]), np.array([0])
286
+ ),
287
+ ),
288
+ qtyping.TransformationInst(
289
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
290
+ tensor_id=7,
291
+ producer=0,
292
+ consumers=[1],
293
+ parameters=qtyping.UniformQuantParams(
294
+ 8, None, np.array([1]), np.array([0])
295
+ ),
296
+ ),
297
+ ],
298
+ ),
299
+ "sequential/average_pooling2d/AvgPool": qtyping.TensorTransformationInsts(
300
+ tensor_name="sequential/average_pooling2d/AvgPool",
301
+ subgraph_id=0,
302
+ instructions=[
303
+ qtyping.TransformationInst(
304
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
305
+ tensor_id=8,
306
+ producer=1,
307
+ consumers=[2],
308
+ parameters=qtyping.UniformQuantParams(
309
+ 8, None, np.array([1]), np.array([0])
310
+ ),
311
+ ),
312
+ qtyping.TransformationInst(
313
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
314
+ tensor_id=8,
315
+ producer=1,
316
+ consumers=[2],
317
+ parameters=qtyping.UniformQuantParams(
318
+ 8, None, np.array([1]), np.array([0])
319
+ ),
320
+ ),
321
+ ],
322
+ ),
323
+ }
324
+ self._transformation_performer.transform_graph(
325
+ instructions, self._test_model
326
+ )
327
+ self.assertLen(self._test_model.subgraphs, 1)
328
+ self.assertLen(self._test_model.subgraphs[0].operators, 10)
329
+ self.assertLen(self._test_model.subgraphs[0].tensors, 17)
330
+ self.assertEqual(
331
+ self._test_model.subgraphs[0].operators[1].opcodeIndex,
332
+ len(self._test_model.operatorCodes) - 1,
333
+ )
334
+ self.assertEqual(self._test_model.subgraphs[0].operators[2].inputs[0], 13)
335
+ self.assertEqual(self._test_model.subgraphs[0].operators[2].outputs[0], 14)
336
+ self.assertEqual(
337
+ self._test_model.subgraphs[0].operators[2].outputs[0],
338
+ self._test_model.subgraphs[0].operators[3].inputs[0],
339
+ )
340
+ self.assertEqual(self._test_model.subgraphs[0].operators[3].outputs[0], 8)
341
+ self.assertEqual(self._test_model.subgraphs[0].operators[4].outputs[0], 15)
342
+
343
+ if __name__ == "__main__":
344
+ googletest.main()
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+