ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,132 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utility functions for graph transformations."""
|
17
|
+
|
18
|
+
import dataclasses
|
19
|
+
from typing import Union
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
|
23
|
+
from ai_edge_quantizer import qtyping
|
24
|
+
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
25
|
+
|
26
|
+
|
27
|
+
@dataclasses.dataclass
|
28
|
+
class TransformationInput:
|
29
|
+
"""Standard input for a graph transformation.
|
30
|
+
|
31
|
+
Attributes:
|
32
|
+
tensor_id: the tensor index to insert dequant op after
|
33
|
+
op_codes: list of operatorCode in the model, if dequantize op doesn't exist,
|
34
|
+
we need to insert the op code into the list
|
35
|
+
buffers: list of buffer in the original TFlite model for buffer quantization
|
36
|
+
subgraph: flatbuffer subgraph object which the tensor resides.
|
37
|
+
producer: op id for the producer of the tensor.
|
38
|
+
consumers: op ids for consumers of the new dequant op.
|
39
|
+
quant_params: quantization parameters to be applied on the orignal tensor
|
40
|
+
"""
|
41
|
+
|
42
|
+
tensor_id: int
|
43
|
+
op_codes: list[schema_py_generated.OperatorCodeT]
|
44
|
+
buffers: list[schema_py_generated.BufferT]
|
45
|
+
subgraph: schema_py_generated.SubGraphT
|
46
|
+
producer: int
|
47
|
+
consumers: list[int]
|
48
|
+
quant_params: Union[qtyping.UniformQuantParams, qtyping.NonLinearQuantParams]
|
49
|
+
|
50
|
+
|
51
|
+
def add_op_code(
|
52
|
+
op_code: schema_py_generated.OperatorCodeT,
|
53
|
+
model_op_codes: list[schema_py_generated.OperatorCodeT],
|
54
|
+
) -> int:
|
55
|
+
"""Add an op code into a model if it's not present.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
op_code: The op code to be added.
|
59
|
+
model_op_codes: The op codes of the model.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
The index of the op code in the model.
|
63
|
+
"""
|
64
|
+
for i, model_op_code in enumerate(model_op_codes):
|
65
|
+
if model_op_code.builtinCode == op_code:
|
66
|
+
return i
|
67
|
+
model_op_codes.append(schema_py_generated.OperatorCodeT())
|
68
|
+
model_op_codes[-1].builtinCode = op_code
|
69
|
+
return len(model_op_codes) - 1
|
70
|
+
|
71
|
+
|
72
|
+
def add_new_constant_tensor(
|
73
|
+
tensor_name: str,
|
74
|
+
data: np.ndarray,
|
75
|
+
tensor_type: schema_py_generated.TensorType,
|
76
|
+
subgraph: schema_py_generated.SubGraphT,
|
77
|
+
buffers: list[schema_py_generated.BufferT],
|
78
|
+
) -> int:
|
79
|
+
"""Add a new constant tensor to the model.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
tensor_name: The name of the new tensor.
|
83
|
+
data: The data of the new tensor.
|
84
|
+
tensor_type: The type of the new tensor.
|
85
|
+
subgraph: The subgraph where the new tensor is added.
|
86
|
+
buffers: The buffers of the model.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
The index of the new tensor in the subgraph.
|
90
|
+
"""
|
91
|
+
tensor_buffer = schema_py_generated.BufferT()
|
92
|
+
tensor_buffer.data = np.frombuffer(data.tobytes(), dtype=np.uint8).flatten()
|
93
|
+
tensor_buffer.offset = 0
|
94
|
+
tensor_buffer.size = 0
|
95
|
+
tensor_buffer_id = len(buffers)
|
96
|
+
buffers.append(tensor_buffer)
|
97
|
+
|
98
|
+
new_tensor = schema_py_generated.TensorT()
|
99
|
+
new_tensor.shape = data.shape
|
100
|
+
new_tensor.buffer = tensor_buffer_id
|
101
|
+
new_tensor.type = tensor_type
|
102
|
+
new_tensor.name = tensor_name
|
103
|
+
new_tensor_id = len(subgraph.tensors)
|
104
|
+
subgraph.tensors.append(new_tensor)
|
105
|
+
return new_tensor_id
|
106
|
+
|
107
|
+
|
108
|
+
def add_new_activation_tensor(
|
109
|
+
tensor_name: str,
|
110
|
+
shape: list[int],
|
111
|
+
tensor_type: schema_py_generated.TensorType,
|
112
|
+
subgraph: schema_py_generated.SubGraphT,
|
113
|
+
) -> int:
|
114
|
+
"""Add a new activation tensor to the model.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
tensor_name: The name of the new tensor.
|
118
|
+
shape: The shape of the new tensor.
|
119
|
+
tensor_type: The type of the new tensor.
|
120
|
+
subgraph: The subgraph where the new tensor is added.
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
The index of the new tensor in the subgraph.
|
124
|
+
"""
|
125
|
+
new_tensor = schema_py_generated.TensorT()
|
126
|
+
new_tensor.shape = shape
|
127
|
+
new_tensor.type = tensor_type
|
128
|
+
new_tensor.name = tensor_name
|
129
|
+
new_tensor.buffer = 0
|
130
|
+
new_tensor_id = len(subgraph.tensors)
|
131
|
+
subgraph.tensors.append(new_tensor)
|
132
|
+
return new_tensor_id
|
@@ -0,0 +1,162 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Tests for transformation_utils."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import numpy as np
|
20
|
+
from tensorflow.python.platform import googletest
|
21
|
+
from absl.testing import parameterized
|
22
|
+
from ai_edge_quantizer.transformations import transformation_utils
|
23
|
+
from ai_edge_quantizer.utils import test_utils
|
24
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
25
|
+
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
26
|
+
|
27
|
+
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../tests/models")
|
28
|
+
|
29
|
+
|
30
|
+
class TransformationUtilsTest(parameterized.TestCase):
|
31
|
+
|
32
|
+
def setUp(self):
|
33
|
+
super().setUp()
|
34
|
+
self.model_path = os.path.join(
|
35
|
+
TEST_DATA_PREFIX_PATH, "single_fc_bias.tflite"
|
36
|
+
)
|
37
|
+
self.model = tfl_flatbuffer_utils.read_model(self.model_path)
|
38
|
+
|
39
|
+
@parameterized.named_parameters(
|
40
|
+
dict(
|
41
|
+
testcase_name="add_new_op_code",
|
42
|
+
op_code=schema_py_generated.BuiltinOperator.LOGISTIC,
|
43
|
+
expected=1,
|
44
|
+
),
|
45
|
+
dict(
|
46
|
+
testcase_name="add_existing_op_code",
|
47
|
+
op_code=schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
|
48
|
+
expected=0,
|
49
|
+
),
|
50
|
+
)
|
51
|
+
def test_add_op_code(self, op_code, expected):
|
52
|
+
"""Tests if the op code is added to the model."""
|
53
|
+
got = transformation_utils.add_op_code(
|
54
|
+
op_code=op_code, model_op_codes=self.model.operatorCodes
|
55
|
+
)
|
56
|
+
self.assertEqual(expected, got)
|
57
|
+
|
58
|
+
@parameterized.named_parameters(
|
59
|
+
dict(
|
60
|
+
testcase_name="float32",
|
61
|
+
tensor_data=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
|
62
|
+
tensor_type=schema_py_generated.TensorType.FLOAT32,
|
63
|
+
expected_type=schema_py_generated.TensorType.FLOAT32,
|
64
|
+
expected_shape=(4,),
|
65
|
+
expected_buffer_data=np.frombuffer(
|
66
|
+
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).tobytes(),
|
67
|
+
dtype=np.uint8,
|
68
|
+
).flatten(),
|
69
|
+
),
|
70
|
+
dict(
|
71
|
+
testcase_name="int8",
|
72
|
+
tensor_data=np.array([[1, 2], [3, 4]], dtype=np.int8),
|
73
|
+
tensor_type=schema_py_generated.TensorType.INT8,
|
74
|
+
expected_type=schema_py_generated.TensorType.INT8,
|
75
|
+
expected_shape=(2, 2),
|
76
|
+
expected_buffer_data=np.frombuffer(
|
77
|
+
np.array([[1, 2], [3, 4]], dtype=np.int8).tobytes(),
|
78
|
+
dtype=np.uint8,
|
79
|
+
).flatten(),
|
80
|
+
),
|
81
|
+
)
|
82
|
+
def test_add_new_constant_tensor(
|
83
|
+
self,
|
84
|
+
tensor_data,
|
85
|
+
tensor_type,
|
86
|
+
expected_type,
|
87
|
+
expected_shape,
|
88
|
+
expected_buffer_data,
|
89
|
+
):
|
90
|
+
"""Tests if the constant tensor is added to the model."""
|
91
|
+
ret = transformation_utils.add_new_constant_tensor(
|
92
|
+
tensor_name="test_tensor",
|
93
|
+
data=tensor_data,
|
94
|
+
tensor_type=tensor_type,
|
95
|
+
subgraph=self.model.subgraphs[0],
|
96
|
+
buffers=self.model.buffers,
|
97
|
+
)
|
98
|
+
self.assertEqual(ret, len(self.model.subgraphs[0].tensors) - 1)
|
99
|
+
self.assertEqual(
|
100
|
+
str(self.model.subgraphs[0].tensors[-1].name), "test_tensor"
|
101
|
+
)
|
102
|
+
self.assertEqual(
|
103
|
+
expected_type,
|
104
|
+
self.model.subgraphs[0].tensors[-1].type,
|
105
|
+
)
|
106
|
+
self.assertEqual(
|
107
|
+
expected_shape,
|
108
|
+
self.model.subgraphs[0].tensors[-1].shape,
|
109
|
+
)
|
110
|
+
self.assertListEqual(
|
111
|
+
expected_buffer_data.tolist(),
|
112
|
+
self.model.buffers[
|
113
|
+
self.model.subgraphs[0].tensors[-1].buffer
|
114
|
+
].data.tolist(),
|
115
|
+
)
|
116
|
+
|
117
|
+
@parameterized.named_parameters(
|
118
|
+
dict(
|
119
|
+
testcase_name="float32",
|
120
|
+
tensor_type=schema_py_generated.TensorType.FLOAT32,
|
121
|
+
tensor_shape=[1, 1, 1, 1],
|
122
|
+
expected_shape=[1, 1, 1, 1],
|
123
|
+
expected_type=schema_py_generated.TensorType.FLOAT32,
|
124
|
+
),
|
125
|
+
dict(
|
126
|
+
testcase_name="int8",
|
127
|
+
tensor_type=schema_py_generated.TensorType.INT8,
|
128
|
+
tensor_shape=[1, 224, 224, 1],
|
129
|
+
expected_shape=[1, 224, 224, 1],
|
130
|
+
expected_type=schema_py_generated.TensorType.INT8,
|
131
|
+
),
|
132
|
+
)
|
133
|
+
def test_add_new_activation_tensor_to_subgraph(
|
134
|
+
self,
|
135
|
+
tensor_type,
|
136
|
+
tensor_shape,
|
137
|
+
expected_shape,
|
138
|
+
expected_type,
|
139
|
+
):
|
140
|
+
"""Tests if the activation tensor is added to the subgraph."""
|
141
|
+
ret = transformation_utils.add_new_activation_tensor(
|
142
|
+
tensor_name="test_tensor",
|
143
|
+
shape=tensor_shape,
|
144
|
+
tensor_type=tensor_type,
|
145
|
+
subgraph=self.model.subgraphs[0],
|
146
|
+
)
|
147
|
+
self.assertEqual(ret, len(self.model.subgraphs[0].tensors) - 1)
|
148
|
+
self.assertEqual(
|
149
|
+
str(self.model.subgraphs[0].tensors[-1].name), "test_tensor"
|
150
|
+
)
|
151
|
+
self.assertEqual(
|
152
|
+
expected_type,
|
153
|
+
self.model.subgraphs[0].tensors[-1].type,
|
154
|
+
)
|
155
|
+
self.assertEqual(
|
156
|
+
expected_shape,
|
157
|
+
self.model.subgraphs[0].tensors[-1].shape,
|
158
|
+
)
|
159
|
+
|
160
|
+
|
161
|
+
if __name__ == "__main__":
|
162
|
+
googletest.main()
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
@@ -0,0 +1,86 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utilities for model calibration."""
|
17
|
+
|
18
|
+
from typing import Union
|
19
|
+
import numpy as np
|
20
|
+
from ai_edge_quantizer import qtyping
|
21
|
+
|
22
|
+
|
23
|
+
def _update_moving_average(
|
24
|
+
smoothing_factor: Union[np.ndarray, float],
|
25
|
+
w: np.ndarray,
|
26
|
+
update: np.ndarray,
|
27
|
+
) -> np.ndarray:
|
28
|
+
"""Updates weight w with moving average.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
smoothing_factor: Smoothing factor used to update w.
|
32
|
+
w: Weights to be updated.
|
33
|
+
update: Value used for update.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
Weighted sum of w and update.
|
37
|
+
"""
|
38
|
+
return smoothing_factor * w + (1.0 - smoothing_factor) * update
|
39
|
+
|
40
|
+
|
41
|
+
def moving_average_update(
|
42
|
+
qsv: qtyping.QSV, new_qsv: qtyping.QSV, smoothing_factor: float = 0.95
|
43
|
+
) -> qtyping.QSV:
|
44
|
+
"""Update the QSV (i.e., min/max) using moving average.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
qsv: The quantization statistical value of the tensor (min/max) that need to
|
48
|
+
be updated.
|
49
|
+
new_qsv: The new QSVs (e.g., from new round of calibration).
|
50
|
+
smoothing_factor: The weight of moving average.
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
The updated QSV for the tensor.
|
54
|
+
"""
|
55
|
+
if not qsv:
|
56
|
+
return new_qsv
|
57
|
+
|
58
|
+
updated_qsv = {}
|
59
|
+
updated_qsv["min"] = _update_moving_average(
|
60
|
+
smoothing_factor, qsv["min"], new_qsv["min"]
|
61
|
+
)
|
62
|
+
|
63
|
+
updated_qsv["max"] = _update_moving_average(
|
64
|
+
smoothing_factor, qsv["max"], new_qsv["max"]
|
65
|
+
)
|
66
|
+
return updated_qsv
|
67
|
+
|
68
|
+
|
69
|
+
def min_max_update(qsv: qtyping.QSV, new_qsv: qtyping.QSV) -> qtyping.QSV:
|
70
|
+
"""Update the QSV with minimum min values and maximum max values.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
qsv: The quantization statistical value of the tensor (min/max) that need to
|
74
|
+
be updated.
|
75
|
+
new_qsv: The new QSVs (e.g., from new round of calibration).
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
The updated QSV for the tensor.
|
79
|
+
"""
|
80
|
+
if not qsv:
|
81
|
+
return new_qsv
|
82
|
+
|
83
|
+
updated_qsv = {}
|
84
|
+
updated_qsv["min"] = np.minimum(qsv["min"], new_qsv["min"])
|
85
|
+
updated_qsv["max"] = np.maximum(qsv["max"], new_qsv["max"])
|
86
|
+
return updated_qsv
|
@@ -0,0 +1,77 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from absl.testing import parameterized
|
17
|
+
from tensorflow.python.platform import googletest
|
18
|
+
from ai_edge_quantizer.utils import calibration_utils
|
19
|
+
|
20
|
+
|
21
|
+
class CalibrationUtilsTest(parameterized.TestCase):
|
22
|
+
|
23
|
+
@parameterized.named_parameters(
|
24
|
+
dict(
|
25
|
+
testcase_name="zero_smoothing_factor",
|
26
|
+
smoothing_factor=0,
|
27
|
+
expected_vals={"min": -1000, "max": 800},
|
28
|
+
),
|
29
|
+
dict(
|
30
|
+
testcase_name="one_smoothing_factor",
|
31
|
+
smoothing_factor=1,
|
32
|
+
expected_vals={"min": -10, "max": 8},
|
33
|
+
),
|
34
|
+
dict(
|
35
|
+
testcase_name="normal_smoothing_factor",
|
36
|
+
smoothing_factor=0.99,
|
37
|
+
expected_vals={"min": -19.9, "max": 15.92},
|
38
|
+
),
|
39
|
+
)
|
40
|
+
def test_update_tensor_qsv_moving_average(
|
41
|
+
self, smoothing_factor, expected_vals
|
42
|
+
):
|
43
|
+
old_qsv = {"min": -10, "max": 8}
|
44
|
+
# Large values to mimic outlier.
|
45
|
+
new_qsv = {"min": -1000, "max": 800}
|
46
|
+
updated_qsv = calibration_utils.moving_average_update(
|
47
|
+
old_qsv, new_qsv, smoothing_factor=smoothing_factor
|
48
|
+
)
|
49
|
+
self.assertAlmostEqual(updated_qsv["min"], expected_vals["min"])
|
50
|
+
self.assertAlmostEqual(updated_qsv["max"], expected_vals["max"])
|
51
|
+
|
52
|
+
@parameterized.named_parameters(
|
53
|
+
dict(
|
54
|
+
testcase_name="scalar",
|
55
|
+
old_qsv={"min": -10, "max": 8},
|
56
|
+
new_qsv={"min": -1000, "max": 1},
|
57
|
+
expected_qsv={"min": -1000, "max": 8},
|
58
|
+
),
|
59
|
+
dict(
|
60
|
+
testcase_name="2darray",
|
61
|
+
old_qsv={"min": [[-19], [20]], "max": [[21], [250]]},
|
62
|
+
new_qsv={"min": [[-1000], [25]], "max": [[33], [100]]},
|
63
|
+
expected_qsv={"min": [[-1000], [20]], "max": [[33], [250]]},
|
64
|
+
),
|
65
|
+
)
|
66
|
+
def test_update_tensor_qsv_min_max(self, old_qsv, new_qsv, expected_qsv):
|
67
|
+
updated_qsv = calibration_utils.min_max_update(old_qsv, new_qsv)
|
68
|
+
if isinstance(expected_qsv["min"], list):
|
69
|
+
self.assertListEqual(list(updated_qsv["min"]), expected_qsv["min"])
|
70
|
+
self.assertListEqual(list(updated_qsv["max"]), expected_qsv["max"])
|
71
|
+
else:
|
72
|
+
self.assertEqual(updated_qsv["min"], expected_qsv["min"])
|
73
|
+
self.assertEqual(updated_qsv["max"], expected_qsv["max"])
|
74
|
+
|
75
|
+
|
76
|
+
if __name__ == "__main__":
|
77
|
+
googletest.main()
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utils for tests."""
|
17
|
+
|
18
|
+
import inspect as _inspect
|
19
|
+
import os.path as _os_path
|
20
|
+
import sys as _sys
|
21
|
+
from typing import Any, Union
|
22
|
+
|
23
|
+
import numpy as np
|
24
|
+
|
25
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
26
|
+
|
27
|
+
|
28
|
+
def get_path_to_datafile(path):
|
29
|
+
"""Get the path to the specified file in the data dependencies.
|
30
|
+
|
31
|
+
The path is relative to the file calling the function.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
path: a string resource path relative to the calling file.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
The path to the specified file present in the data attribute of py_test
|
38
|
+
or py_binary.
|
39
|
+
|
40
|
+
Raises:
|
41
|
+
IOError: If the path is not found, or the resource can't be opened.
|
42
|
+
"""
|
43
|
+
data_files_path = _os_path.dirname(_inspect.getfile(_sys._getframe(1))) # pylint: disable=protected-access
|
44
|
+
path = _os_path.join(data_files_path, path)
|
45
|
+
path = _os_path.normpath(path)
|
46
|
+
return path
|
47
|
+
|
48
|
+
|
49
|
+
def create_random_normal_dataset(
|
50
|
+
input_details: dict[str, Any],
|
51
|
+
num_samples: int,
|
52
|
+
random_seed: Union[int, np._typing.ArrayLike],
|
53
|
+
) -> list[dict[str, Any]]:
|
54
|
+
"""create random dataset following random distribution.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
input_details: list of dict created by
|
58
|
+
tensorflow.lite.interpreter.get_input_details() for generating dataset
|
59
|
+
num_samples: number of input samples to be generated
|
60
|
+
random_seed: random seed to be used for function
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
a list of inputs to the given interpreter, for a single interpreter we may
|
64
|
+
have multiple input tensors so each set of inputs is also represented as
|
65
|
+
list
|
66
|
+
"""
|
67
|
+
rng = np.random.default_rng(random_seed)
|
68
|
+
dataset = []
|
69
|
+
for _ in range(num_samples):
|
70
|
+
input_data = {}
|
71
|
+
for arg_name, input_tensor in input_details.items():
|
72
|
+
new_data = rng.normal(size=input_tensor['shape']).astype(
|
73
|
+
input_tensor['dtype']
|
74
|
+
)
|
75
|
+
input_data[arg_name] = new_data
|
76
|
+
dataset.append(input_data)
|
77
|
+
return dataset
|
78
|
+
|
79
|
+
|
80
|
+
def create_random_normal_input_data(
|
81
|
+
tflite_model: Union[str, bytes],
|
82
|
+
num_samples: int = 4,
|
83
|
+
random_seed: int = 666,
|
84
|
+
) -> dict[str, list[dict[str, Any]]]:
|
85
|
+
"""create random dataset following random distribution for signature runner.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
tflite_model: TFLite model path or bytearray
|
89
|
+
num_samples: number of input samples to be generated
|
90
|
+
random_seed: random seed to be used for function
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
a list of inputs to the given interpreter, for a single interpreter we may
|
94
|
+
have multiple signatures so each set of inputs is also represented as
|
95
|
+
list
|
96
|
+
"""
|
97
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(tflite_model)
|
98
|
+
signature_defs = tfl_interpreter.get_signature_list()
|
99
|
+
signature_keys = list(signature_defs.keys())
|
100
|
+
test_data = {}
|
101
|
+
for signature_key in signature_keys:
|
102
|
+
signature_runner = tfl_interpreter.get_signature_runner(signature_key)
|
103
|
+
input_details = signature_runner.get_input_details()
|
104
|
+
test_data[signature_key] = create_random_normal_dataset(
|
105
|
+
input_details, num_samples, random_seed
|
106
|
+
)
|
107
|
+
return test_data
|