ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,87 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Apply dequantization transformations to the given op/tensor.
17
+
18
+ Inserts dequantize node after the given tensor to enable float execution of
19
+ the tensor consumer
20
+ """
21
+
22
+ from ai_edge_quantizer import qtyping
23
+ from ai_edge_quantizer.transformations import quantize_tensor
24
+ from ai_edge_quantizer.transformations import transformation_utils
25
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
26
+
27
+
28
+ def insert_dequant(
29
+ transformation_input: transformation_utils.TransformationInput,
30
+ ) -> qtyping.TransformationInfo:
31
+ """Insert dequant op after the given tensor in the subgraph.
32
+
33
+ Args:
34
+ transformation_input: input structure that contains all information needed
35
+ for the transformation.
36
+
37
+ Returns:
38
+ TransformationInfo:
39
+ op_id: the index where the dequant op is added
40
+ num_ops_added: the total number of ops inserted by this operation, which
41
+ is 1
42
+ """
43
+ dequant_op_code_idx = transformation_utils.add_op_code(
44
+ schema_py_generated.BuiltinOperator.DEQUANTIZE,
45
+ transformation_input.op_codes,
46
+ )
47
+ # create output tensor for the dequant op
48
+ tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
49
+ new_tensor_id = transformation_utils.add_new_activation_tensor(
50
+ tensor.name + b'_dequant',
51
+ tensor.shape,
52
+ schema_py_generated.TensorType.FLOAT32,
53
+ transformation_input.subgraph,
54
+ )
55
+
56
+ # create dequantize_op
57
+ dequant_op = schema_py_generated.OperatorT()
58
+ dequant_op.opcodeIndex = dequant_op_code_idx
59
+ dequant_op.outputs = [new_tensor_id]
60
+ dequant_op.inputs = [transformation_input.tensor_id]
61
+
62
+ # quantize the source tensor
63
+ quantize_tensor.quantize_tensor(transformation_input)
64
+
65
+ # update the original consumers of the op to take the dequant op,
66
+ # and find the first consumer of the new tensor
67
+ first_consumer_id = min(transformation_input.consumers)
68
+ for consumer_id in transformation_input.consumers:
69
+ op = transformation_input.subgraph.operators[consumer_id]
70
+ for input_idx in range(len(op.inputs)):
71
+ if op.inputs[input_idx] == transformation_input.tensor_id:
72
+ op.inputs[input_idx] = new_tensor_id
73
+
74
+ # if the output is also an output to the graph, we need to update that as well
75
+ for output_idx, output in enumerate(transformation_input.subgraph.outputs):
76
+ if output == transformation_input.tensor_id:
77
+ transformation_input.subgraph.outputs[output_idx] = new_tensor_id
78
+
79
+ # add dequant into the subgraph op list,
80
+ # must insert the op right before it's first consumer
81
+ # in the case of output goes to graph output, we need to ensure the dequant
82
+ # op is inserted after the producer
83
+ op_id = max(transformation_input.producer + 1, first_consumer_id)
84
+ transformation_input.subgraph.operators.insert(op_id, dequant_op)
85
+ return qtyping.TransformationInfo(
86
+ op_id=op_id, num_ops_added=1, output_tensor_id=new_tensor_id
87
+ )
@@ -0,0 +1,304 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Test for various transformations used by quantizer."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from tensorflow.python.platform import googletest
21
+ from ai_edge_quantizer import qtyping
22
+ from ai_edge_quantizer.transformations import dequant_insert
23
+ from ai_edge_quantizer.transformations import transformation_utils
24
+ from ai_edge_quantizer.utils import test_utils
25
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
26
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
27
+
28
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")
29
+
30
+
31
+ class DequantInsertTest(googletest.TestCase):
32
+
33
+ def setUp(self):
34
+ super().setUp()
35
+ self._orig_test_model_path = os.path.join(
36
+ TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
37
+ )
38
+ self._model = tfl_flatbuffer_utils.read_model(self._orig_test_model_path)
39
+
40
+ def test_dequant_insert_constant(self):
41
+ """Test dequant insert lib on a constant tensor."""
42
+ subgraph = self._model.subgraphs[0]
43
+ model = self._model
44
+ dequant_opcode = schema_py_generated.BuiltinOperator.DEQUANTIZE
45
+ # insert dequant on the constant before the add node
46
+ dequant_insert.insert_dequant(
47
+ transformation_utils.TransformationInput(
48
+ 7,
49
+ model.operatorCodes,
50
+ model.buffers,
51
+ subgraph,
52
+ -1,
53
+ [4],
54
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
55
+ )
56
+ )
57
+
58
+ # check dequant op code is added to the model
59
+ self.assertEqual(
60
+ model.operatorCodes[len(model.operatorCodes) - 1].builtinCode,
61
+ dequant_opcode,
62
+ )
63
+
64
+ # check new tensor is correct created
65
+ self.assertIn(b"_dequant", subgraph.tensors[9].name)
66
+ self.assertEqual(
67
+ subgraph.tensors[9].type, schema_py_generated.TensorType.FLOAT32
68
+ )
69
+ self.assertEqual(
70
+ subgraph.tensors[7].type, schema_py_generated.TensorType.INT8
71
+ )
72
+ # checking if consumer haves the correct input
73
+ self.assertEqual(subgraph.operators[5].inputs[0], 6)
74
+ self.assertEqual(subgraph.operators[5].inputs[1], 9)
75
+
76
+ # checking the inserted node has the correct input/output
77
+ self.assertEqual(subgraph.operators[4].outputs[0], 9)
78
+ self.assertEqual(subgraph.operators[4].inputs[0], 7)
79
+ # checking inserted node is the dequant node
80
+ self.assertEqual(
81
+ subgraph.operators[4].opcodeIndex, len(model.operatorCodes) - 1
82
+ )
83
+
84
+ def test_dequant_insert_activation(self):
85
+ """Test dequant insert lib on activation tensors."""
86
+ subgraph = self._model.subgraphs[0]
87
+ model = self._model
88
+ dequant_opcode = schema_py_generated.BuiltinOperator.DEQUANTIZE
89
+ # insert dequant on the output of a conv node
90
+ dequant_insert.insert_dequant(
91
+ transformation_utils.TransformationInput(
92
+ 4,
93
+ model.operatorCodes,
94
+ model.buffers,
95
+ subgraph,
96
+ 1,
97
+ [3],
98
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
99
+ )
100
+ )
101
+
102
+ # check dequant op code is added to the model
103
+ self.assertEqual(
104
+ model.operatorCodes[len(model.operatorCodes) - 1].builtinCode,
105
+ dequant_opcode,
106
+ )
107
+
108
+ # check new tensor is correct created
109
+ self.assertIn(b"_dequant", subgraph.tensors[9].name)
110
+ self.assertEqual(
111
+ subgraph.tensors[9].type, schema_py_generated.TensorType.FLOAT32
112
+ )
113
+ # check original source tensor is updated
114
+ self.assertEqual(
115
+ subgraph.tensors[4].type, schema_py_generated.TensorType.INT8
116
+ )
117
+
118
+ # checking if consumer haves the correct input
119
+ self.assertEqual(subgraph.operators[4].inputs[0], 9)
120
+ self.assertEqual(subgraph.operators[4].inputs[1], 5)
121
+
122
+ # checking the inserted node has the correct input/output
123
+ self.assertEqual(subgraph.operators[3].outputs[0], 9)
124
+ self.assertEqual(subgraph.operators[3].inputs[0], 4)
125
+ # checking inserted node is the dequant node
126
+ self.assertEqual(
127
+ subgraph.operators[3].opcodeIndex, len(model.operatorCodes) - 1
128
+ )
129
+
130
+ def test_dequant_insert_constant_multiple_consumers(self):
131
+ """Test dequant insert lib on tensors with multiple consumers."""
132
+ subgraph = self._model.subgraphs[0]
133
+ model = self._model
134
+ dequant_opcode = schema_py_generated.BuiltinOperator.DEQUANTIZE
135
+ # insert dequant on the input of a conv node
136
+ post_trans_info = dequant_insert.insert_dequant(
137
+ transformation_utils.TransformationInput(
138
+ 2,
139
+ model.operatorCodes,
140
+ model.buffers,
141
+ subgraph,
142
+ -1,
143
+ [1, 2],
144
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
145
+ )
146
+ )
147
+ self.assertEqual(post_trans_info.op_id, 1)
148
+ self.assertEqual(post_trans_info.num_ops_added, 1)
149
+
150
+ # check dequant op code is added to the model
151
+ self.assertEqual(
152
+ model.operatorCodes[len(model.operatorCodes) - 1].builtinCode,
153
+ dequant_opcode,
154
+ )
155
+
156
+ # check new tensor is correct created
157
+ self.assertIn(b"_dequant", subgraph.tensors[9].name)
158
+ self.assertEqual(
159
+ subgraph.tensors[9].type, schema_py_generated.TensorType.FLOAT32
160
+ )
161
+ # check original source tensor has the correct type
162
+ self.assertEqual(
163
+ subgraph.tensors[2].type, schema_py_generated.TensorType.INT8
164
+ )
165
+
166
+ # checking the inserted node has the correct input/output
167
+ self.assertEqual(subgraph.operators[1].outputs[0], 9)
168
+ self.assertEqual(subgraph.operators[1].inputs[0], 2)
169
+ # checking inserted node is the dequant node
170
+ self.assertEqual(
171
+ subgraph.operators[1].opcodeIndex, len(model.operatorCodes) - 1
172
+ )
173
+
174
+ # checking if consumer haves the correct input
175
+ self.assertEqual(subgraph.operators[2].inputs[1], 9)
176
+ self.assertEqual(subgraph.operators[3].inputs[1], 9)
177
+
178
+ def test_dequant_insert_activation_multiple_consumers(self):
179
+ """Test dequant insert lib on tensors with multiple consumers."""
180
+ subgraph = self._model.subgraphs[0]
181
+ model = self._model
182
+ dequant_opcode = schema_py_generated.BuiltinOperator.DEQUANTIZE
183
+ # insert dequant on the output of a conv node
184
+ dequant_insert.insert_dequant(
185
+ transformation_utils.TransformationInput(
186
+ 1,
187
+ model.operatorCodes,
188
+ model.buffers,
189
+ subgraph,
190
+ 0,
191
+ [1, 2],
192
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
193
+ )
194
+ )
195
+
196
+ # check dequant op code is added to the model
197
+ self.assertEqual(
198
+ model.operatorCodes[len(model.operatorCodes) - 1].builtinCode,
199
+ dequant_opcode,
200
+ )
201
+
202
+ # check new tensor is correct created
203
+ self.assertIn(b"_dequant", subgraph.tensors[9].name)
204
+ self.assertEqual(
205
+ subgraph.tensors[9].type, schema_py_generated.TensorType.FLOAT32
206
+ )
207
+ # check original source tensor is updated
208
+ self.assertEqual(
209
+ subgraph.tensors[1].type, schema_py_generated.TensorType.INT8
210
+ )
211
+
212
+ # checking the inserted node has the correct input/output
213
+ self.assertEqual(subgraph.operators[1].outputs[0], 9)
214
+ self.assertEqual(subgraph.operators[1].inputs[0], 1)
215
+ # checking inserted node is the dequant node
216
+ self.assertEqual(
217
+ subgraph.operators[1].opcodeIndex, len(model.operatorCodes) - 1
218
+ )
219
+
220
+ # checking if consumer haves the correct input
221
+ self.assertEqual(subgraph.operators[2].inputs[0], 9)
222
+ self.assertEqual(subgraph.operators[3].inputs[0], 9)
223
+
224
+ def test_dequant_insert_activation_multiple_consumers_select(self):
225
+ """Test dequant insert lib on tensors with multiple consumers but only insert for one of them."""
226
+ subgraph = self._model.subgraphs[0]
227
+ model = self._model
228
+ dequant_opcode = schema_py_generated.BuiltinOperator.DEQUANTIZE
229
+ # insert dequant on the output of a conv node
230
+ dequant_insert.insert_dequant(
231
+ transformation_utils.TransformationInput(
232
+ 1,
233
+ model.operatorCodes,
234
+ model.buffers,
235
+ subgraph,
236
+ 0,
237
+ [1],
238
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
239
+ )
240
+ )
241
+
242
+ # check dequant op code is added to the model
243
+ self.assertEqual(
244
+ model.operatorCodes[len(model.operatorCodes) - 1].builtinCode,
245
+ dequant_opcode,
246
+ )
247
+
248
+ # check new tensor is correct created
249
+ self.assertIn(b"_dequant", subgraph.tensors[9].name)
250
+ self.assertEqual(
251
+ subgraph.tensors[9].type, schema_py_generated.TensorType.FLOAT32
252
+ )
253
+ # check original source tensor is updated
254
+ self.assertEqual(
255
+ subgraph.tensors[1].type, schema_py_generated.TensorType.INT8
256
+ )
257
+
258
+ # checking the inserted node has the correct input/output
259
+ self.assertEqual(subgraph.operators[1].outputs[0], 9)
260
+ self.assertEqual(subgraph.operators[1].inputs[0], 1)
261
+ # checking inserted node is the dequant node
262
+ self.assertEqual(
263
+ subgraph.operators[1].opcodeIndex, len(model.operatorCodes) - 1
264
+ )
265
+
266
+ # checking if consumer haves the correct input
267
+ self.assertEqual(subgraph.operators[2].inputs[0], 9)
268
+ self.assertEqual(subgraph.operators[3].inputs[0], 1)
269
+
270
+ def test_dequant_insert_on_graph_output(self):
271
+ """Test dequant insert lib on graph output."""
272
+ subgraph = self._model.subgraphs[0]
273
+ model = self._model
274
+ dequant_opcode = schema_py_generated.BuiltinOperator.DEQUANTIZE
275
+ # insert dequant on the graph output
276
+ dequant_insert.insert_dequant(
277
+ transformation_utils.TransformationInput(
278
+ 8,
279
+ model.operatorCodes,
280
+ model.buffers,
281
+ subgraph,
282
+ 4,
283
+ [-1],
284
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
285
+ )
286
+ )
287
+
288
+ # check dequant op code is added to the model
289
+ self.assertEqual(
290
+ model.operatorCodes[len(model.operatorCodes) - 1].builtinCode,
291
+ dequant_opcode,
292
+ )
293
+
294
+ # checking inserted node is the dequant node
295
+ self.assertEqual(
296
+ subgraph.operators[5].opcodeIndex, len(model.operatorCodes) - 1
297
+ )
298
+
299
+ # check if the graph output is updated
300
+ self.assertEqual(subgraph.outputs[0], 9)
301
+
302
+
303
+ if __name__ == "__main__":
304
+ googletest.main()