ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,312 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Util functions for TFL interpreter."""
17
+
18
+ from typing import Any, Optional, Union
19
+
20
+ import numpy as np
21
+
22
+ from ai_edge_quantizer import qtyping
23
+ from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
24
+ from ai_edge_litert import interpreter as tfl # pylint: disable=g-direct-tensorflow-import
25
+ from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
26
+
27
+ DEFAULT_SIGNATURE_KEY = "serving_default"
28
+
29
+
30
+ def create_tfl_interpreter(
31
+ tflite_model: Union[str, bytes],
32
+ allocate_tensors: bool = True,
33
+ use_xnnpack: bool = True,
34
+ num_threads: int = 16,
35
+ ) -> tfl.Interpreter:
36
+ """Creates a TFLite interpreter from a model file.
37
+
38
+ Args:
39
+ tflite_model: Model file path or bytes.
40
+ allocate_tensors: Whether to allocate tensors.
41
+ use_xnnpack: Whether to use the XNNPACK delegate for the interpreter.
42
+ num_threads: The number of threads to use for the interpreter.
43
+
44
+ Returns:
45
+ A TFLite interpreter.
46
+ """
47
+ if isinstance(tflite_model, str):
48
+ with gfile.GFile(tflite_model, "rb") as f:
49
+ tflite_model = f.read()
50
+
51
+ if use_xnnpack:
52
+ op_resolver = tfl.OpResolverType.BUILTIN
53
+ else:
54
+ op_resolver = tfl.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES
55
+ tflite_interpreter = tfl.Interpreter(
56
+ model_content=bytes(tflite_model),
57
+ num_threads=num_threads,
58
+ experimental_op_resolver_type=op_resolver,
59
+ experimental_preserve_all_tensors=True,
60
+ )
61
+ if allocate_tensors:
62
+ tflite_interpreter.allocate_tensors()
63
+ return tflite_interpreter
64
+
65
+
66
+ def is_tensor_quantized(tensor_detail: dict[str, Any]) -> bool:
67
+ """Checks if a tensor is quantized.
68
+
69
+ Args:
70
+ tensor_detail: A dictionary of tensor details.
71
+
72
+ Returns:
73
+ True if the tensor is quantized.
74
+ """
75
+ quant_params = tensor_detail["quantization_parameters"]
76
+ return bool(len(quant_params["scales"]))
77
+
78
+
79
+ def invoke_interpreter_signature(
80
+ tflite_interpreter: tfl.Interpreter,
81
+ signature_input_data: dict[str, Any],
82
+ signature_key: Optional[str] = None,
83
+ quantize_input: bool = True,
84
+ ) -> dict[str, np.ndarray]:
85
+ """Invokes the TFLite interpreter through signature runner.
86
+
87
+ Args:
88
+ tflite_interpreter: A TFLite interpreter.
89
+ signature_input_data: The input data for the signature.
90
+ signature_key: The signature key.
91
+ quantize_input: Whether to quantize the input data.
92
+
93
+ Returns:
94
+ The output data of the signature.
95
+ """
96
+ # Make a copy to avoid in-place modification.
97
+ signature_input = signature_input_data.copy()
98
+ signature_runner = tflite_interpreter.get_signature_runner(signature_key)
99
+ for input_name, input_detail in signature_runner.get_input_details().items():
100
+ if is_tensor_quantized(input_detail) and quantize_input:
101
+ input_data = signature_input[input_name]
102
+ quant_params = qtyping.UniformQuantParams.from_tfl_tensor_details(
103
+ input_detail
104
+ )
105
+ signature_input[input_name] = uniform_quantize_tensor.uniform_quantize(
106
+ input_data, quant_params
107
+ )
108
+ return signature_runner(**signature_input)
109
+
110
+
111
+ def invoke_interpreter_once(
112
+ tflite_interpreter: tfl.Interpreter,
113
+ input_data_list: list[Any],
114
+ quantize_input: bool = True,
115
+ ):
116
+ """Invokes the TFLite interpreter once.
117
+
118
+ Args:
119
+ tflite_interpreter: A TFLite interpreter.
120
+ input_data_list: A list of input data.
121
+ quantize_input: Whether to quantize the input data.
122
+ """
123
+ if len(input_data_list) != len(tflite_interpreter.get_input_details()):
124
+ raise ValueError(
125
+ "Input data must be a list with each element match the input sequence"
126
+ " defined in .tflite. If the model has only one input, wrap it with a"
127
+ " list (e.g., [input_data])"
128
+ )
129
+ for i, input_data in enumerate(input_data_list):
130
+ input_details = tflite_interpreter.get_input_details()[i]
131
+ if is_tensor_quantized(input_details) and quantize_input:
132
+ quant_params = qtyping.UniformQuantParams.from_tfl_tensor_details(
133
+ input_details
134
+ )
135
+ input_data = uniform_quantize_tensor.uniform_quantize(
136
+ input_data, quant_params
137
+ )
138
+ tflite_interpreter.set_tensor(input_details["index"], input_data)
139
+ tflite_interpreter.invoke()
140
+
141
+
142
+ def get_tensor_data(
143
+ tflite_interpreter: Any,
144
+ tensor_detail: dict[str, Any],
145
+ subgraph_index: int = 0,
146
+ dequantize: bool = True,
147
+ ) -> np.ndarray:
148
+ """Gets the tensor data from a TFLite interpreter.
149
+
150
+ Args:
151
+ tflite_interpreter: A TFLite interpreter.
152
+ tensor_detail: A dictionary of tensor details.
153
+ subgraph_index: The index of the subgraph that the tensor belongs to.
154
+ dequantize: Whether to dequantize the quantized tensor data.
155
+
156
+ Returns:
157
+ The tensor data.
158
+ """
159
+ tensor_data = tflite_interpreter.get_tensor(
160
+ tensor_detail["index"], subgraph_index
161
+ )
162
+ if is_tensor_quantized(tensor_detail) and dequantize:
163
+ quant_params = qtyping.UniformQuantParams.from_tfl_tensor_details(
164
+ tensor_detail
165
+ )
166
+ tensor_data = uniform_quantize_tensor.uniform_dequantize(
167
+ tensor_data,
168
+ quant_params,
169
+ )
170
+ return tensor_data
171
+
172
+
173
+ def get_tensor_name_to_content_map(
174
+ tflite_interpreter: Any, subgraph_index: int = 0, dequantize: bool = False
175
+ ) -> dict[str, Any]:
176
+ """Gets internal tensors from a TFLite interpreter for a given subgraph.
177
+
178
+ Note the data will be copied to the returned dictionary, increasing the
179
+ memory usage.
180
+
181
+ Args:
182
+ tflite_interpreter: A TFLite interpreter.
183
+ subgraph_index: The index of the subgraph that the tensor belongs to.
184
+ dequantize: Whether to dequantize the tensor data.
185
+
186
+ Returns:
187
+ A dictionary of internal tensors.
188
+ """
189
+ tensors = {}
190
+ for tensor_detail in tflite_interpreter.get_tensor_details(subgraph_index):
191
+ # Don't return temporary, unnamed tensors
192
+ if not tensor_detail["name"]:
193
+ continue
194
+ tensors[tensor_detail["name"]] = get_tensor_data(
195
+ tflite_interpreter, tensor_detail, subgraph_index, dequantize
196
+ )
197
+ return tensors
198
+
199
+
200
+ def get_tensor_name_to_details_map(
201
+ tflite_interpreter: Any, subgraph_index: int = 0
202
+ ) -> dict[str, Any]:
203
+ """Gets internal tensors from a TFLite interpreter for a given subgraph.
204
+
205
+ Args:
206
+ tflite_interpreter: A TFLite interpreter.
207
+ subgraph_index: The index of the subgraph that the tensor belongs to.
208
+
209
+ Returns:
210
+ A dictionary of internal tensors.
211
+ """
212
+ tensor_name_to_detail = {}
213
+ for tensor_detail in tflite_interpreter.get_tensor_details(subgraph_index):
214
+ # Don't return temporary, unnamed tensors
215
+ if not tensor_detail["name"]:
216
+ continue
217
+ tensor_name_to_detail[tensor_detail["name"]] = tensor_detail
218
+ return tensor_name_to_detail
219
+
220
+
221
+ def get_input_tensor_names(
222
+ tflite_model: Union[str, bytes], signature_name: Optional[str] = None
223
+ ) -> list[str]:
224
+ """Gets input tensor names from a TFLite model for a signature.
225
+
226
+ Args:
227
+ tflite_model: Model file path or bytes.
228
+ signature_name: The signature name that the input tensors belong to.
229
+
230
+ Returns:
231
+ A list of input tensor names.
232
+ """
233
+
234
+ tfl_interpreter = create_tfl_interpreter(tflite_model, allocate_tensors=False)
235
+ signature_runner = tfl_interpreter.get_signature_runner(signature_name)
236
+ input_tensor_names = []
237
+ for _, input_detail in signature_runner.get_input_details().items():
238
+ input_tensor_names.append(input_detail["name"])
239
+ return input_tensor_names
240
+
241
+
242
+ def get_output_tensor_names(
243
+ tflite_model: Union[str, bytes], signature_name: Optional[str] = None
244
+ ) -> list[str]:
245
+ """Gets output tensor names from a TFLite model for a signature.
246
+
247
+ Args:
248
+ tflite_model: Model file path or bytes.
249
+ signature_name: The signature name that the output tensors belong to.
250
+
251
+ Returns:
252
+ A list of output tensor names.
253
+ """
254
+ tfl_interpreter = create_tfl_interpreter(tflite_model, allocate_tensors=False)
255
+ signature_runner = tfl_interpreter.get_signature_runner(signature_name)
256
+ output_tensor_names = []
257
+ for _, output_detail in signature_runner.get_output_details().items():
258
+ output_tensor_names.append(output_detail["name"])
259
+ return output_tensor_names
260
+
261
+
262
+ def get_constant_tensor_names(
263
+ tflite_model: Union[str, bytes],
264
+ subgraph_index: int = 0,
265
+ min_constant_size: int = 1,
266
+ ) -> list[str]:
267
+ """Gets constant tensor names from a TFLite model for a subgraph.
268
+
269
+ Note that this function acts on subgraph level, not signature level. This is
270
+ because it is non-trivial to track constant tensors for a signature without
271
+ running it.
272
+
273
+ Args:
274
+ tflite_model: Model file path or bytes.
275
+ subgraph_index: The index of the subgraph that the tensor belongs to.
276
+ min_constant_size: The minimum size of a constant tensor.
277
+
278
+ Returns:
279
+ A list of names for constant tensor that bigger than min_constant_size and a
280
+ list of names for constant tensor that smaller than min_constant_size.
281
+ """
282
+ tfl_interpreter = create_tfl_interpreter(tflite_model, allocate_tensors=False)
283
+ const_tensor_names = []
284
+ for tensor_detail in tfl_interpreter.get_tensor_details(subgraph_index):
285
+ if tensor_detail["dtype"] == np.object_:
286
+ continue
287
+ try:
288
+ tensor_data = get_tensor_data(
289
+ tfl_interpreter, tensor_detail, subgraph_index
290
+ )
291
+ if tensor_data.size >= min_constant_size:
292
+ const_tensor_names.append(tensor_detail["name"])
293
+ except ValueError:
294
+ continue
295
+ return const_tensor_names
296
+
297
+
298
+ def get_signature_main_subgraph_index(
299
+ tflite_interpreter: tfl.Interpreter,
300
+ signature_key: Optional[str] = None,
301
+ ) -> int:
302
+ """Gets the main subgraph index of a signature.
303
+
304
+ Args:
305
+ tflite_interpreter: A TFLite interpreter.
306
+ signature_key: The signature key.
307
+
308
+ Returns:
309
+ The main subgraph index of the signature.
310
+ """
311
+ signature_runner = tflite_interpreter.get_signature_runner(signature_key)
312
+ return signature_runner._subgraph_index # pylint:disable=protected-access
@@ -0,0 +1,332 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+ import numpy as np
18
+ from tensorflow.python.platform import googletest
19
+ from ai_edge_quantizer.utils import test_utils
20
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
21
+
22
+
23
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../tests/models")
24
+
25
+
26
+ class TflUtilsSingleSignatureModelTest(googletest.TestCase):
27
+
28
+ def setUp(self):
29
+ super().setUp()
30
+ np.random.seed(0)
31
+ self._test_model_path = os.path.join(
32
+ TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite"
33
+ )
34
+ self._input_data = np.random.rand(1, 28, 28, 1).astype(np.float32)
35
+
36
+ def test_create_tfl_interpreter(self):
37
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
38
+ self._test_model_path
39
+ )
40
+ self.assertIsNotNone(tfl_interpreter)
41
+
42
+ def test_invoke_interpreter_once(self):
43
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
44
+ self._test_model_path
45
+ )
46
+ tfl_interpreter_utils.invoke_interpreter_once(
47
+ tfl_interpreter, [self._input_data]
48
+ )
49
+ output_details = tfl_interpreter.get_output_details()[0]
50
+ output_data = tfl_interpreter.get_tensor(output_details["index"])
51
+ self.assertIsNotNone(output_data)
52
+ self.assertEqual(tuple(output_data.shape), (1, 10))
53
+ self.assertAlmostEqual(output_data[0][0], 0.0031010755)
54
+
55
+ def test_get_tensor_data(self):
56
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
57
+ self._test_model_path
58
+ )
59
+ tfl_interpreter_utils.invoke_interpreter_once(
60
+ tfl_interpreter, [self._input_data]
61
+ )
62
+ output_details = tfl_interpreter.get_output_details()[0]
63
+ output_data = tfl_interpreter_utils.get_tensor_data(
64
+ tfl_interpreter, output_details
65
+ )
66
+ self.assertEqual(tuple(output_data.shape), (1, 10))
67
+
68
+ def test_get_tensor_name_to_content_map(self):
69
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
70
+ self._test_model_path
71
+ )
72
+ tfl_interpreter_utils.invoke_interpreter_once(
73
+ tfl_interpreter, [self._input_data]
74
+ )
75
+
76
+ tensor_name_to_content_map = (
77
+ tfl_interpreter_utils.get_tensor_name_to_content_map(tfl_interpreter)
78
+ )
79
+ input_content = tensor_name_to_content_map["serving_default_conv2d_input:0"]
80
+ self.assertSequenceAlmostEqual(
81
+ self._input_data.flatten(), input_content.flatten()
82
+ )
83
+ weight_content = tensor_name_to_content_map["sequential/conv2d/Conv2D"]
84
+ self.assertEqual(tuple(weight_content.shape), (8, 3, 3, 1))
85
+
86
+ self.assertIn(
87
+ "sequential/average_pooling2d/AvgPool", tensor_name_to_content_map
88
+ )
89
+ average_pool_res = tensor_name_to_content_map[
90
+ "sequential/average_pooling2d/AvgPool"
91
+ ]
92
+ self.assertEqual(tuple(average_pool_res.shape), (1, 14, 14, 8))
93
+
94
+ def test_is_tensor_quantized(self):
95
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
96
+ self._test_model_path
97
+ )
98
+ input_details = tfl_interpreter.get_input_details()[0]
99
+ self.assertFalse(tfl_interpreter_utils.is_tensor_quantized(input_details))
100
+
101
+ def test_get_input_tensor_names(self):
102
+ input_tensor_names = tfl_interpreter_utils.get_input_tensor_names(
103
+ self._test_model_path
104
+ )
105
+ self.assertEqual(
106
+ input_tensor_names,
107
+ ["serving_default_conv2d_input:0"],
108
+ )
109
+
110
+ def test_get_output_tensor_names(self):
111
+ output_tensor_names = tfl_interpreter_utils.get_output_tensor_names(
112
+ self._test_model_path
113
+ )
114
+ self.assertEqual(
115
+ output_tensor_names,
116
+ ["StatefulPartitionedCall:0"],
117
+ )
118
+
119
+ def test_get_constant_tensor_names(self):
120
+ const_tensor_names = tfl_interpreter_utils.get_constant_tensor_names(
121
+ self._test_model_path
122
+ )
123
+ self.assertEqual(
124
+ set(const_tensor_names),
125
+ set([
126
+ "sequential/conv2d/Conv2D",
127
+ "sequential/conv2d/Relu;sequential/conv2d/BiasAdd;sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp",
128
+ "arith.constant",
129
+ "arith.constant1",
130
+ "arith.constant2",
131
+ "arith.constant3",
132
+ ]),
133
+ )
134
+
135
+
136
+ class TflUtilsQuantizedModelTest(googletest.TestCase):
137
+
138
+ def setUp(self):
139
+ super().setUp()
140
+ np.random.seed(0)
141
+ self._test_model_path = os.path.join(
142
+ TEST_DATA_PREFIX_PATH, "conv_fc_mnist_srq_a8w8.tflite"
143
+ )
144
+ self._signature_input_data = {
145
+ "conv2d_input": np.random.rand(1, 28, 28, 1).astype(np.float32)
146
+ }
147
+
148
+ def test_is_tensor_quantized(self):
149
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
150
+ self._test_model_path
151
+ )
152
+ input_details = tfl_interpreter.get_input_details()[0]
153
+ self.assertTrue(tfl_interpreter_utils.is_tensor_quantized(input_details))
154
+
155
+ def test_invoke_interpreter_signature(self):
156
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
157
+ self._test_model_path
158
+ )
159
+ signature_output = tfl_interpreter_utils.invoke_interpreter_signature(
160
+ tfl_interpreter, self._signature_input_data
161
+ )
162
+ print(signature_output)
163
+ self.assertEqual(tuple(signature_output["dense_1"].shape), (1, 10))
164
+
165
+ # Assert the input data is not modified in-place b/353340272.
166
+ self.assertEqual(
167
+ self._signature_input_data["conv2d_input"].dtype, np.float32
168
+ )
169
+
170
+
171
+ class TflUtilsMultiSignatureModelTest(googletest.TestCase):
172
+
173
+ def setUp(self):
174
+ super().setUp()
175
+ np.random.seed(0)
176
+ self._test_model_path = os.path.join(
177
+ TEST_DATA_PREFIX_PATH, "two_signatures.tflite"
178
+ )
179
+ self._signature_input_data = {"x": np.array([2.0]).astype(np.float32)}
180
+
181
+ def test_create_tfl_interpreter(self):
182
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
183
+ self._test_model_path
184
+ )
185
+ self.assertIsNotNone(tfl_interpreter)
186
+
187
+ def test_get_input_tensor_names(self):
188
+ signature_name = "add"
189
+ input_tensor_names = tfl_interpreter_utils.get_input_tensor_names(
190
+ self._test_model_path, signature_name
191
+ )
192
+ self.assertEqual(
193
+ input_tensor_names,
194
+ ["add_x:0"],
195
+ )
196
+
197
+ signature_name = "multiply"
198
+ input_tensor_names = tfl_interpreter_utils.get_input_tensor_names(
199
+ self._test_model_path, signature_name
200
+ )
201
+ self.assertEqual(
202
+ input_tensor_names,
203
+ ["multiply_x:0"],
204
+ )
205
+
206
+ def test_get_output_tensor_names(self):
207
+ signature_name = "add"
208
+ input_tensor_names = tfl_interpreter_utils.get_output_tensor_names(
209
+ self._test_model_path, signature_name
210
+ )
211
+ self.assertEqual(
212
+ input_tensor_names,
213
+ ["PartitionedCall:0"],
214
+ )
215
+
216
+ signature_name = "multiply"
217
+ input_tensor_names = tfl_interpreter_utils.get_output_tensor_names(
218
+ self._test_model_path, signature_name
219
+ )
220
+ self.assertEqual(
221
+ input_tensor_names,
222
+ ["PartitionedCall_1:0"],
223
+ )
224
+
225
+ def test_get_constant_tensor_names(self):
226
+ subgraph0_const_tensor_names = (
227
+ tfl_interpreter_utils.get_constant_tensor_names(
228
+ self._test_model_path, 0
229
+ )
230
+ )
231
+ self.assertEqual(subgraph0_const_tensor_names, ["Add/y"])
232
+
233
+ subgraph1_const_tensor_names = (
234
+ tfl_interpreter_utils.get_constant_tensor_names(
235
+ self._test_model_path, 1
236
+ )
237
+ )
238
+ self.assertEqual(subgraph1_const_tensor_names, ["Mul/y"])
239
+
240
+ def test_get_signature_main_subgraph_index(self):
241
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
242
+ self._test_model_path
243
+ )
244
+ add_subgraph_index = (
245
+ tfl_interpreter_utils.get_signature_main_subgraph_index(
246
+ tfl_interpreter, "add"
247
+ )
248
+ )
249
+ self.assertEqual(add_subgraph_index, 0)
250
+ multiply_subgraph_index = (
251
+ tfl_interpreter_utils.get_signature_main_subgraph_index(
252
+ tfl_interpreter, "multiply"
253
+ )
254
+ )
255
+ self.assertEqual(multiply_subgraph_index, 1)
256
+
257
+ def test_get_tensor_data(self):
258
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
259
+ self._test_model_path
260
+ )
261
+ # Invoke the ADD signature.
262
+ tfl_interpreter_utils.invoke_interpreter_signature(
263
+ tfl_interpreter, self._signature_input_data, "add"
264
+ )
265
+ output_details = {"index": 2, "quantization_parameters": {"scales": []}}
266
+ output_data = tfl_interpreter_utils.get_tensor_data(
267
+ tfl_interpreter, output_details, subgraph_index=0
268
+ ) # The ADD signature is in the first subgraph.
269
+ self.assertEqual(output_data, [12.0]) # 10 + 2
270
+
271
+ # Invoke the MULTIPLY signature.
272
+ tfl_interpreter_utils.invoke_interpreter_signature(
273
+ tfl_interpreter, self._signature_input_data, "multiply"
274
+ )
275
+ output_data = tfl_interpreter_utils.get_tensor_data(
276
+ tfl_interpreter, output_details, subgraph_index=1
277
+ ) # The Multiply signature is in the second subgraph.
278
+ self.assertEqual(output_data, [20.0]) # 10 * 2
279
+
280
+ def test_get_tensor_name_to_content_map(self):
281
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
282
+ self._test_model_path
283
+ )
284
+ # Invoke all signatures.
285
+ tfl_interpreter_utils.invoke_interpreter_signature(
286
+ tfl_interpreter, self._signature_input_data, "multiply"
287
+ )
288
+ tfl_interpreter_utils.invoke_interpreter_signature(
289
+ tfl_interpreter, self._signature_input_data, "add"
290
+ )
291
+
292
+ # Test tensors belonging to the ADD signature.
293
+ add_subgraph_index = (
294
+ tfl_interpreter_utils.get_signature_main_subgraph_index(
295
+ tfl_interpreter, "add"
296
+ )
297
+ )
298
+ add_tensor_content = tfl_interpreter_utils.get_tensor_name_to_content_map(
299
+ tfl_interpreter, add_subgraph_index
300
+ )
301
+
302
+ add_input_content = add_tensor_content["add_x:0"]
303
+ self.assertSequenceAlmostEqual(
304
+ self._signature_input_data["x"].flatten(), add_input_content.flatten()
305
+ )
306
+ weight_content = add_tensor_content["Add/y"]
307
+ self.assertEqual(weight_content, 10)
308
+ add_output_content = add_tensor_content["PartitionedCall:0"]
309
+ self.assertEqual(add_output_content, [12.0])
310
+
311
+ # Test tensors belonging to the MULTIPLY signature.
312
+ multiply_subgraph_index = (
313
+ tfl_interpreter_utils.get_signature_main_subgraph_index(
314
+ tfl_interpreter, "multiply"
315
+ )
316
+ )
317
+ mul_tensor_content = tfl_interpreter_utils.get_tensor_name_to_content_map(
318
+ tfl_interpreter, multiply_subgraph_index
319
+ )
320
+ multiply_input_content = mul_tensor_content["multiply_x:0"]
321
+ self.assertSequenceAlmostEqual(
322
+ self._signature_input_data["x"].flatten(),
323
+ multiply_input_content.flatten(),
324
+ )
325
+ weight_content = mul_tensor_content["Mul/y"]
326
+ self.assertEqual(weight_content, 10)
327
+ multiply_output_content = mul_tensor_content["PartitionedCall_1:0"]
328
+ self.assertEqual(multiply_output_content, [20.0])
329
+
330
+
331
+ if __name__ == "__main__":
332
+ googletest.main()