ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +158 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
- ai_edge_quantizer/calibrator.py +11 -60
- ai_edge_quantizer/calibrator_test.py +4 -73
- ai_edge_quantizer/default_policy.py +61 -26
- ai_edge_quantizer/model_modifier.py +97 -7
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +31 -8
- ai_edge_quantizer/params_generator.py +17 -10
- ai_edge_quantizer/params_generator_test.py +2 -7
- ai_edge_quantizer/qtyping.py +86 -6
- ai_edge_quantizer/quantizer.py +166 -21
- ai_edge_quantizer/quantizer_test.py +284 -16
- ai_edge_quantizer/recipe.py +154 -42
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +118 -13
- ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
- ai_edge_quantizer/transformation_performer.py +55 -25
- ai_edge_quantizer/transformation_performer_test.py +127 -5
- ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
- ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
- ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
- ai_edge_quantizer/transformations/transformation_utils.py +129 -6
- ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +75 -2
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -170,7 +170,7 @@ class QuantizeTensorTest(parameterized.TestCase):
|
|
|
170
170
|
# Check if the scale and zero point tensors are inserted correctly.
|
|
171
171
|
self.assertEqual(quant_param.details.scales, 9)
|
|
172
172
|
# So far we don't have zero point in blockwise quantization.
|
|
173
|
-
self.assertEqual(quant_param.details.zeroPoints,
|
|
173
|
+
self.assertEqual(quant_param.details.zeroPoints, -1)
|
|
174
174
|
|
|
175
175
|
def test_int4_constant_packed_correctly(self):
|
|
176
176
|
subgraph = self._model.subgraphs[0]
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
"""Utility functions for graph transformations."""
|
|
17
17
|
|
|
18
|
+
import copy
|
|
18
19
|
import dataclasses
|
|
19
20
|
from typing import Optional, Union
|
|
20
21
|
|
|
@@ -51,39 +52,78 @@ class TransformationInput:
|
|
|
51
52
|
def add_op_code(
|
|
52
53
|
op_code: schema_py_generated.OperatorCodeT,
|
|
53
54
|
model_op_codes: list[schema_py_generated.OperatorCodeT],
|
|
55
|
+
custom_op_name: Optional[str] = None,
|
|
54
56
|
) -> int:
|
|
55
57
|
"""Add an op code into a model if it's not present.
|
|
56
58
|
|
|
57
59
|
Args:
|
|
58
60
|
op_code: The op code to be added.
|
|
59
61
|
model_op_codes: The op codes of the model.
|
|
62
|
+
custom_op_name: The custom string of the op code. If None, the op code will
|
|
63
|
+
be added as a builtin op code.
|
|
60
64
|
|
|
61
65
|
Returns:
|
|
62
66
|
The index of the op code in the model.
|
|
63
67
|
"""
|
|
68
|
+
if (
|
|
69
|
+
op_code == schema_py_generated.BuiltinOperator.CUSTOM
|
|
70
|
+
and custom_op_name is None
|
|
71
|
+
):
|
|
72
|
+
raise ValueError('Custom string is required for custom op code.')
|
|
73
|
+
|
|
64
74
|
for i, model_op_code in enumerate(model_op_codes):
|
|
75
|
+
# If the model already has the op code, just return the index.
|
|
65
76
|
if model_op_code.builtinCode == op_code:
|
|
66
|
-
|
|
77
|
+
if custom_op_name is not None:
|
|
78
|
+
if model_op_code.customCode == custom_op_name:
|
|
79
|
+
return i
|
|
80
|
+
else:
|
|
81
|
+
# Built-in op
|
|
82
|
+
return i
|
|
83
|
+
|
|
67
84
|
model_op_codes.append(schema_py_generated.OperatorCodeT())
|
|
68
85
|
model_op_codes[-1].builtinCode = op_code
|
|
86
|
+
if custom_op_name is not None:
|
|
87
|
+
model_op_codes[-1].customCode = custom_op_name
|
|
69
88
|
return len(model_op_codes) - 1
|
|
70
89
|
|
|
71
90
|
|
|
72
|
-
def
|
|
91
|
+
def get_constant_buffer(
|
|
73
92
|
data: np.ndarray,
|
|
74
93
|
buffers: list[schema_py_generated.BufferT],
|
|
94
|
+
force_duplicate_buffer: bool = False,
|
|
75
95
|
) -> int:
|
|
76
|
-
"""
|
|
96
|
+
"""Get the index of the constant buffer that contains the given data.
|
|
97
|
+
|
|
98
|
+
creating new buffer if provided data is not found in buffers list.
|
|
77
99
|
|
|
78
100
|
Args:
|
|
79
101
|
data: The data of the new tensor.
|
|
80
102
|
buffers: The buffers of the model.
|
|
103
|
+
force_duplicate_buffer: Whether to add a new buffer even if the same buffer
|
|
104
|
+
already exists.
|
|
81
105
|
|
|
82
106
|
Returns:
|
|
83
107
|
The index of the new buffer in the model.
|
|
84
108
|
"""
|
|
109
|
+
|
|
110
|
+
if isinstance(data, np.ndarray):
|
|
111
|
+
# in the case where the data is passed from quantization_params.
|
|
112
|
+
new_data = np.frombuffer(data.tobytes(), dtype=np.uint8).flatten()
|
|
113
|
+
elif isinstance(data, bytes):
|
|
114
|
+
# in the case where the data is coming from duplicating buffers, we need to
|
|
115
|
+
# make a copy of the data to avoid having two buffers pointing to the same
|
|
116
|
+
# data.
|
|
117
|
+
new_data = copy.deepcopy(data)
|
|
118
|
+
else:
|
|
119
|
+
raise ValueError('data passed in must be either np.ndarray or bytes.')
|
|
120
|
+
# TODO: b/417811116 - we should make this more efficient.
|
|
121
|
+
if not force_duplicate_buffer:
|
|
122
|
+
for index, buffer in enumerate(buffers):
|
|
123
|
+
if np.array_equal(buffer.data, new_data):
|
|
124
|
+
return index
|
|
85
125
|
new_buffer = schema_py_generated.BufferT()
|
|
86
|
-
new_buffer.data =
|
|
126
|
+
new_buffer.data = new_data
|
|
87
127
|
new_buffer.offset = 0
|
|
88
128
|
new_buffer.size = 0
|
|
89
129
|
new_buffer_id = len(buffers)
|
|
@@ -99,6 +139,7 @@ def add_new_constant_tensor(
|
|
|
99
139
|
subgraph: schema_py_generated.SubGraphT,
|
|
100
140
|
buffers: list[schema_py_generated.BufferT],
|
|
101
141
|
tensor_shape: Optional[list[int]] = None,
|
|
142
|
+
force_duplicate_buffer: bool = False,
|
|
102
143
|
) -> int:
|
|
103
144
|
"""Add a new constant tensor to the model.
|
|
104
145
|
|
|
@@ -110,11 +151,13 @@ def add_new_constant_tensor(
|
|
|
110
151
|
buffers: The buffers of the model.
|
|
111
152
|
tensor_shape: The shape of the new tensor. If not provided, the shape of the
|
|
112
153
|
data will be used.
|
|
154
|
+
force_duplicate_buffer: Whether to add a new buffer even if the same buffer
|
|
155
|
+
already exists.
|
|
113
156
|
|
|
114
157
|
Returns:
|
|
115
158
|
The index of the new tensor in the subgraph.
|
|
116
159
|
"""
|
|
117
|
-
new_buffer_id =
|
|
160
|
+
new_buffer_id = get_constant_buffer(data, buffers, force_duplicate_buffer)
|
|
118
161
|
|
|
119
162
|
new_tensor = schema_py_generated.TensorT()
|
|
120
163
|
if tensor_shape is None:
|
|
@@ -146,10 +189,90 @@ def add_new_activation_tensor(
|
|
|
146
189
|
The index of the new tensor in the subgraph.
|
|
147
190
|
"""
|
|
148
191
|
new_tensor = schema_py_generated.TensorT()
|
|
149
|
-
|
|
192
|
+
# If there's a dynamic shape, we need to read from the shapeSignature field
|
|
193
|
+
# instead of shape. Shape should contain just 1 for the dynamic dimension but
|
|
194
|
+
# shapeSignature should contain the true shape.
|
|
195
|
+
if -1 in shape:
|
|
196
|
+
new_tensor.shapeSignature = shape
|
|
197
|
+
new_tensor.shape = [1 if i == -1 else i for i in shape]
|
|
198
|
+
else:
|
|
199
|
+
new_tensor.shape = shape
|
|
150
200
|
new_tensor.type = tensor_type
|
|
151
201
|
new_tensor.name = tensor_name
|
|
152
202
|
new_tensor.buffer = 0
|
|
153
203
|
new_tensor_id = len(subgraph.tensors)
|
|
154
204
|
subgraph.tensors.append(new_tensor)
|
|
155
205
|
return new_tensor_id
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def raise_deprecated_error(_: TransformationInput):
|
|
209
|
+
raise NotImplementedError(
|
|
210
|
+
'This transformation is deprecated. Please contact AI Edge Quantizer team'
|
|
211
|
+
' if you see this error.'
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
|
|
216
|
+
"""Pack the data to the corresponding bit width.
|
|
217
|
+
|
|
218
|
+
Currently only support 4 bits. If no packing is needed, the original data is
|
|
219
|
+
returned.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
bitwidth: Bit width from NonLinearQuantParams.
|
|
223
|
+
flattened_data: The data to be packed.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Packed data.
|
|
227
|
+
"""
|
|
228
|
+
if bitwidth == 4:
|
|
229
|
+
even_data = flattened_data[::2] & 0x0F
|
|
230
|
+
odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
|
|
231
|
+
if odd_data.shape[0] == even_data.shape[0] - 1:
|
|
232
|
+
odd_data = np.pad(odd_data, (0, 1), constant_values=0)
|
|
233
|
+
return np.bitwise_or(even_data, odd_data)
|
|
234
|
+
else:
|
|
235
|
+
return flattened_data
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def get_producer_schema_op_id(
|
|
239
|
+
transformation: TransformationInput,
|
|
240
|
+
) -> int:
|
|
241
|
+
"""Checks if the tensor's producer matches the given op.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
transformation: The transformation input to check the producer of.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
The schema op id of the producer op. E.g.
|
|
248
|
+
schema_py_generated.BuiltinOperator.FULLY_CONNECTED.
|
|
249
|
+
"""
|
|
250
|
+
if transformation.producer == -1:
|
|
251
|
+
return False
|
|
252
|
+
else:
|
|
253
|
+
return (
|
|
254
|
+
transformation.op_codes[
|
|
255
|
+
transformation.subgraph.operators[
|
|
256
|
+
transformation.producer
|
|
257
|
+
].opcodeIndex
|
|
258
|
+
].builtinCode
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def get_schema_op_id(
|
|
263
|
+
transformation: TransformationInput, op_id: int
|
|
264
|
+
) -> bool:
|
|
265
|
+
"""Returns the schema op id of the given op.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
transformation: The transformation input to check the consumers of.
|
|
269
|
+
op_id: The op id in the list of operators to check for.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
The schema op id of the given op.
|
|
273
|
+
"""
|
|
274
|
+
return (
|
|
275
|
+
transformation.op_codes[
|
|
276
|
+
transformation.subgraph.operators[op_id].opcodeIndex
|
|
277
|
+
].builtinCode
|
|
278
|
+
)
|
|
@@ -41,19 +41,62 @@ class TransformationUtilsTest(parameterized.TestCase):
|
|
|
41
41
|
testcase_name="add_new_op_code",
|
|
42
42
|
op_code=schema_py_generated.BuiltinOperator.LOGISTIC,
|
|
43
43
|
expected=1,
|
|
44
|
+
custom_op_name=None,
|
|
44
45
|
),
|
|
45
46
|
dict(
|
|
46
47
|
testcase_name="add_existing_op_code",
|
|
47
48
|
op_code=schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
|
|
48
49
|
expected=0,
|
|
50
|
+
custom_op_name=None,
|
|
51
|
+
),
|
|
52
|
+
dict(
|
|
53
|
+
testcase_name="add_new_custom_op_code",
|
|
54
|
+
op_code=schema_py_generated.BuiltinOperator.CUSTOM,
|
|
55
|
+
expected=1,
|
|
56
|
+
custom_op_name="random_new_custom_op",
|
|
49
57
|
),
|
|
50
58
|
)
|
|
51
|
-
def test_add_op_code(self, op_code, expected):
|
|
59
|
+
def test_add_op_code(self, op_code, expected, custom_op_name):
|
|
52
60
|
"""Tests if the op code is added to the model."""
|
|
53
61
|
got = transformation_utils.add_op_code(
|
|
54
|
-
op_code=op_code,
|
|
62
|
+
op_code=op_code,
|
|
63
|
+
model_op_codes=self.model.operatorCodes,
|
|
64
|
+
custom_op_name=custom_op_name,
|
|
55
65
|
)
|
|
56
66
|
self.assertEqual(expected, got)
|
|
67
|
+
if custom_op_name is not None:
|
|
68
|
+
self.assertEqual(self.model.operatorCodes[got].customCode, custom_op_name)
|
|
69
|
+
|
|
70
|
+
def test_add_custom_op_code_without_op_string_raises_error(self):
|
|
71
|
+
with self.assertRaisesRegex(ValueError, "Custom string is required"):
|
|
72
|
+
transformation_utils.add_op_code(
|
|
73
|
+
op_code=schema_py_generated.BuiltinOperator.CUSTOM,
|
|
74
|
+
model_op_codes=self.model.operatorCodes,
|
|
75
|
+
custom_op_name=None,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def test_add_two_custom_op_codes(self):
|
|
79
|
+
custom_op_name = "random_new_custom_op"
|
|
80
|
+
added_index = transformation_utils.add_op_code(
|
|
81
|
+
op_code=schema_py_generated.BuiltinOperator.CUSTOM,
|
|
82
|
+
model_op_codes=self.model.operatorCodes,
|
|
83
|
+
custom_op_name=custom_op_name,
|
|
84
|
+
)
|
|
85
|
+
self.assertEqual(1, added_index)
|
|
86
|
+
self.assertEqual(
|
|
87
|
+
self.model.operatorCodes[added_index].customCode, custom_op_name
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
custom_op_name_2 = "random_new_custom_op_2"
|
|
91
|
+
added_index = transformation_utils.add_op_code(
|
|
92
|
+
op_code=schema_py_generated.BuiltinOperator.CUSTOM,
|
|
93
|
+
model_op_codes=self.model.operatorCodes,
|
|
94
|
+
custom_op_name=custom_op_name_2,
|
|
95
|
+
)
|
|
96
|
+
self.assertEqual(2, added_index)
|
|
97
|
+
self.assertEqual(
|
|
98
|
+
self.model.operatorCodes[added_index].customCode, custom_op_name_2
|
|
99
|
+
)
|
|
57
100
|
|
|
58
101
|
@parameterized.named_parameters(
|
|
59
102
|
dict(
|
|
@@ -68,7 +111,7 @@ class TransformationUtilsTest(parameterized.TestCase):
|
|
|
68
111
|
def test_add_new_constant_buffer(self, data):
|
|
69
112
|
"""Tests if the constant buffer is added to the model."""
|
|
70
113
|
prev_num_buffers = len(self.model.buffers) - 1
|
|
71
|
-
new_buffer_idx = transformation_utils.
|
|
114
|
+
new_buffer_idx = transformation_utils.get_constant_buffer(
|
|
72
115
|
data=data,
|
|
73
116
|
buffers=self.model.buffers,
|
|
74
117
|
)
|
|
@@ -189,6 +232,25 @@ class TransformationUtilsTest(parameterized.TestCase):
|
|
|
189
232
|
self.model.subgraphs[0].tensors[-1].shape,
|
|
190
233
|
)
|
|
191
234
|
|
|
235
|
+
def test_add_new_activation_tensor_with_dynamic_shape(self):
|
|
236
|
+
"""Tests adding an activation tensor with dynamic shape."""
|
|
237
|
+
subgraph = self.model.subgraphs[0]
|
|
238
|
+
new_id = transformation_utils.add_new_activation_tensor(
|
|
239
|
+
tensor_name="test_tensor",
|
|
240
|
+
shape=[1, -1, -1, 1],
|
|
241
|
+
tensor_type=schema_py_generated.TensorType.FLOAT32,
|
|
242
|
+
subgraph=subgraph,
|
|
243
|
+
)
|
|
244
|
+
# Originally had 4 tensors, new tensor is added at index 4.
|
|
245
|
+
self.assertEqual(new_id, 4)
|
|
246
|
+
self.assertLen(subgraph.tensors, 5)
|
|
247
|
+
self.assertEqual(subgraph.tensors[-1].name, "test_tensor")
|
|
248
|
+
self.assertEqual(
|
|
249
|
+
subgraph.tensors[-1].type, schema_py_generated.TensorType.FLOAT32
|
|
250
|
+
)
|
|
251
|
+
self.assertEqual(subgraph.tensors[-1].shape, [1, 1, 1, 1])
|
|
252
|
+
self.assertEqual(subgraph.tensors[-1].shapeSignature, [1, -1, -1, 1])
|
|
253
|
+
|
|
192
254
|
|
|
193
255
|
if __name__ == "__main__":
|
|
194
256
|
googletest.main()
|
|
@@ -15,9 +15,24 @@
|
|
|
15
15
|
|
|
16
16
|
"""Utilities for model calibration."""
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
import copy
|
|
19
|
+
from typing import Any, Union
|
|
20
|
+
|
|
19
21
|
import numpy as np
|
|
22
|
+
|
|
23
|
+
from ai_edge_litert.tools import flatbuffer_utils
|
|
20
24
|
from ai_edge_quantizer import qtyping
|
|
25
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
|
26
|
+
from ai_edge_quantizer.utils import constrained_ops_utils
|
|
27
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
28
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
_SignatureInput = dict[str, Any]
|
|
32
|
+
_OpQuantConstraint = common_utils.OpQuantConstraint
|
|
33
|
+
_SignatureData = dict[
|
|
34
|
+
str, list[str]
|
|
35
|
+
] # signature_key -> list of signature_names.
|
|
21
36
|
|
|
22
37
|
|
|
23
38
|
def _update_moving_average(
|
|
@@ -84,3 +99,250 @@ def min_max_update(qsv: qtyping.QSV, new_qsv: qtyping.QSV) -> qtyping.QSV:
|
|
|
84
99
|
updated_qsv["min"] = np.minimum(qsv["min"], new_qsv["min"])
|
|
85
100
|
updated_qsv["max"] = np.maximum(qsv["max"], new_qsv["max"])
|
|
86
101
|
return updated_qsv
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _find_overall_min_max(
|
|
105
|
+
qsv: qtyping.QSV, tensor_names: list[str]
|
|
106
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
107
|
+
"""Finds the overall minimum and maximum values for the given tensors.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
qsv: The quantization statistical value of the tensor (min/max).
|
|
111
|
+
tensor_names: The list of tensor names to find the minimum and maximum
|
|
112
|
+
values.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
The minimum and maximum values for the given tensors.
|
|
116
|
+
"""
|
|
117
|
+
min_value = np.inf
|
|
118
|
+
max_value = -np.inf
|
|
119
|
+
for tensor_name in tensor_names:
|
|
120
|
+
min_value = min(min_value, qsv[tensor_name]["min"])
|
|
121
|
+
max_value = max(max_value, qsv[tensor_name]["max"])
|
|
122
|
+
return min_value, max_value
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class CalibrationQsvAlignmentUtils:
|
|
126
|
+
"""Calibration utils for alignment of QSVs.
|
|
127
|
+
|
|
128
|
+
This class is used to align QSVs for a given model. It builds a list of ops
|
|
129
|
+
that need to be constrained to the same scale as the input. Based on this
|
|
130
|
+
list, it finds the corresponding tensor names for a given signature data.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def __init__(self, model_path: str):
|
|
134
|
+
self._same_as_input_scale_ops = (
|
|
135
|
+
constrained_ops_utils.get_constrained_op_list(
|
|
136
|
+
_OpQuantConstraint.SAME_AS_INPUT_SCALE
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_path)
|
|
141
|
+
self._flatbuffer_object = tfl_flatbuffer_utils.read_model(model_path)
|
|
142
|
+
|
|
143
|
+
signature_keys = list(tfl_interpreter.get_signature_list().keys())
|
|
144
|
+
|
|
145
|
+
# Build a dict of signature runners.
|
|
146
|
+
self._signature_runners = {}
|
|
147
|
+
for signature_key in signature_keys:
|
|
148
|
+
signature_runner = tfl_interpreter.get_signature_runner(signature_key)
|
|
149
|
+
self._signature_runners[signature_key] = signature_runner
|
|
150
|
+
|
|
151
|
+
def _search_tensor_by_signature_name(
|
|
152
|
+
self, signature_key: str, signature_input_output_name: str, verbose=False
|
|
153
|
+
) -> list[str]:
|
|
154
|
+
"""Searches for a tensor name for a given signature by signature input or output name.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
signature_key: Name of the signature.
|
|
158
|
+
signature_input_output_name: Name of the input or output in the signature.
|
|
159
|
+
verbose: Flag to enable verbose output.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
The list with one or two tensor names. The first one is the input tensor
|
|
163
|
+
name, and the second one is the output tensor name.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
if verbose:
|
|
167
|
+
print("Searching tensor by signature name.")
|
|
168
|
+
|
|
169
|
+
tensor_names = []
|
|
170
|
+
|
|
171
|
+
# Search among inputs.
|
|
172
|
+
input_details = self._signature_runners[signature_key].get_input_details()
|
|
173
|
+
if signature_input_output_name in input_details.keys():
|
|
174
|
+
tensor_names.append(input_details[signature_input_output_name]["name"])
|
|
175
|
+
|
|
176
|
+
# Search among outputs.
|
|
177
|
+
output_details = self._signature_runners[signature_key].get_output_details()
|
|
178
|
+
if signature_input_output_name not in output_details:
|
|
179
|
+
if not tensor_names:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"Signature {signature_key} does not have input or output"
|
|
182
|
+
f" `{signature_input_output_name}`"
|
|
183
|
+
)
|
|
184
|
+
return tensor_names
|
|
185
|
+
|
|
186
|
+
output_tensor_name = output_details[signature_input_output_name]["name"]
|
|
187
|
+
if verbose:
|
|
188
|
+
print(
|
|
189
|
+
">> Starting recursive search for the output tensor name:"
|
|
190
|
+
f" {output_tensor_name}"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
idx = self._signature_runners[signature_key]._subgraph_index # pylint: disable=protected-access
|
|
194
|
+
subgraph = self._flatbuffer_object.subgraphs[idx]
|
|
195
|
+
graph_info = qtyping.GraphInfo(
|
|
196
|
+
subgraph.tensors, self._flatbuffer_object.buffers
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Recursively search the graph for the output tensor name since it may be
|
|
200
|
+
# `SAME_AS_INPUT` constrainted.
|
|
201
|
+
operators = copy.deepcopy(subgraph.operators)
|
|
202
|
+
tensor_name = self._search_reverse_order_recursively(
|
|
203
|
+
graph_info, operators, output_tensor_name, indent=" ", verbose=verbose
|
|
204
|
+
)
|
|
205
|
+
tensor_names.append(tensor_name)
|
|
206
|
+
|
|
207
|
+
if verbose:
|
|
208
|
+
print(f"\n\nFound tensor name: {tensor_name}")
|
|
209
|
+
|
|
210
|
+
return tensor_names
|
|
211
|
+
|
|
212
|
+
def _search_reverse_order_recursively(
|
|
213
|
+
self,
|
|
214
|
+
graph_info: qtyping.GraphInfo,
|
|
215
|
+
operators: list[Any],
|
|
216
|
+
output_tensor_name: str,
|
|
217
|
+
indent: str,
|
|
218
|
+
verbose: bool = False,
|
|
219
|
+
):
|
|
220
|
+
"""Searches for a tensor name in reverse order recursively.
|
|
221
|
+
|
|
222
|
+
Stop criteria is when the tensor belongs to an operator that is not
|
|
223
|
+
`SAME_AS_INPUT` constrainted.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
graph_info: Graph information.
|
|
227
|
+
operators: List of operators.
|
|
228
|
+
output_tensor_name: Name of the output tensor to search for.
|
|
229
|
+
indent: Indentation string for debug output.
|
|
230
|
+
verbose: Flag to enable verbose output.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
The name of the tensor found, or None if not found.
|
|
234
|
+
"""
|
|
235
|
+
op_codes = self._flatbuffer_object.operatorCodes
|
|
236
|
+
|
|
237
|
+
while operators:
|
|
238
|
+
op = operators.pop()
|
|
239
|
+
op_code = op_codes[op.opcodeIndex].builtinCode
|
|
240
|
+
op_name = flatbuffer_utils.opcode_to_name(
|
|
241
|
+
self._flatbuffer_object, op.opcodeIndex
|
|
242
|
+
)
|
|
243
|
+
if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
|
|
244
|
+
continue
|
|
245
|
+
for output_idx in op.outputs:
|
|
246
|
+
if output_tensor_name == tfl_flatbuffer_utils.get_tensor_name(
|
|
247
|
+
graph_info.subgraph_tensors[output_idx]
|
|
248
|
+
):
|
|
249
|
+
dbg_str = (
|
|
250
|
+
f"{indent}>> Found `{op_name}`, output tensor"
|
|
251
|
+
f" '{output_tensor_name}'"
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if op_name not in self._same_as_input_scale_ops:
|
|
255
|
+
if verbose:
|
|
256
|
+
print(f"{dbg_str}, returning...")
|
|
257
|
+
return output_tensor_name
|
|
258
|
+
|
|
259
|
+
if verbose:
|
|
260
|
+
print(f"{dbg_str}, with SAME_AS_INPUT, search recursively among:")
|
|
261
|
+
for input_idx in op.inputs:
|
|
262
|
+
input_tensor_name = graph_info.subgraph_tensors[
|
|
263
|
+
input_idx
|
|
264
|
+
].name.decode("utf-8")
|
|
265
|
+
|
|
266
|
+
if verbose:
|
|
267
|
+
print(f"{indent} Input: {input_tensor_name}")
|
|
268
|
+
|
|
269
|
+
return self._search_reverse_order_recursively(
|
|
270
|
+
graph_info,
|
|
271
|
+
operators,
|
|
272
|
+
input_tensor_name,
|
|
273
|
+
indent=f"{indent} ",
|
|
274
|
+
verbose=verbose,
|
|
275
|
+
)
|
|
276
|
+
return output_tensor_name
|
|
277
|
+
|
|
278
|
+
def align_quant_stats(
|
|
279
|
+
self, qsv: dict[str, Any], signature_data: _SignatureData
|
|
280
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
281
|
+
"""Aligns quantization statistics for a given signature data.
|
|
282
|
+
|
|
283
|
+
This function takes quantization statistics and signature data as input,
|
|
284
|
+
identifies the tensors associated with the signature data, and aligns
|
|
285
|
+
the quantization statistics of these tensors by setting their minimum
|
|
286
|
+
and maximum values to the same value. This ensures that the tensors
|
|
287
|
+
have the same quantization parameters.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
qsv: Quantization statistics.
|
|
291
|
+
signature_data: Signature data.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
Tuple of min and max values.
|
|
295
|
+
"""
|
|
296
|
+
# Go over all signature info and find the corresponding tensor names.
|
|
297
|
+
tensor_names = []
|
|
298
|
+
for signature_key, signature_names in signature_data.items():
|
|
299
|
+
for signature_name in signature_names:
|
|
300
|
+
tensor_name = self._search_tensor_by_signature_name(
|
|
301
|
+
signature_key, signature_name
|
|
302
|
+
)
|
|
303
|
+
tensor_names.extend(tensor_name)
|
|
304
|
+
|
|
305
|
+
# Find min and max values accross all tensors.
|
|
306
|
+
min_value, max_value = _find_overall_min_max(qsv, tensor_names)
|
|
307
|
+
|
|
308
|
+
# Overwrite the min and max values in the QSV.
|
|
309
|
+
for tensor_name in tensor_names:
|
|
310
|
+
qsv[tensor_name]["min"] = min_value
|
|
311
|
+
qsv[tensor_name]["max"] = max_value
|
|
312
|
+
|
|
313
|
+
return min_value, max_value
|
|
314
|
+
|
|
315
|
+
def update_quant_stats(
|
|
316
|
+
self,
|
|
317
|
+
qsv: dict[str, Any],
|
|
318
|
+
signature_data: _SignatureData,
|
|
319
|
+
min_value: np.ndarray,
|
|
320
|
+
max_value: np.ndarray,
|
|
321
|
+
):
|
|
322
|
+
"""Updates quantization statistics for a given signature data.
|
|
323
|
+
|
|
324
|
+
This function updates the quantization statistics with the provided min, max
|
|
325
|
+
values for the tensors specified in the signature data.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
qsv: Quantization statistics.
|
|
329
|
+
signature_data: Signature data.
|
|
330
|
+
min_value: Minimum value to update.
|
|
331
|
+
max_value: Maximum value to update.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Updated quantization statistics.
|
|
335
|
+
"""
|
|
336
|
+
# Go over all signature info and find the corresponding tensor names.
|
|
337
|
+
tensor_names = []
|
|
338
|
+
for signature_key, signature_names in signature_data.items():
|
|
339
|
+
for signature_name in signature_names:
|
|
340
|
+
tensor_name = self._search_tensor_by_signature_name(
|
|
341
|
+
signature_key, signature_name
|
|
342
|
+
)
|
|
343
|
+
tensor_names.extend(tensor_name)
|
|
344
|
+
|
|
345
|
+
# Overwrite the min and max values in the QSV.
|
|
346
|
+
for tensor_name in tensor_names:
|
|
347
|
+
qsv[tensor_name]["min"] = min_value
|
|
348
|
+
qsv[tensor_name]["max"] = max_value
|