ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,132 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utility functions for graph transformations."""
17
+
18
+ import dataclasses
19
+ from typing import Union
20
+
21
+ import numpy as np
22
+
23
+ from ai_edge_quantizer import qtyping
24
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class TransformationInput:
29
+ """Standard input for a graph transformation.
30
+
31
+ Attributes:
32
+ tensor_id: the tensor index to insert dequant op after
33
+ op_codes: list of operatorCode in the model, if dequantize op doesn't exist,
34
+ we need to insert the op code into the list
35
+ buffers: list of buffer in the original TFlite model for buffer quantization
36
+ subgraph: flatbuffer subgraph object which the tensor resides.
37
+ producer: op id for the producer of the tensor.
38
+ consumers: op ids for consumers of the new dequant op.
39
+ quant_params: quantization parameters to be applied on the orignal tensor
40
+ """
41
+
42
+ tensor_id: int
43
+ op_codes: list[schema_py_generated.OperatorCodeT]
44
+ buffers: list[schema_py_generated.BufferT]
45
+ subgraph: schema_py_generated.SubGraphT
46
+ producer: int
47
+ consumers: list[int]
48
+ quant_params: Union[qtyping.UniformQuantParams, qtyping.NonLinearQuantParams]
49
+
50
+
51
+ def add_op_code(
52
+ op_code: schema_py_generated.OperatorCodeT,
53
+ model_op_codes: list[schema_py_generated.OperatorCodeT],
54
+ ) -> int:
55
+ """Add an op code into a model if it's not present.
56
+
57
+ Args:
58
+ op_code: The op code to be added.
59
+ model_op_codes: The op codes of the model.
60
+
61
+ Returns:
62
+ The index of the op code in the model.
63
+ """
64
+ for i, model_op_code in enumerate(model_op_codes):
65
+ if model_op_code.builtinCode == op_code:
66
+ return i
67
+ model_op_codes.append(schema_py_generated.OperatorCodeT())
68
+ model_op_codes[-1].builtinCode = op_code
69
+ return len(model_op_codes) - 1
70
+
71
+
72
+ def add_new_constant_tensor(
73
+ tensor_name: str,
74
+ data: np.ndarray,
75
+ tensor_type: schema_py_generated.TensorType,
76
+ subgraph: schema_py_generated.SubGraphT,
77
+ buffers: list[schema_py_generated.BufferT],
78
+ ) -> int:
79
+ """Add a new constant tensor to the model.
80
+
81
+ Args:
82
+ tensor_name: The name of the new tensor.
83
+ data: The data of the new tensor.
84
+ tensor_type: The type of the new tensor.
85
+ subgraph: The subgraph where the new tensor is added.
86
+ buffers: The buffers of the model.
87
+
88
+ Returns:
89
+ The index of the new tensor in the subgraph.
90
+ """
91
+ tensor_buffer = schema_py_generated.BufferT()
92
+ tensor_buffer.data = np.frombuffer(data.tobytes(), dtype=np.uint8).flatten()
93
+ tensor_buffer.offset = 0
94
+ tensor_buffer.size = 0
95
+ tensor_buffer_id = len(buffers)
96
+ buffers.append(tensor_buffer)
97
+
98
+ new_tensor = schema_py_generated.TensorT()
99
+ new_tensor.shape = data.shape
100
+ new_tensor.buffer = tensor_buffer_id
101
+ new_tensor.type = tensor_type
102
+ new_tensor.name = tensor_name
103
+ new_tensor_id = len(subgraph.tensors)
104
+ subgraph.tensors.append(new_tensor)
105
+ return new_tensor_id
106
+
107
+
108
+ def add_new_activation_tensor(
109
+ tensor_name: str,
110
+ shape: list[int],
111
+ tensor_type: schema_py_generated.TensorType,
112
+ subgraph: schema_py_generated.SubGraphT,
113
+ ) -> int:
114
+ """Add a new activation tensor to the model.
115
+
116
+ Args:
117
+ tensor_name: The name of the new tensor.
118
+ shape: The shape of the new tensor.
119
+ tensor_type: The type of the new tensor.
120
+ subgraph: The subgraph where the new tensor is added.
121
+
122
+ Returns:
123
+ The index of the new tensor in the subgraph.
124
+ """
125
+ new_tensor = schema_py_generated.TensorT()
126
+ new_tensor.shape = shape
127
+ new_tensor.type = tensor_type
128
+ new_tensor.name = tensor_name
129
+ new_tensor.buffer = 0
130
+ new_tensor_id = len(subgraph.tensors)
131
+ subgraph.tensors.append(new_tensor)
132
+ return new_tensor_id
@@ -0,0 +1,162 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Tests for transformation_utils."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from tensorflow.python.platform import googletest
21
+ from absl.testing import parameterized
22
+ from ai_edge_quantizer.transformations import transformation_utils
23
+ from ai_edge_quantizer.utils import test_utils
24
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
25
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
26
+
27
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../tests/models")
28
+
29
+
30
+ class TransformationUtilsTest(parameterized.TestCase):
31
+
32
+ def setUp(self):
33
+ super().setUp()
34
+ self.model_path = os.path.join(
35
+ TEST_DATA_PREFIX_PATH, "single_fc_bias.tflite"
36
+ )
37
+ self.model = tfl_flatbuffer_utils.read_model(self.model_path)
38
+
39
+ @parameterized.named_parameters(
40
+ dict(
41
+ testcase_name="add_new_op_code",
42
+ op_code=schema_py_generated.BuiltinOperator.LOGISTIC,
43
+ expected=1,
44
+ ),
45
+ dict(
46
+ testcase_name="add_existing_op_code",
47
+ op_code=schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
48
+ expected=0,
49
+ ),
50
+ )
51
+ def test_add_op_code(self, op_code, expected):
52
+ """Tests if the op code is added to the model."""
53
+ got = transformation_utils.add_op_code(
54
+ op_code=op_code, model_op_codes=self.model.operatorCodes
55
+ )
56
+ self.assertEqual(expected, got)
57
+
58
+ @parameterized.named_parameters(
59
+ dict(
60
+ testcase_name="float32",
61
+ tensor_data=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
62
+ tensor_type=schema_py_generated.TensorType.FLOAT32,
63
+ expected_type=schema_py_generated.TensorType.FLOAT32,
64
+ expected_shape=(4,),
65
+ expected_buffer_data=np.frombuffer(
66
+ np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).tobytes(),
67
+ dtype=np.uint8,
68
+ ).flatten(),
69
+ ),
70
+ dict(
71
+ testcase_name="int8",
72
+ tensor_data=np.array([[1, 2], [3, 4]], dtype=np.int8),
73
+ tensor_type=schema_py_generated.TensorType.INT8,
74
+ expected_type=schema_py_generated.TensorType.INT8,
75
+ expected_shape=(2, 2),
76
+ expected_buffer_data=np.frombuffer(
77
+ np.array([[1, 2], [3, 4]], dtype=np.int8).tobytes(),
78
+ dtype=np.uint8,
79
+ ).flatten(),
80
+ ),
81
+ )
82
+ def test_add_new_constant_tensor(
83
+ self,
84
+ tensor_data,
85
+ tensor_type,
86
+ expected_type,
87
+ expected_shape,
88
+ expected_buffer_data,
89
+ ):
90
+ """Tests if the constant tensor is added to the model."""
91
+ ret = transformation_utils.add_new_constant_tensor(
92
+ tensor_name="test_tensor",
93
+ data=tensor_data,
94
+ tensor_type=tensor_type,
95
+ subgraph=self.model.subgraphs[0],
96
+ buffers=self.model.buffers,
97
+ )
98
+ self.assertEqual(ret, len(self.model.subgraphs[0].tensors) - 1)
99
+ self.assertEqual(
100
+ str(self.model.subgraphs[0].tensors[-1].name), "test_tensor"
101
+ )
102
+ self.assertEqual(
103
+ expected_type,
104
+ self.model.subgraphs[0].tensors[-1].type,
105
+ )
106
+ self.assertEqual(
107
+ expected_shape,
108
+ self.model.subgraphs[0].tensors[-1].shape,
109
+ )
110
+ self.assertListEqual(
111
+ expected_buffer_data.tolist(),
112
+ self.model.buffers[
113
+ self.model.subgraphs[0].tensors[-1].buffer
114
+ ].data.tolist(),
115
+ )
116
+
117
+ @parameterized.named_parameters(
118
+ dict(
119
+ testcase_name="float32",
120
+ tensor_type=schema_py_generated.TensorType.FLOAT32,
121
+ tensor_shape=[1, 1, 1, 1],
122
+ expected_shape=[1, 1, 1, 1],
123
+ expected_type=schema_py_generated.TensorType.FLOAT32,
124
+ ),
125
+ dict(
126
+ testcase_name="int8",
127
+ tensor_type=schema_py_generated.TensorType.INT8,
128
+ tensor_shape=[1, 224, 224, 1],
129
+ expected_shape=[1, 224, 224, 1],
130
+ expected_type=schema_py_generated.TensorType.INT8,
131
+ ),
132
+ )
133
+ def test_add_new_activation_tensor_to_subgraph(
134
+ self,
135
+ tensor_type,
136
+ tensor_shape,
137
+ expected_shape,
138
+ expected_type,
139
+ ):
140
+ """Tests if the activation tensor is added to the subgraph."""
141
+ ret = transformation_utils.add_new_activation_tensor(
142
+ tensor_name="test_tensor",
143
+ shape=tensor_shape,
144
+ tensor_type=tensor_type,
145
+ subgraph=self.model.subgraphs[0],
146
+ )
147
+ self.assertEqual(ret, len(self.model.subgraphs[0].tensors) - 1)
148
+ self.assertEqual(
149
+ str(self.model.subgraphs[0].tensors[-1].name), "test_tensor"
150
+ )
151
+ self.assertEqual(
152
+ expected_type,
153
+ self.model.subgraphs[0].tensors[-1].type,
154
+ )
155
+ self.assertEqual(
156
+ expected_shape,
157
+ self.model.subgraphs[0].tensors[-1].shape,
158
+ )
159
+
160
+
161
+ if __name__ == "__main__":
162
+ googletest.main()
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
@@ -0,0 +1,86 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities for model calibration."""
17
+
18
+ from typing import Union
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+
22
+
23
+ def _update_moving_average(
24
+ smoothing_factor: Union[np.ndarray, float],
25
+ w: np.ndarray,
26
+ update: np.ndarray,
27
+ ) -> np.ndarray:
28
+ """Updates weight w with moving average.
29
+
30
+ Args:
31
+ smoothing_factor: Smoothing factor used to update w.
32
+ w: Weights to be updated.
33
+ update: Value used for update.
34
+
35
+ Returns:
36
+ Weighted sum of w and update.
37
+ """
38
+ return smoothing_factor * w + (1.0 - smoothing_factor) * update
39
+
40
+
41
+ def moving_average_update(
42
+ qsv: qtyping.QSV, new_qsv: qtyping.QSV, smoothing_factor: float = 0.95
43
+ ) -> qtyping.QSV:
44
+ """Update the QSV (i.e., min/max) using moving average.
45
+
46
+ Args:
47
+ qsv: The quantization statistical value of the tensor (min/max) that need to
48
+ be updated.
49
+ new_qsv: The new QSVs (e.g., from new round of calibration).
50
+ smoothing_factor: The weight of moving average.
51
+
52
+ Returns:
53
+ The updated QSV for the tensor.
54
+ """
55
+ if not qsv:
56
+ return new_qsv
57
+
58
+ updated_qsv = {}
59
+ updated_qsv["min"] = _update_moving_average(
60
+ smoothing_factor, qsv["min"], new_qsv["min"]
61
+ )
62
+
63
+ updated_qsv["max"] = _update_moving_average(
64
+ smoothing_factor, qsv["max"], new_qsv["max"]
65
+ )
66
+ return updated_qsv
67
+
68
+
69
+ def min_max_update(qsv: qtyping.QSV, new_qsv: qtyping.QSV) -> qtyping.QSV:
70
+ """Update the QSV with minimum min values and maximum max values.
71
+
72
+ Args:
73
+ qsv: The quantization statistical value of the tensor (min/max) that need to
74
+ be updated.
75
+ new_qsv: The new QSVs (e.g., from new round of calibration).
76
+
77
+ Returns:
78
+ The updated QSV for the tensor.
79
+ """
80
+ if not qsv:
81
+ return new_qsv
82
+
83
+ updated_qsv = {}
84
+ updated_qsv["min"] = np.minimum(qsv["min"], new_qsv["min"])
85
+ updated_qsv["max"] = np.maximum(qsv["max"], new_qsv["max"])
86
+ return updated_qsv
@@ -0,0 +1,77 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from absl.testing import parameterized
17
+ from tensorflow.python.platform import googletest
18
+ from ai_edge_quantizer.utils import calibration_utils
19
+
20
+
21
+ class CalibrationUtilsTest(parameterized.TestCase):
22
+
23
+ @parameterized.named_parameters(
24
+ dict(
25
+ testcase_name="zero_smoothing_factor",
26
+ smoothing_factor=0,
27
+ expected_vals={"min": -1000, "max": 800},
28
+ ),
29
+ dict(
30
+ testcase_name="one_smoothing_factor",
31
+ smoothing_factor=1,
32
+ expected_vals={"min": -10, "max": 8},
33
+ ),
34
+ dict(
35
+ testcase_name="normal_smoothing_factor",
36
+ smoothing_factor=0.99,
37
+ expected_vals={"min": -19.9, "max": 15.92},
38
+ ),
39
+ )
40
+ def test_update_tensor_qsv_moving_average(
41
+ self, smoothing_factor, expected_vals
42
+ ):
43
+ old_qsv = {"min": -10, "max": 8}
44
+ # Large values to mimic outlier.
45
+ new_qsv = {"min": -1000, "max": 800}
46
+ updated_qsv = calibration_utils.moving_average_update(
47
+ old_qsv, new_qsv, smoothing_factor=smoothing_factor
48
+ )
49
+ self.assertAlmostEqual(updated_qsv["min"], expected_vals["min"])
50
+ self.assertAlmostEqual(updated_qsv["max"], expected_vals["max"])
51
+
52
+ @parameterized.named_parameters(
53
+ dict(
54
+ testcase_name="scalar",
55
+ old_qsv={"min": -10, "max": 8},
56
+ new_qsv={"min": -1000, "max": 1},
57
+ expected_qsv={"min": -1000, "max": 8},
58
+ ),
59
+ dict(
60
+ testcase_name="2darray",
61
+ old_qsv={"min": [[-19], [20]], "max": [[21], [250]]},
62
+ new_qsv={"min": [[-1000], [25]], "max": [[33], [100]]},
63
+ expected_qsv={"min": [[-1000], [20]], "max": [[33], [250]]},
64
+ ),
65
+ )
66
+ def test_update_tensor_qsv_min_max(self, old_qsv, new_qsv, expected_qsv):
67
+ updated_qsv = calibration_utils.min_max_update(old_qsv, new_qsv)
68
+ if isinstance(expected_qsv["min"], list):
69
+ self.assertListEqual(list(updated_qsv["min"]), expected_qsv["min"])
70
+ self.assertListEqual(list(updated_qsv["max"]), expected_qsv["max"])
71
+ else:
72
+ self.assertEqual(updated_qsv["min"], expected_qsv["min"])
73
+ self.assertEqual(updated_qsv["max"], expected_qsv["max"])
74
+
75
+
76
+ if __name__ == "__main__":
77
+ googletest.main()
@@ -0,0 +1,107 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utils for tests."""
17
+
18
+ import inspect as _inspect
19
+ import os.path as _os_path
20
+ import sys as _sys
21
+ from typing import Any, Union
22
+
23
+ import numpy as np
24
+
25
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
26
+
27
+
28
+ def get_path_to_datafile(path):
29
+ """Get the path to the specified file in the data dependencies.
30
+
31
+ The path is relative to the file calling the function.
32
+
33
+ Args:
34
+ path: a string resource path relative to the calling file.
35
+
36
+ Returns:
37
+ The path to the specified file present in the data attribute of py_test
38
+ or py_binary.
39
+
40
+ Raises:
41
+ IOError: If the path is not found, or the resource can't be opened.
42
+ """
43
+ data_files_path = _os_path.dirname(_inspect.getfile(_sys._getframe(1))) # pylint: disable=protected-access
44
+ path = _os_path.join(data_files_path, path)
45
+ path = _os_path.normpath(path)
46
+ return path
47
+
48
+
49
+ def create_random_normal_dataset(
50
+ input_details: dict[str, Any],
51
+ num_samples: int,
52
+ random_seed: Union[int, np._typing.ArrayLike],
53
+ ) -> list[dict[str, Any]]:
54
+ """create random dataset following random distribution.
55
+
56
+ Args:
57
+ input_details: list of dict created by
58
+ tensorflow.lite.interpreter.get_input_details() for generating dataset
59
+ num_samples: number of input samples to be generated
60
+ random_seed: random seed to be used for function
61
+
62
+ Returns:
63
+ a list of inputs to the given interpreter, for a single interpreter we may
64
+ have multiple input tensors so each set of inputs is also represented as
65
+ list
66
+ """
67
+ rng = np.random.default_rng(random_seed)
68
+ dataset = []
69
+ for _ in range(num_samples):
70
+ input_data = {}
71
+ for arg_name, input_tensor in input_details.items():
72
+ new_data = rng.normal(size=input_tensor['shape']).astype(
73
+ input_tensor['dtype']
74
+ )
75
+ input_data[arg_name] = new_data
76
+ dataset.append(input_data)
77
+ return dataset
78
+
79
+
80
+ def create_random_normal_input_data(
81
+ tflite_model: Union[str, bytes],
82
+ num_samples: int = 4,
83
+ random_seed: int = 666,
84
+ ) -> dict[str, list[dict[str, Any]]]:
85
+ """create random dataset following random distribution for signature runner.
86
+
87
+ Args:
88
+ tflite_model: TFLite model path or bytearray
89
+ num_samples: number of input samples to be generated
90
+ random_seed: random seed to be used for function
91
+
92
+ Returns:
93
+ a list of inputs to the given interpreter, for a single interpreter we may
94
+ have multiple signatures so each set of inputs is also represented as
95
+ list
96
+ """
97
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(tflite_model)
98
+ signature_defs = tfl_interpreter.get_signature_list()
99
+ signature_keys = list(signature_defs.keys())
100
+ test_data = {}
101
+ for signature_key in signature_keys:
102
+ signature_runner = tfl_interpreter.get_signature_runner(signature_key)
103
+ input_details = signature_runner.get_input_details()
104
+ test_data[signature_key] = create_random_normal_dataset(
105
+ input_details, num_samples, random_seed
106
+ )
107
+ return test_data