ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,284 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Test for various transformations used by quantization toolkit."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import numpy as np
|
20
|
+
from tensorflow.python.platform import googletest
|
21
|
+
from ai_edge_quantizer import qtyping
|
22
|
+
from ai_edge_quantizer.transformations import quant_insert
|
23
|
+
from ai_edge_quantizer.transformations import transformation_utils
|
24
|
+
from ai_edge_quantizer.utils import test_utils
|
25
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
26
|
+
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
27
|
+
|
28
|
+
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")
|
29
|
+
|
30
|
+
|
31
|
+
class QuantInsertTest(googletest.TestCase):
|
32
|
+
|
33
|
+
def setUp(self):
|
34
|
+
super().setUp()
|
35
|
+
self._orig_test_model_path = os.path.join(
|
36
|
+
TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
|
37
|
+
)
|
38
|
+
self._model = tfl_flatbuffer_utils.read_model(self._orig_test_model_path)
|
39
|
+
|
40
|
+
def test_quant_insert_constant(self):
|
41
|
+
"""Test quant insert lib on a constant tensor."""
|
42
|
+
subgraph = self._model.subgraphs[0]
|
43
|
+
model = self._model
|
44
|
+
quant_opcode = schema_py_generated.BuiltinOperator.QUANTIZE
|
45
|
+
# insert quant on the constant before the add node
|
46
|
+
quant_insert.insert_quant(
|
47
|
+
transformation_utils.TransformationInput(
|
48
|
+
7,
|
49
|
+
model.operatorCodes,
|
50
|
+
model.buffers,
|
51
|
+
subgraph,
|
52
|
+
-1,
|
53
|
+
[4],
|
54
|
+
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
|
55
|
+
)
|
56
|
+
)
|
57
|
+
|
58
|
+
# check quant op code is added to the model
|
59
|
+
self.assertEqual(
|
60
|
+
model.operatorCodes[0].builtinCode,
|
61
|
+
quant_opcode,
|
62
|
+
)
|
63
|
+
|
64
|
+
# check new tensor is correct created
|
65
|
+
self.assertIn(b"_quantized", subgraph.tensors[9].name)
|
66
|
+
self.assertEqual(
|
67
|
+
subgraph.tensors[9].type, schema_py_generated.TensorType.INT8
|
68
|
+
)
|
69
|
+
self.assertEqual(
|
70
|
+
subgraph.tensors[7].type, schema_py_generated.TensorType.UINT8
|
71
|
+
)
|
72
|
+
# checking if consumer has the correct input
|
73
|
+
self.assertEqual(subgraph.operators[5].inputs[0], 6)
|
74
|
+
self.assertEqual(subgraph.operators[5].inputs[1], 9)
|
75
|
+
|
76
|
+
# checking the inserted node has the correct input/output
|
77
|
+
self.assertEqual(subgraph.operators[4].outputs[0], 9)
|
78
|
+
self.assertEqual(subgraph.operators[4].inputs[0], 7)
|
79
|
+
# checking inserted node is the quant node
|
80
|
+
self.assertEqual(subgraph.operators[4].opcodeIndex, 0)
|
81
|
+
|
82
|
+
def test_quant_insert_activation(self):
|
83
|
+
"""Test quant insert lib on activation tensors."""
|
84
|
+
subgraph = self._model.subgraphs[0]
|
85
|
+
model = self._model
|
86
|
+
quant_opcode = schema_py_generated.BuiltinOperator.QUANTIZE
|
87
|
+
# insert quant on the output of a conv node
|
88
|
+
quant_insert.insert_quant(
|
89
|
+
transformation_utils.TransformationInput(
|
90
|
+
4,
|
91
|
+
model.operatorCodes,
|
92
|
+
model.buffers,
|
93
|
+
subgraph,
|
94
|
+
1,
|
95
|
+
[3],
|
96
|
+
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
|
97
|
+
)
|
98
|
+
)
|
99
|
+
|
100
|
+
# check quant op code is added to the model
|
101
|
+
self.assertEqual(
|
102
|
+
model.operatorCodes[0].builtinCode,
|
103
|
+
quant_opcode,
|
104
|
+
)
|
105
|
+
|
106
|
+
# check new tensor is correctly created
|
107
|
+
self.assertIn(b"_quantized", subgraph.tensors[9].name)
|
108
|
+
self.assertEqual(
|
109
|
+
subgraph.tensors[9].type, schema_py_generated.TensorType.INT8
|
110
|
+
)
|
111
|
+
# check original source tensor is updated
|
112
|
+
self.assertEqual(
|
113
|
+
subgraph.tensors[4].type, schema_py_generated.TensorType.UINT8
|
114
|
+
)
|
115
|
+
|
116
|
+
# checking if consumer haves the correct input
|
117
|
+
self.assertEqual(subgraph.operators[4].inputs[0], 9)
|
118
|
+
self.assertEqual(subgraph.operators[4].inputs[1], 5)
|
119
|
+
|
120
|
+
# checking the inserted node has the correct input/output
|
121
|
+
self.assertEqual(subgraph.operators[3].outputs[0], 9)
|
122
|
+
self.assertEqual(subgraph.operators[3].inputs[0], 4)
|
123
|
+
# checking inserted node is the quant node
|
124
|
+
self.assertEqual(subgraph.operators[3].opcodeIndex, 0)
|
125
|
+
|
126
|
+
def test_quant_insert_constant_multiple_consumers(self):
|
127
|
+
"""Test quant insert lib on tensors with multiple consumers."""
|
128
|
+
subgraph = self._model.subgraphs[0]
|
129
|
+
model = self._model
|
130
|
+
quant_opcode = schema_py_generated.BuiltinOperator.QUANTIZE
|
131
|
+
# insert quant on the input of a conv node
|
132
|
+
post_trans_info = quant_insert.insert_quant(
|
133
|
+
transformation_utils.TransformationInput(
|
134
|
+
2,
|
135
|
+
model.operatorCodes,
|
136
|
+
model.buffers,
|
137
|
+
subgraph,
|
138
|
+
-1,
|
139
|
+
[1, 2],
|
140
|
+
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
|
141
|
+
)
|
142
|
+
)
|
143
|
+
self.assertEqual(post_trans_info.op_id, 1)
|
144
|
+
self.assertEqual(post_trans_info.num_ops_added, 1)
|
145
|
+
|
146
|
+
# check quant op code is added to the model
|
147
|
+
self.assertEqual(
|
148
|
+
model.operatorCodes[0].builtinCode,
|
149
|
+
quant_opcode,
|
150
|
+
)
|
151
|
+
|
152
|
+
# check new tensor is correct created
|
153
|
+
self.assertIn(b"_quantized", subgraph.tensors[9].name)
|
154
|
+
self.assertEqual(
|
155
|
+
subgraph.tensors[9].type, schema_py_generated.TensorType.INT8
|
156
|
+
)
|
157
|
+
# check original source tensor has the correct type
|
158
|
+
self.assertEqual(
|
159
|
+
subgraph.tensors[2].type, schema_py_generated.TensorType.UINT8
|
160
|
+
)
|
161
|
+
|
162
|
+
# checking the inserted node has the correct input/output
|
163
|
+
self.assertEqual(subgraph.operators[1].outputs[0], 9)
|
164
|
+
self.assertEqual(subgraph.operators[1].inputs[0], 2)
|
165
|
+
# checking inserted node is the quant node
|
166
|
+
self.assertEqual(subgraph.operators[1].opcodeIndex, 0)
|
167
|
+
|
168
|
+
# checking if consumer haves the correct input
|
169
|
+
self.assertEqual(subgraph.operators[2].inputs[1], 9)
|
170
|
+
self.assertEqual(subgraph.operators[3].inputs[1], 9)
|
171
|
+
|
172
|
+
def test_quant_insert_activation_multiple_consumers(self):
|
173
|
+
"""Test quant insert lib on tensors with multiple consumers."""
|
174
|
+
subgraph = self._model.subgraphs[0]
|
175
|
+
model = self._model
|
176
|
+
quant_opcode = schema_py_generated.BuiltinOperator.QUANTIZE
|
177
|
+
# insert quant on the output of a conv node
|
178
|
+
quant_insert.insert_quant(
|
179
|
+
transformation_utils.TransformationInput(
|
180
|
+
1,
|
181
|
+
model.operatorCodes,
|
182
|
+
model.buffers,
|
183
|
+
subgraph,
|
184
|
+
0,
|
185
|
+
[1, 2],
|
186
|
+
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
|
187
|
+
)
|
188
|
+
)
|
189
|
+
|
190
|
+
# check quant op code is added to the model
|
191
|
+
self.assertEqual(
|
192
|
+
model.operatorCodes[0].builtinCode,
|
193
|
+
quant_opcode,
|
194
|
+
)
|
195
|
+
|
196
|
+
# check new tensor is correct created
|
197
|
+
self.assertIn(b"_quantized", subgraph.tensors[9].name)
|
198
|
+
self.assertEqual(
|
199
|
+
subgraph.tensors[9].type, schema_py_generated.TensorType.INT8
|
200
|
+
)
|
201
|
+
# check original source tensor is updated
|
202
|
+
self.assertEqual(
|
203
|
+
subgraph.tensors[1].type, schema_py_generated.TensorType.UINT8
|
204
|
+
)
|
205
|
+
|
206
|
+
# checking the inserted node has the correct input/output
|
207
|
+
self.assertEqual(subgraph.operators[1].outputs[0], 9)
|
208
|
+
self.assertEqual(subgraph.operators[1].inputs[0], 1)
|
209
|
+
# checking inserted node is the quant node
|
210
|
+
self.assertEqual(subgraph.operators[1].opcodeIndex, 0)
|
211
|
+
|
212
|
+
# checking if consumer haves the correct input
|
213
|
+
self.assertEqual(subgraph.operators[2].inputs[0], 9)
|
214
|
+
self.assertEqual(subgraph.operators[3].inputs[0], 9)
|
215
|
+
|
216
|
+
def test_quant_insert_activation_multiple_consumers_select(self):
|
217
|
+
"""Test quant insert lib on tensors with multiple consumers but only insert for one of them."""
|
218
|
+
subgraph = self._model.subgraphs[0]
|
219
|
+
model = self._model
|
220
|
+
quant_opcode = schema_py_generated.BuiltinOperator.QUANTIZE
|
221
|
+
# insert quant on the output of a conv node
|
222
|
+
quant_insert.insert_quant(
|
223
|
+
transformation_utils.TransformationInput(
|
224
|
+
1,
|
225
|
+
model.operatorCodes,
|
226
|
+
model.buffers,
|
227
|
+
subgraph,
|
228
|
+
0,
|
229
|
+
[1],
|
230
|
+
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
|
231
|
+
)
|
232
|
+
)
|
233
|
+
|
234
|
+
# check quant op code is added to the model
|
235
|
+
self.assertEqual(
|
236
|
+
model.operatorCodes[0].builtinCode,
|
237
|
+
quant_opcode,
|
238
|
+
)
|
239
|
+
|
240
|
+
# check new tensor is correct created
|
241
|
+
self.assertIn(b"_quantized", subgraph.tensors[9].name)
|
242
|
+
self.assertEqual(
|
243
|
+
subgraph.tensors[9].type, schema_py_generated.TensorType.INT8
|
244
|
+
)
|
245
|
+
# check original source tensor is updated
|
246
|
+
self.assertEqual(
|
247
|
+
subgraph.tensors[1].type, schema_py_generated.TensorType.UINT8
|
248
|
+
)
|
249
|
+
|
250
|
+
# checking inserted node is the quant node
|
251
|
+
self.assertEqual(subgraph.operators[1].opcodeIndex, 0)
|
252
|
+
|
253
|
+
# checking if consumer haves the correct input
|
254
|
+
self.assertEqual(subgraph.operators[2].inputs[0], 9)
|
255
|
+
self.assertEqual(subgraph.operators[3].inputs[0], 1)
|
256
|
+
|
257
|
+
# checking the inserted node has the correct input/output
|
258
|
+
self.assertEqual(subgraph.operators[1].outputs[0], 9)
|
259
|
+
self.assertEqual(subgraph.operators[1].inputs[0], 1)
|
260
|
+
|
261
|
+
def test_dequant_insert_on_graph_output(self):
|
262
|
+
"""Test dequant insert lib on graph output."""
|
263
|
+
subgraph = self._model.subgraphs[0]
|
264
|
+
model = self._model
|
265
|
+
# insert dequant on the graph output
|
266
|
+
quant_insert.insert_quant(
|
267
|
+
transformation_utils.TransformationInput(
|
268
|
+
8,
|
269
|
+
model.operatorCodes,
|
270
|
+
model.buffers,
|
271
|
+
subgraph,
|
272
|
+
4,
|
273
|
+
[-1],
|
274
|
+
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
|
275
|
+
)
|
276
|
+
)
|
277
|
+
# checking inserted node is the quant node
|
278
|
+
self.assertEqual(subgraph.operators[5].opcodeIndex, 0)
|
279
|
+
# check if the graph output is updated
|
280
|
+
self.assertEqual(subgraph.outputs[0], 9)
|
281
|
+
|
282
|
+
|
283
|
+
if __name__ == "__main__":
|
284
|
+
googletest.main()
|
@@ -0,0 +1,156 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""quantize a given tensor."""
|
17
|
+
|
18
|
+
from typing import Optional, cast
|
19
|
+
import numpy as np
|
20
|
+
from ai_edge_quantizer import qtyping
|
21
|
+
from ai_edge_quantizer.transformations import transformation_utils
|
22
|
+
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
23
|
+
|
24
|
+
|
25
|
+
# TODO: b/335014051 - Support distinguishing INT, FLOAT & UINT, BFLOAT.
|
26
|
+
def quant_params_to_tflite_type(
|
27
|
+
bitwidth: int,
|
28
|
+
) -> Optional[schema_py_generated.TensorType]:
|
29
|
+
"""Given specifications from quant param return the corresponding TFLite dtype.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
bitwidth: Bit width from UniformQuantParams.
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
The corresponding TFLite tensor type.
|
36
|
+
"""
|
37
|
+
if bitwidth == 4:
|
38
|
+
return schema_py_generated.TensorType.INT4
|
39
|
+
elif bitwidth <= 8:
|
40
|
+
return schema_py_generated.TensorType.INT8
|
41
|
+
elif bitwidth <= 16:
|
42
|
+
return schema_py_generated.TensorType.INT16
|
43
|
+
elif bitwidth <= 32:
|
44
|
+
return schema_py_generated.TensorType.INT32
|
45
|
+
elif bitwidth <= 64:
|
46
|
+
return schema_py_generated.TensorType.INT64
|
47
|
+
else:
|
48
|
+
raise ValueError(f"Unsupported quant params: {bitwidth}")
|
49
|
+
|
50
|
+
|
51
|
+
def nonlinear_quant_params_to_tflite_type(
|
52
|
+
bitwidth: int,
|
53
|
+
) -> Optional[schema_py_generated.TensorType]:
|
54
|
+
"""Given specifications from quant param return the corresponding tflite dtype.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
bitwidth: bitwidth from NonLinearQuantParams
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
the corresponding tflite tensortype
|
61
|
+
"""
|
62
|
+
if bitwidth == 16:
|
63
|
+
return schema_py_generated.TensorType.FLOAT16
|
64
|
+
elif bitwidth == 32:
|
65
|
+
return schema_py_generated.TensorType.FLOAT32
|
66
|
+
else:
|
67
|
+
raise ValueError(f"Unsupported nonlinear params: {bitwidth}")
|
68
|
+
|
69
|
+
|
70
|
+
def _pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
|
71
|
+
"""Pack the data to the corresponding bit width.
|
72
|
+
|
73
|
+
Currently only support 4 bits. If no packing is needed, the original data is
|
74
|
+
returned.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
bitwidth: Bit width from NonLinearQuantParams.
|
78
|
+
flattened_data: The data to be packed.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
Packed data.
|
82
|
+
"""
|
83
|
+
if bitwidth == 4:
|
84
|
+
even_data = flattened_data[::2] & 0x0F
|
85
|
+
odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
|
86
|
+
if odd_data.shape[0] == even_data.shape[0] - 1:
|
87
|
+
odd_data = np.pad(odd_data, (0, 1), constant_values=0)
|
88
|
+
return np.bitwise_or(even_data, odd_data)
|
89
|
+
else:
|
90
|
+
return flattened_data
|
91
|
+
|
92
|
+
|
93
|
+
def quantize_tensor(
|
94
|
+
transformation_input: transformation_utils.TransformationInput,
|
95
|
+
) -> qtyping.TransformationInfo:
|
96
|
+
"""Quantize the tensor at the tensor_id in the given subgraph.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
transformation_input: input structure that contains all information needed
|
100
|
+
for the transformation.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
TransformationInfo:
|
104
|
+
op_id: the producer index for tensor
|
105
|
+
num_ops_added: the total number of ops inserted by this operation, which
|
106
|
+
is 0
|
107
|
+
"""
|
108
|
+
tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
|
109
|
+
# TODO: b/336385820 - suppport quantize buffer directly when quantized_data
|
110
|
+
# is not provided
|
111
|
+
if tensor.buffer:
|
112
|
+
if transformation_input.quant_params.quantized_data is not None:
|
113
|
+
transformation_input.buffers[tensor.buffer].data = _pack_data(
|
114
|
+
transformation_input.quant_params.num_bits,
|
115
|
+
np.frombuffer(
|
116
|
+
cast(
|
117
|
+
np.ndarray, transformation_input.quant_params.quantized_data
|
118
|
+
).tobytes(),
|
119
|
+
dtype=np.uint8,
|
120
|
+
).flatten(),
|
121
|
+
)
|
122
|
+
|
123
|
+
if isinstance(transformation_input.quant_params, qtyping.UniformQuantParams):
|
124
|
+
flatbuffer_quantization = schema_py_generated.QuantizationParametersT()
|
125
|
+
flatbuffer_quantization.scale = list(
|
126
|
+
transformation_input.quant_params.scale.flatten().astype(np.float32)
|
127
|
+
) # flatbuffer requires scale as list[float]
|
128
|
+
flatbuffer_quantization.zeroPoint = list(
|
129
|
+
transformation_input.quant_params.zero_point.flatten().astype(np.int64)
|
130
|
+
) # flatbuffer requires zeroPoint as list[int64]
|
131
|
+
if transformation_input.quant_params.quantized_dimension is not None:
|
132
|
+
flatbuffer_quantization.quantizedDimension = (
|
133
|
+
transformation_input.quant_params.quantized_dimension
|
134
|
+
)
|
135
|
+
tensor.quantization = flatbuffer_quantization
|
136
|
+
tensor.type = quant_params_to_tflite_type(
|
137
|
+
transformation_input.quant_params.num_bits
|
138
|
+
)
|
139
|
+
|
140
|
+
if isinstance(
|
141
|
+
transformation_input.quant_params, qtyping.NonLinearQuantParams
|
142
|
+
):
|
143
|
+
tensor.type = nonlinear_quant_params_to_tflite_type(
|
144
|
+
transformation_input.quant_params.num_bits
|
145
|
+
)
|
146
|
+
|
147
|
+
if isinstance(
|
148
|
+
transformation_input.quant_params, qtyping.NonLinearQuantParams
|
149
|
+
):
|
150
|
+
tensor.type = nonlinear_quant_params_to_tflite_type(
|
151
|
+
transformation_input.quant_params.num_bits
|
152
|
+
)
|
153
|
+
|
154
|
+
return qtyping.TransformationInfo(
|
155
|
+
0, num_ops_added=0, output_tensor_id=transformation_input.tensor_id
|
156
|
+
)
|
@@ -0,0 +1,227 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""test for quantize tensor."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import numpy as np
|
20
|
+
from tensorflow.python.platform import googletest
|
21
|
+
from absl.testing import parameterized
|
22
|
+
from ai_edge_quantizer import qtyping
|
23
|
+
from ai_edge_quantizer.transformations import quantize_tensor
|
24
|
+
from ai_edge_quantizer.transformations import transformation_utils
|
25
|
+
from ai_edge_quantizer.utils import test_utils
|
26
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
27
|
+
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
28
|
+
|
29
|
+
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")
|
30
|
+
|
31
|
+
|
32
|
+
class QuantizeTensorTest(parameterized.TestCase):
|
33
|
+
|
34
|
+
def setUp(self):
|
35
|
+
super().setUp()
|
36
|
+
self._orig_test_model_path = os.path.join(
|
37
|
+
TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
|
38
|
+
)
|
39
|
+
self._model = tfl_flatbuffer_utils.read_model(self._orig_test_model_path)
|
40
|
+
|
41
|
+
def test_quantize_constant_tensor(self):
|
42
|
+
"""test quantizing a constant tensor."""
|
43
|
+
subgraph = self._model.subgraphs[0]
|
44
|
+
model = self._model
|
45
|
+
data = np.ones([1, 112, 112, 3], dtype=np.int8)
|
46
|
+
ret = quantize_tensor.quantize_tensor(
|
47
|
+
transformation_utils.TransformationInput(
|
48
|
+
7,
|
49
|
+
model.operatorCodes,
|
50
|
+
model.buffers,
|
51
|
+
subgraph,
|
52
|
+
-1,
|
53
|
+
[4],
|
54
|
+
qtyping.UniformQuantParams(
|
55
|
+
8, None, np.ones(1), np.ones(1), True, data
|
56
|
+
),
|
57
|
+
)
|
58
|
+
)
|
59
|
+
self.assertEqual(ret.op_id, 0)
|
60
|
+
self.assertEqual(ret.num_ops_added, 0)
|
61
|
+
self.assertListEqual(
|
62
|
+
np.array(model.buffers[8].data).tolist(), data.flatten().tolist()
|
63
|
+
)
|
64
|
+
quant_param = subgraph.tensors[7].quantization
|
65
|
+
self.assertListEqual(np.array(quant_param.scale).tolist(), [1])
|
66
|
+
self.assertEqual(np.array(quant_param.zeroPoint).tolist(), [1])
|
67
|
+
self.assertEqual(quant_param.quantizedDimension, 0)
|
68
|
+
|
69
|
+
def test_quantize_activation_tensor(self):
|
70
|
+
"""test quantizing an activation tensor."""
|
71
|
+
subgraph = self._model.subgraphs[0]
|
72
|
+
model = self._model
|
73
|
+
ret = quantize_tensor.quantize_tensor(
|
74
|
+
transformation_utils.TransformationInput(
|
75
|
+
4,
|
76
|
+
model.operatorCodes,
|
77
|
+
model.buffers,
|
78
|
+
subgraph,
|
79
|
+
1,
|
80
|
+
[3],
|
81
|
+
qtyping.UniformQuantParams(
|
82
|
+
8, None, np.array([22]), np.array([127])
|
83
|
+
),
|
84
|
+
)
|
85
|
+
)
|
86
|
+
self.assertEqual(ret.op_id, 0)
|
87
|
+
self.assertEqual(ret.num_ops_added, 0)
|
88
|
+
quant_param = subgraph.tensors[4].quantization
|
89
|
+
self.assertListEqual(np.array(quant_param.scale).tolist(), [22])
|
90
|
+
self.assertListEqual(np.array(quant_param.zeroPoint).tolist(), [127])
|
91
|
+
self.assertEqual(quant_param.quantizedDimension, 0)
|
92
|
+
|
93
|
+
def test_quantize_tensor_with_per_channel_quantization(self):
|
94
|
+
"""test quantizing an activation tensor."""
|
95
|
+
subgraph = self._model.subgraphs[0]
|
96
|
+
model = self._model
|
97
|
+
ret = quantize_tensor.quantize_tensor(
|
98
|
+
transformation_utils.TransformationInput(
|
99
|
+
4,
|
100
|
+
model.operatorCodes,
|
101
|
+
model.buffers,
|
102
|
+
subgraph,
|
103
|
+
1,
|
104
|
+
[3],
|
105
|
+
qtyping.UniformQuantParams(8, 3, np.ones([22]), np.zeros([22])),
|
106
|
+
)
|
107
|
+
)
|
108
|
+
self.assertEqual(ret.op_id, 0)
|
109
|
+
self.assertEqual(ret.num_ops_added, 0)
|
110
|
+
quant_param = subgraph.tensors[4].quantization
|
111
|
+
self.assertListEqual(
|
112
|
+
np.array(quant_param.scale).tolist(), np.ones([22]).tolist()
|
113
|
+
)
|
114
|
+
self.assertListEqual(
|
115
|
+
np.array(quant_param.zeroPoint).tolist(), np.zeros([22]).tolist()
|
116
|
+
)
|
117
|
+
self.assertEqual(quant_param.quantizedDimension, 3)
|
118
|
+
|
119
|
+
def test_quantize_tensor_with_nonlinear_quantization(self):
|
120
|
+
"""test quantizing an activation tensor with non-linear quantization."""
|
121
|
+
subgraph = self._model.subgraphs[0]
|
122
|
+
model = self._model
|
123
|
+
quantize_tensor.quantize_tensor(
|
124
|
+
transformation_utils.TransformationInput(
|
125
|
+
4,
|
126
|
+
model.operatorCodes,
|
127
|
+
model.buffers,
|
128
|
+
subgraph,
|
129
|
+
1,
|
130
|
+
[3],
|
131
|
+
qtyping.NonLinearQuantParams(16, None),
|
132
|
+
)
|
133
|
+
)
|
134
|
+
self.assertEqual(
|
135
|
+
subgraph.tensors[4].type, schema_py_generated.TensorType.FLOAT16
|
136
|
+
)
|
137
|
+
|
138
|
+
def test_int4_constant_packed_correctly(self):
|
139
|
+
subgraph = self._model.subgraphs[0]
|
140
|
+
model = self._model
|
141
|
+
data = np.array(
|
142
|
+
[
|
143
|
+
0x0,
|
144
|
+
0x1,
|
145
|
+
0x2,
|
146
|
+
0x3,
|
147
|
+
0x4,
|
148
|
+
0x5,
|
149
|
+
0x6,
|
150
|
+
0x7,
|
151
|
+
0x8,
|
152
|
+
0x9,
|
153
|
+
0xA,
|
154
|
+
0xB,
|
155
|
+
0xC,
|
156
|
+
0xD,
|
157
|
+
0xE,
|
158
|
+
],
|
159
|
+
dtype=np.int8,
|
160
|
+
)
|
161
|
+
expected = np.array([0x10, 0x32, 0x54, 0x76, 0x98, 0xBA, 0xDC, 0x0E])
|
162
|
+
ret = quantize_tensor.quantize_tensor(
|
163
|
+
transformation_utils.TransformationInput(
|
164
|
+
tensor_id=7,
|
165
|
+
op_codes=model.operatorCodes,
|
166
|
+
buffers=model.buffers,
|
167
|
+
subgraph=subgraph,
|
168
|
+
producer=-1,
|
169
|
+
consumers=[4],
|
170
|
+
quant_params=qtyping.UniformQuantParams(
|
171
|
+
4, None, np.ones(1), np.ones(1), True, data
|
172
|
+
),
|
173
|
+
)
|
174
|
+
)
|
175
|
+
self.assertEqual(ret.op_id, 0)
|
176
|
+
self.assertEqual(ret.num_ops_added, 0)
|
177
|
+
np.testing.assert_array_equal(model.buffers[8].data, expected)
|
178
|
+
quant_param = subgraph.tensors[7].quantization
|
179
|
+
np.testing.assert_array_equal(quant_param.scale, [1])
|
180
|
+
np.testing.assert_array_equal(quant_param.zeroPoint, [1])
|
181
|
+
self.assertEqual(quant_param.quantizedDimension, 0)
|
182
|
+
|
183
|
+
@parameterized.named_parameters(
|
184
|
+
dict(
|
185
|
+
testcase_name="int5",
|
186
|
+
num_bits=5,
|
187
|
+
),
|
188
|
+
dict(
|
189
|
+
testcase_name="int2",
|
190
|
+
num_bits=2,
|
191
|
+
),
|
192
|
+
)
|
193
|
+
def test_int_constant_not_packed(self, num_bits):
|
194
|
+
subgraph = self._model.subgraphs[0]
|
195
|
+
model = self._model
|
196
|
+
tensor_id = 7
|
197
|
+
data = np.array([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7], dtype=np.int8)
|
198
|
+
expected = np.array([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7])
|
199
|
+
ret = quantize_tensor.quantize_tensor(
|
200
|
+
transformation_utils.TransformationInput(
|
201
|
+
tensor_id=tensor_id,
|
202
|
+
op_codes=model.operatorCodes,
|
203
|
+
buffers=model.buffers,
|
204
|
+
subgraph=subgraph,
|
205
|
+
producer=-1,
|
206
|
+
consumers=[4],
|
207
|
+
quant_params=qtyping.UniformQuantParams(
|
208
|
+
num_bits=num_bits,
|
209
|
+
quantized_dimension=None,
|
210
|
+
scale=np.ones(1),
|
211
|
+
zero_point=np.ones(1),
|
212
|
+
symmetric=True,
|
213
|
+
quantized_data=data,
|
214
|
+
),
|
215
|
+
)
|
216
|
+
)
|
217
|
+
self.assertEqual(ret.op_id, 0)
|
218
|
+
self.assertEqual(ret.num_ops_added, 0)
|
219
|
+
np.testing.assert_array_equal(model.buffers[8].data, expected)
|
220
|
+
quant_param = subgraph.tensors[tensor_id].quantization
|
221
|
+
np.testing.assert_array_equal(quant_param.scale, [1])
|
222
|
+
np.testing.assert_array_equal(quant_param.zeroPoint, [1])
|
223
|
+
self.assertEqual(quant_param.quantizedDimension, 0)
|
224
|
+
|
225
|
+
|
226
|
+
if __name__ == "__main__":
|
227
|
+
googletest.main()
|