ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,284 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Test for various transformations used by quantization toolkit."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from tensorflow.python.platform import googletest
21
+ from ai_edge_quantizer import qtyping
22
+ from ai_edge_quantizer.transformations import quant_insert
23
+ from ai_edge_quantizer.transformations import transformation_utils
24
+ from ai_edge_quantizer.utils import test_utils
25
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
26
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
27
+
28
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")
29
+
30
+
31
+ class QuantInsertTest(googletest.TestCase):
32
+
33
+ def setUp(self):
34
+ super().setUp()
35
+ self._orig_test_model_path = os.path.join(
36
+ TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
37
+ )
38
+ self._model = tfl_flatbuffer_utils.read_model(self._orig_test_model_path)
39
+
40
+ def test_quant_insert_constant(self):
41
+ """Test quant insert lib on a constant tensor."""
42
+ subgraph = self._model.subgraphs[0]
43
+ model = self._model
44
+ quant_opcode = schema_py_generated.BuiltinOperator.QUANTIZE
45
+ # insert quant on the constant before the add node
46
+ quant_insert.insert_quant(
47
+ transformation_utils.TransformationInput(
48
+ 7,
49
+ model.operatorCodes,
50
+ model.buffers,
51
+ subgraph,
52
+ -1,
53
+ [4],
54
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
55
+ )
56
+ )
57
+
58
+ # check quant op code is added to the model
59
+ self.assertEqual(
60
+ model.operatorCodes[0].builtinCode,
61
+ quant_opcode,
62
+ )
63
+
64
+ # check new tensor is correct created
65
+ self.assertIn(b"_quantized", subgraph.tensors[9].name)
66
+ self.assertEqual(
67
+ subgraph.tensors[9].type, schema_py_generated.TensorType.INT8
68
+ )
69
+ self.assertEqual(
70
+ subgraph.tensors[7].type, schema_py_generated.TensorType.UINT8
71
+ )
72
+ # checking if consumer has the correct input
73
+ self.assertEqual(subgraph.operators[5].inputs[0], 6)
74
+ self.assertEqual(subgraph.operators[5].inputs[1], 9)
75
+
76
+ # checking the inserted node has the correct input/output
77
+ self.assertEqual(subgraph.operators[4].outputs[0], 9)
78
+ self.assertEqual(subgraph.operators[4].inputs[0], 7)
79
+ # checking inserted node is the quant node
80
+ self.assertEqual(subgraph.operators[4].opcodeIndex, 0)
81
+
82
+ def test_quant_insert_activation(self):
83
+ """Test quant insert lib on activation tensors."""
84
+ subgraph = self._model.subgraphs[0]
85
+ model = self._model
86
+ quant_opcode = schema_py_generated.BuiltinOperator.QUANTIZE
87
+ # insert quant on the output of a conv node
88
+ quant_insert.insert_quant(
89
+ transformation_utils.TransformationInput(
90
+ 4,
91
+ model.operatorCodes,
92
+ model.buffers,
93
+ subgraph,
94
+ 1,
95
+ [3],
96
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
97
+ )
98
+ )
99
+
100
+ # check quant op code is added to the model
101
+ self.assertEqual(
102
+ model.operatorCodes[0].builtinCode,
103
+ quant_opcode,
104
+ )
105
+
106
+ # check new tensor is correctly created
107
+ self.assertIn(b"_quantized", subgraph.tensors[9].name)
108
+ self.assertEqual(
109
+ subgraph.tensors[9].type, schema_py_generated.TensorType.INT8
110
+ )
111
+ # check original source tensor is updated
112
+ self.assertEqual(
113
+ subgraph.tensors[4].type, schema_py_generated.TensorType.UINT8
114
+ )
115
+
116
+ # checking if consumer haves the correct input
117
+ self.assertEqual(subgraph.operators[4].inputs[0], 9)
118
+ self.assertEqual(subgraph.operators[4].inputs[1], 5)
119
+
120
+ # checking the inserted node has the correct input/output
121
+ self.assertEqual(subgraph.operators[3].outputs[0], 9)
122
+ self.assertEqual(subgraph.operators[3].inputs[0], 4)
123
+ # checking inserted node is the quant node
124
+ self.assertEqual(subgraph.operators[3].opcodeIndex, 0)
125
+
126
+ def test_quant_insert_constant_multiple_consumers(self):
127
+ """Test quant insert lib on tensors with multiple consumers."""
128
+ subgraph = self._model.subgraphs[0]
129
+ model = self._model
130
+ quant_opcode = schema_py_generated.BuiltinOperator.QUANTIZE
131
+ # insert quant on the input of a conv node
132
+ post_trans_info = quant_insert.insert_quant(
133
+ transformation_utils.TransformationInput(
134
+ 2,
135
+ model.operatorCodes,
136
+ model.buffers,
137
+ subgraph,
138
+ -1,
139
+ [1, 2],
140
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
141
+ )
142
+ )
143
+ self.assertEqual(post_trans_info.op_id, 1)
144
+ self.assertEqual(post_trans_info.num_ops_added, 1)
145
+
146
+ # check quant op code is added to the model
147
+ self.assertEqual(
148
+ model.operatorCodes[0].builtinCode,
149
+ quant_opcode,
150
+ )
151
+
152
+ # check new tensor is correct created
153
+ self.assertIn(b"_quantized", subgraph.tensors[9].name)
154
+ self.assertEqual(
155
+ subgraph.tensors[9].type, schema_py_generated.TensorType.INT8
156
+ )
157
+ # check original source tensor has the correct type
158
+ self.assertEqual(
159
+ subgraph.tensors[2].type, schema_py_generated.TensorType.UINT8
160
+ )
161
+
162
+ # checking the inserted node has the correct input/output
163
+ self.assertEqual(subgraph.operators[1].outputs[0], 9)
164
+ self.assertEqual(subgraph.operators[1].inputs[0], 2)
165
+ # checking inserted node is the quant node
166
+ self.assertEqual(subgraph.operators[1].opcodeIndex, 0)
167
+
168
+ # checking if consumer haves the correct input
169
+ self.assertEqual(subgraph.operators[2].inputs[1], 9)
170
+ self.assertEqual(subgraph.operators[3].inputs[1], 9)
171
+
172
+ def test_quant_insert_activation_multiple_consumers(self):
173
+ """Test quant insert lib on tensors with multiple consumers."""
174
+ subgraph = self._model.subgraphs[0]
175
+ model = self._model
176
+ quant_opcode = schema_py_generated.BuiltinOperator.QUANTIZE
177
+ # insert quant on the output of a conv node
178
+ quant_insert.insert_quant(
179
+ transformation_utils.TransformationInput(
180
+ 1,
181
+ model.operatorCodes,
182
+ model.buffers,
183
+ subgraph,
184
+ 0,
185
+ [1, 2],
186
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
187
+ )
188
+ )
189
+
190
+ # check quant op code is added to the model
191
+ self.assertEqual(
192
+ model.operatorCodes[0].builtinCode,
193
+ quant_opcode,
194
+ )
195
+
196
+ # check new tensor is correct created
197
+ self.assertIn(b"_quantized", subgraph.tensors[9].name)
198
+ self.assertEqual(
199
+ subgraph.tensors[9].type, schema_py_generated.TensorType.INT8
200
+ )
201
+ # check original source tensor is updated
202
+ self.assertEqual(
203
+ subgraph.tensors[1].type, schema_py_generated.TensorType.UINT8
204
+ )
205
+
206
+ # checking the inserted node has the correct input/output
207
+ self.assertEqual(subgraph.operators[1].outputs[0], 9)
208
+ self.assertEqual(subgraph.operators[1].inputs[0], 1)
209
+ # checking inserted node is the quant node
210
+ self.assertEqual(subgraph.operators[1].opcodeIndex, 0)
211
+
212
+ # checking if consumer haves the correct input
213
+ self.assertEqual(subgraph.operators[2].inputs[0], 9)
214
+ self.assertEqual(subgraph.operators[3].inputs[0], 9)
215
+
216
+ def test_quant_insert_activation_multiple_consumers_select(self):
217
+ """Test quant insert lib on tensors with multiple consumers but only insert for one of them."""
218
+ subgraph = self._model.subgraphs[0]
219
+ model = self._model
220
+ quant_opcode = schema_py_generated.BuiltinOperator.QUANTIZE
221
+ # insert quant on the output of a conv node
222
+ quant_insert.insert_quant(
223
+ transformation_utils.TransformationInput(
224
+ 1,
225
+ model.operatorCodes,
226
+ model.buffers,
227
+ subgraph,
228
+ 0,
229
+ [1],
230
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
231
+ )
232
+ )
233
+
234
+ # check quant op code is added to the model
235
+ self.assertEqual(
236
+ model.operatorCodes[0].builtinCode,
237
+ quant_opcode,
238
+ )
239
+
240
+ # check new tensor is correct created
241
+ self.assertIn(b"_quantized", subgraph.tensors[9].name)
242
+ self.assertEqual(
243
+ subgraph.tensors[9].type, schema_py_generated.TensorType.INT8
244
+ )
245
+ # check original source tensor is updated
246
+ self.assertEqual(
247
+ subgraph.tensors[1].type, schema_py_generated.TensorType.UINT8
248
+ )
249
+
250
+ # checking inserted node is the quant node
251
+ self.assertEqual(subgraph.operators[1].opcodeIndex, 0)
252
+
253
+ # checking if consumer haves the correct input
254
+ self.assertEqual(subgraph.operators[2].inputs[0], 9)
255
+ self.assertEqual(subgraph.operators[3].inputs[0], 1)
256
+
257
+ # checking the inserted node has the correct input/output
258
+ self.assertEqual(subgraph.operators[1].outputs[0], 9)
259
+ self.assertEqual(subgraph.operators[1].inputs[0], 1)
260
+
261
+ def test_dequant_insert_on_graph_output(self):
262
+ """Test dequant insert lib on graph output."""
263
+ subgraph = self._model.subgraphs[0]
264
+ model = self._model
265
+ # insert dequant on the graph output
266
+ quant_insert.insert_quant(
267
+ transformation_utils.TransformationInput(
268
+ 8,
269
+ model.operatorCodes,
270
+ model.buffers,
271
+ subgraph,
272
+ 4,
273
+ [-1],
274
+ qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
275
+ )
276
+ )
277
+ # checking inserted node is the quant node
278
+ self.assertEqual(subgraph.operators[5].opcodeIndex, 0)
279
+ # check if the graph output is updated
280
+ self.assertEqual(subgraph.outputs[0], 9)
281
+
282
+
283
+ if __name__ == "__main__":
284
+ googletest.main()
@@ -0,0 +1,156 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """quantize a given tensor."""
17
+
18
+ from typing import Optional, cast
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.transformations import transformation_utils
22
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ # TODO: b/335014051 - Support distinguishing INT, FLOAT & UINT, BFLOAT.
26
+ def quant_params_to_tflite_type(
27
+ bitwidth: int,
28
+ ) -> Optional[schema_py_generated.TensorType]:
29
+ """Given specifications from quant param return the corresponding TFLite dtype.
30
+
31
+ Args:
32
+ bitwidth: Bit width from UniformQuantParams.
33
+
34
+ Returns:
35
+ The corresponding TFLite tensor type.
36
+ """
37
+ if bitwidth == 4:
38
+ return schema_py_generated.TensorType.INT4
39
+ elif bitwidth <= 8:
40
+ return schema_py_generated.TensorType.INT8
41
+ elif bitwidth <= 16:
42
+ return schema_py_generated.TensorType.INT16
43
+ elif bitwidth <= 32:
44
+ return schema_py_generated.TensorType.INT32
45
+ elif bitwidth <= 64:
46
+ return schema_py_generated.TensorType.INT64
47
+ else:
48
+ raise ValueError(f"Unsupported quant params: {bitwidth}")
49
+
50
+
51
+ def nonlinear_quant_params_to_tflite_type(
52
+ bitwidth: int,
53
+ ) -> Optional[schema_py_generated.TensorType]:
54
+ """Given specifications from quant param return the corresponding tflite dtype.
55
+
56
+ Args:
57
+ bitwidth: bitwidth from NonLinearQuantParams
58
+
59
+ Returns:
60
+ the corresponding tflite tensortype
61
+ """
62
+ if bitwidth == 16:
63
+ return schema_py_generated.TensorType.FLOAT16
64
+ elif bitwidth == 32:
65
+ return schema_py_generated.TensorType.FLOAT32
66
+ else:
67
+ raise ValueError(f"Unsupported nonlinear params: {bitwidth}")
68
+
69
+
70
+ def _pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
71
+ """Pack the data to the corresponding bit width.
72
+
73
+ Currently only support 4 bits. If no packing is needed, the original data is
74
+ returned.
75
+
76
+ Args:
77
+ bitwidth: Bit width from NonLinearQuantParams.
78
+ flattened_data: The data to be packed.
79
+
80
+ Returns:
81
+ Packed data.
82
+ """
83
+ if bitwidth == 4:
84
+ even_data = flattened_data[::2] & 0x0F
85
+ odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
86
+ if odd_data.shape[0] == even_data.shape[0] - 1:
87
+ odd_data = np.pad(odd_data, (0, 1), constant_values=0)
88
+ return np.bitwise_or(even_data, odd_data)
89
+ else:
90
+ return flattened_data
91
+
92
+
93
+ def quantize_tensor(
94
+ transformation_input: transformation_utils.TransformationInput,
95
+ ) -> qtyping.TransformationInfo:
96
+ """Quantize the tensor at the tensor_id in the given subgraph.
97
+
98
+ Args:
99
+ transformation_input: input structure that contains all information needed
100
+ for the transformation.
101
+
102
+ Returns:
103
+ TransformationInfo:
104
+ op_id: the producer index for tensor
105
+ num_ops_added: the total number of ops inserted by this operation, which
106
+ is 0
107
+ """
108
+ tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
109
+ # TODO: b/336385820 - suppport quantize buffer directly when quantized_data
110
+ # is not provided
111
+ if tensor.buffer:
112
+ if transformation_input.quant_params.quantized_data is not None:
113
+ transformation_input.buffers[tensor.buffer].data = _pack_data(
114
+ transformation_input.quant_params.num_bits,
115
+ np.frombuffer(
116
+ cast(
117
+ np.ndarray, transformation_input.quant_params.quantized_data
118
+ ).tobytes(),
119
+ dtype=np.uint8,
120
+ ).flatten(),
121
+ )
122
+
123
+ if isinstance(transformation_input.quant_params, qtyping.UniformQuantParams):
124
+ flatbuffer_quantization = schema_py_generated.QuantizationParametersT()
125
+ flatbuffer_quantization.scale = list(
126
+ transformation_input.quant_params.scale.flatten().astype(np.float32)
127
+ ) # flatbuffer requires scale as list[float]
128
+ flatbuffer_quantization.zeroPoint = list(
129
+ transformation_input.quant_params.zero_point.flatten().astype(np.int64)
130
+ ) # flatbuffer requires zeroPoint as list[int64]
131
+ if transformation_input.quant_params.quantized_dimension is not None:
132
+ flatbuffer_quantization.quantizedDimension = (
133
+ transformation_input.quant_params.quantized_dimension
134
+ )
135
+ tensor.quantization = flatbuffer_quantization
136
+ tensor.type = quant_params_to_tflite_type(
137
+ transformation_input.quant_params.num_bits
138
+ )
139
+
140
+ if isinstance(
141
+ transformation_input.quant_params, qtyping.NonLinearQuantParams
142
+ ):
143
+ tensor.type = nonlinear_quant_params_to_tflite_type(
144
+ transformation_input.quant_params.num_bits
145
+ )
146
+
147
+ if isinstance(
148
+ transformation_input.quant_params, qtyping.NonLinearQuantParams
149
+ ):
150
+ tensor.type = nonlinear_quant_params_to_tflite_type(
151
+ transformation_input.quant_params.num_bits
152
+ )
153
+
154
+ return qtyping.TransformationInfo(
155
+ 0, num_ops_added=0, output_tensor_id=transformation_input.tensor_id
156
+ )
@@ -0,0 +1,227 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """test for quantize tensor."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from tensorflow.python.platform import googletest
21
+ from absl.testing import parameterized
22
+ from ai_edge_quantizer import qtyping
23
+ from ai_edge_quantizer.transformations import quantize_tensor
24
+ from ai_edge_quantizer.transformations import transformation_utils
25
+ from ai_edge_quantizer.utils import test_utils
26
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
27
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
28
+
29
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")
30
+
31
+
32
+ class QuantizeTensorTest(parameterized.TestCase):
33
+
34
+ def setUp(self):
35
+ super().setUp()
36
+ self._orig_test_model_path = os.path.join(
37
+ TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
38
+ )
39
+ self._model = tfl_flatbuffer_utils.read_model(self._orig_test_model_path)
40
+
41
+ def test_quantize_constant_tensor(self):
42
+ """test quantizing a constant tensor."""
43
+ subgraph = self._model.subgraphs[0]
44
+ model = self._model
45
+ data = np.ones([1, 112, 112, 3], dtype=np.int8)
46
+ ret = quantize_tensor.quantize_tensor(
47
+ transformation_utils.TransformationInput(
48
+ 7,
49
+ model.operatorCodes,
50
+ model.buffers,
51
+ subgraph,
52
+ -1,
53
+ [4],
54
+ qtyping.UniformQuantParams(
55
+ 8, None, np.ones(1), np.ones(1), True, data
56
+ ),
57
+ )
58
+ )
59
+ self.assertEqual(ret.op_id, 0)
60
+ self.assertEqual(ret.num_ops_added, 0)
61
+ self.assertListEqual(
62
+ np.array(model.buffers[8].data).tolist(), data.flatten().tolist()
63
+ )
64
+ quant_param = subgraph.tensors[7].quantization
65
+ self.assertListEqual(np.array(quant_param.scale).tolist(), [1])
66
+ self.assertEqual(np.array(quant_param.zeroPoint).tolist(), [1])
67
+ self.assertEqual(quant_param.quantizedDimension, 0)
68
+
69
+ def test_quantize_activation_tensor(self):
70
+ """test quantizing an activation tensor."""
71
+ subgraph = self._model.subgraphs[0]
72
+ model = self._model
73
+ ret = quantize_tensor.quantize_tensor(
74
+ transformation_utils.TransformationInput(
75
+ 4,
76
+ model.operatorCodes,
77
+ model.buffers,
78
+ subgraph,
79
+ 1,
80
+ [3],
81
+ qtyping.UniformQuantParams(
82
+ 8, None, np.array([22]), np.array([127])
83
+ ),
84
+ )
85
+ )
86
+ self.assertEqual(ret.op_id, 0)
87
+ self.assertEqual(ret.num_ops_added, 0)
88
+ quant_param = subgraph.tensors[4].quantization
89
+ self.assertListEqual(np.array(quant_param.scale).tolist(), [22])
90
+ self.assertListEqual(np.array(quant_param.zeroPoint).tolist(), [127])
91
+ self.assertEqual(quant_param.quantizedDimension, 0)
92
+
93
+ def test_quantize_tensor_with_per_channel_quantization(self):
94
+ """test quantizing an activation tensor."""
95
+ subgraph = self._model.subgraphs[0]
96
+ model = self._model
97
+ ret = quantize_tensor.quantize_tensor(
98
+ transformation_utils.TransformationInput(
99
+ 4,
100
+ model.operatorCodes,
101
+ model.buffers,
102
+ subgraph,
103
+ 1,
104
+ [3],
105
+ qtyping.UniformQuantParams(8, 3, np.ones([22]), np.zeros([22])),
106
+ )
107
+ )
108
+ self.assertEqual(ret.op_id, 0)
109
+ self.assertEqual(ret.num_ops_added, 0)
110
+ quant_param = subgraph.tensors[4].quantization
111
+ self.assertListEqual(
112
+ np.array(quant_param.scale).tolist(), np.ones([22]).tolist()
113
+ )
114
+ self.assertListEqual(
115
+ np.array(quant_param.zeroPoint).tolist(), np.zeros([22]).tolist()
116
+ )
117
+ self.assertEqual(quant_param.quantizedDimension, 3)
118
+
119
+ def test_quantize_tensor_with_nonlinear_quantization(self):
120
+ """test quantizing an activation tensor with non-linear quantization."""
121
+ subgraph = self._model.subgraphs[0]
122
+ model = self._model
123
+ quantize_tensor.quantize_tensor(
124
+ transformation_utils.TransformationInput(
125
+ 4,
126
+ model.operatorCodes,
127
+ model.buffers,
128
+ subgraph,
129
+ 1,
130
+ [3],
131
+ qtyping.NonLinearQuantParams(16, None),
132
+ )
133
+ )
134
+ self.assertEqual(
135
+ subgraph.tensors[4].type, schema_py_generated.TensorType.FLOAT16
136
+ )
137
+
138
+ def test_int4_constant_packed_correctly(self):
139
+ subgraph = self._model.subgraphs[0]
140
+ model = self._model
141
+ data = np.array(
142
+ [
143
+ 0x0,
144
+ 0x1,
145
+ 0x2,
146
+ 0x3,
147
+ 0x4,
148
+ 0x5,
149
+ 0x6,
150
+ 0x7,
151
+ 0x8,
152
+ 0x9,
153
+ 0xA,
154
+ 0xB,
155
+ 0xC,
156
+ 0xD,
157
+ 0xE,
158
+ ],
159
+ dtype=np.int8,
160
+ )
161
+ expected = np.array([0x10, 0x32, 0x54, 0x76, 0x98, 0xBA, 0xDC, 0x0E])
162
+ ret = quantize_tensor.quantize_tensor(
163
+ transformation_utils.TransformationInput(
164
+ tensor_id=7,
165
+ op_codes=model.operatorCodes,
166
+ buffers=model.buffers,
167
+ subgraph=subgraph,
168
+ producer=-1,
169
+ consumers=[4],
170
+ quant_params=qtyping.UniformQuantParams(
171
+ 4, None, np.ones(1), np.ones(1), True, data
172
+ ),
173
+ )
174
+ )
175
+ self.assertEqual(ret.op_id, 0)
176
+ self.assertEqual(ret.num_ops_added, 0)
177
+ np.testing.assert_array_equal(model.buffers[8].data, expected)
178
+ quant_param = subgraph.tensors[7].quantization
179
+ np.testing.assert_array_equal(quant_param.scale, [1])
180
+ np.testing.assert_array_equal(quant_param.zeroPoint, [1])
181
+ self.assertEqual(quant_param.quantizedDimension, 0)
182
+
183
+ @parameterized.named_parameters(
184
+ dict(
185
+ testcase_name="int5",
186
+ num_bits=5,
187
+ ),
188
+ dict(
189
+ testcase_name="int2",
190
+ num_bits=2,
191
+ ),
192
+ )
193
+ def test_int_constant_not_packed(self, num_bits):
194
+ subgraph = self._model.subgraphs[0]
195
+ model = self._model
196
+ tensor_id = 7
197
+ data = np.array([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7], dtype=np.int8)
198
+ expected = np.array([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7])
199
+ ret = quantize_tensor.quantize_tensor(
200
+ transformation_utils.TransformationInput(
201
+ tensor_id=tensor_id,
202
+ op_codes=model.operatorCodes,
203
+ buffers=model.buffers,
204
+ subgraph=subgraph,
205
+ producer=-1,
206
+ consumers=[4],
207
+ quant_params=qtyping.UniformQuantParams(
208
+ num_bits=num_bits,
209
+ quantized_dimension=None,
210
+ scale=np.ones(1),
211
+ zero_point=np.ones(1),
212
+ symmetric=True,
213
+ quantized_data=data,
214
+ ),
215
+ )
216
+ )
217
+ self.assertEqual(ret.op_id, 0)
218
+ self.assertEqual(ret.num_ops_added, 0)
219
+ np.testing.assert_array_equal(model.buffers[8].data, expected)
220
+ quant_param = subgraph.tensors[tensor_id].quantization
221
+ np.testing.assert_array_equal(quant_param.scale, [1])
222
+ np.testing.assert_array_equal(quant_param.zeroPoint, [1])
223
+ self.assertEqual(quant_param.quantizedDimension, 0)
224
+
225
+
226
+ if __name__ == "__main__":
227
+ googletest.main()