ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,312 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Util functions for TFL interpreter."""
|
17
|
+
|
18
|
+
from typing import Any, Optional, Union
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
|
22
|
+
from ai_edge_quantizer import qtyping
|
23
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
|
24
|
+
from ai_edge_litert import interpreter as tfl # pylint: disable=g-direct-tensorflow-import
|
25
|
+
from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
|
26
|
+
|
27
|
+
DEFAULT_SIGNATURE_KEY = "serving_default"
|
28
|
+
|
29
|
+
|
30
|
+
def create_tfl_interpreter(
|
31
|
+
tflite_model: Union[str, bytes],
|
32
|
+
allocate_tensors: bool = True,
|
33
|
+
use_xnnpack: bool = True,
|
34
|
+
num_threads: int = 16,
|
35
|
+
) -> tfl.Interpreter:
|
36
|
+
"""Creates a TFLite interpreter from a model file.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
tflite_model: Model file path or bytes.
|
40
|
+
allocate_tensors: Whether to allocate tensors.
|
41
|
+
use_xnnpack: Whether to use the XNNPACK delegate for the interpreter.
|
42
|
+
num_threads: The number of threads to use for the interpreter.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
A TFLite interpreter.
|
46
|
+
"""
|
47
|
+
if isinstance(tflite_model, str):
|
48
|
+
with gfile.GFile(tflite_model, "rb") as f:
|
49
|
+
tflite_model = f.read()
|
50
|
+
|
51
|
+
if use_xnnpack:
|
52
|
+
op_resolver = tfl.OpResolverType.BUILTIN
|
53
|
+
else:
|
54
|
+
op_resolver = tfl.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES
|
55
|
+
tflite_interpreter = tfl.Interpreter(
|
56
|
+
model_content=bytes(tflite_model),
|
57
|
+
num_threads=num_threads,
|
58
|
+
experimental_op_resolver_type=op_resolver,
|
59
|
+
experimental_preserve_all_tensors=True,
|
60
|
+
)
|
61
|
+
if allocate_tensors:
|
62
|
+
tflite_interpreter.allocate_tensors()
|
63
|
+
return tflite_interpreter
|
64
|
+
|
65
|
+
|
66
|
+
def is_tensor_quantized(tensor_detail: dict[str, Any]) -> bool:
|
67
|
+
"""Checks if a tensor is quantized.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
tensor_detail: A dictionary of tensor details.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
True if the tensor is quantized.
|
74
|
+
"""
|
75
|
+
quant_params = tensor_detail["quantization_parameters"]
|
76
|
+
return bool(len(quant_params["scales"]))
|
77
|
+
|
78
|
+
|
79
|
+
def invoke_interpreter_signature(
|
80
|
+
tflite_interpreter: tfl.Interpreter,
|
81
|
+
signature_input_data: dict[str, Any],
|
82
|
+
signature_key: Optional[str] = None,
|
83
|
+
quantize_input: bool = True,
|
84
|
+
) -> dict[str, np.ndarray]:
|
85
|
+
"""Invokes the TFLite interpreter through signature runner.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
tflite_interpreter: A TFLite interpreter.
|
89
|
+
signature_input_data: The input data for the signature.
|
90
|
+
signature_key: The signature key.
|
91
|
+
quantize_input: Whether to quantize the input data.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
The output data of the signature.
|
95
|
+
"""
|
96
|
+
# Make a copy to avoid in-place modification.
|
97
|
+
signature_input = signature_input_data.copy()
|
98
|
+
signature_runner = tflite_interpreter.get_signature_runner(signature_key)
|
99
|
+
for input_name, input_detail in signature_runner.get_input_details().items():
|
100
|
+
if is_tensor_quantized(input_detail) and quantize_input:
|
101
|
+
input_data = signature_input[input_name]
|
102
|
+
quant_params = qtyping.UniformQuantParams.from_tfl_tensor_details(
|
103
|
+
input_detail
|
104
|
+
)
|
105
|
+
signature_input[input_name] = uniform_quantize_tensor.uniform_quantize(
|
106
|
+
input_data, quant_params
|
107
|
+
)
|
108
|
+
return signature_runner(**signature_input)
|
109
|
+
|
110
|
+
|
111
|
+
def invoke_interpreter_once(
|
112
|
+
tflite_interpreter: tfl.Interpreter,
|
113
|
+
input_data_list: list[Any],
|
114
|
+
quantize_input: bool = True,
|
115
|
+
):
|
116
|
+
"""Invokes the TFLite interpreter once.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
tflite_interpreter: A TFLite interpreter.
|
120
|
+
input_data_list: A list of input data.
|
121
|
+
quantize_input: Whether to quantize the input data.
|
122
|
+
"""
|
123
|
+
if len(input_data_list) != len(tflite_interpreter.get_input_details()):
|
124
|
+
raise ValueError(
|
125
|
+
"Input data must be a list with each element match the input sequence"
|
126
|
+
" defined in .tflite. If the model has only one input, wrap it with a"
|
127
|
+
" list (e.g., [input_data])"
|
128
|
+
)
|
129
|
+
for i, input_data in enumerate(input_data_list):
|
130
|
+
input_details = tflite_interpreter.get_input_details()[i]
|
131
|
+
if is_tensor_quantized(input_details) and quantize_input:
|
132
|
+
quant_params = qtyping.UniformQuantParams.from_tfl_tensor_details(
|
133
|
+
input_details
|
134
|
+
)
|
135
|
+
input_data = uniform_quantize_tensor.uniform_quantize(
|
136
|
+
input_data, quant_params
|
137
|
+
)
|
138
|
+
tflite_interpreter.set_tensor(input_details["index"], input_data)
|
139
|
+
tflite_interpreter.invoke()
|
140
|
+
|
141
|
+
|
142
|
+
def get_tensor_data(
|
143
|
+
tflite_interpreter: Any,
|
144
|
+
tensor_detail: dict[str, Any],
|
145
|
+
subgraph_index: int = 0,
|
146
|
+
dequantize: bool = True,
|
147
|
+
) -> np.ndarray:
|
148
|
+
"""Gets the tensor data from a TFLite interpreter.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
tflite_interpreter: A TFLite interpreter.
|
152
|
+
tensor_detail: A dictionary of tensor details.
|
153
|
+
subgraph_index: The index of the subgraph that the tensor belongs to.
|
154
|
+
dequantize: Whether to dequantize the quantized tensor data.
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
The tensor data.
|
158
|
+
"""
|
159
|
+
tensor_data = tflite_interpreter.get_tensor(
|
160
|
+
tensor_detail["index"], subgraph_index
|
161
|
+
)
|
162
|
+
if is_tensor_quantized(tensor_detail) and dequantize:
|
163
|
+
quant_params = qtyping.UniformQuantParams.from_tfl_tensor_details(
|
164
|
+
tensor_detail
|
165
|
+
)
|
166
|
+
tensor_data = uniform_quantize_tensor.uniform_dequantize(
|
167
|
+
tensor_data,
|
168
|
+
quant_params,
|
169
|
+
)
|
170
|
+
return tensor_data
|
171
|
+
|
172
|
+
|
173
|
+
def get_tensor_name_to_content_map(
|
174
|
+
tflite_interpreter: Any, subgraph_index: int = 0, dequantize: bool = False
|
175
|
+
) -> dict[str, Any]:
|
176
|
+
"""Gets internal tensors from a TFLite interpreter for a given subgraph.
|
177
|
+
|
178
|
+
Note the data will be copied to the returned dictionary, increasing the
|
179
|
+
memory usage.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
tflite_interpreter: A TFLite interpreter.
|
183
|
+
subgraph_index: The index of the subgraph that the tensor belongs to.
|
184
|
+
dequantize: Whether to dequantize the tensor data.
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
A dictionary of internal tensors.
|
188
|
+
"""
|
189
|
+
tensors = {}
|
190
|
+
for tensor_detail in tflite_interpreter.get_tensor_details(subgraph_index):
|
191
|
+
# Don't return temporary, unnamed tensors
|
192
|
+
if not tensor_detail["name"]:
|
193
|
+
continue
|
194
|
+
tensors[tensor_detail["name"]] = get_tensor_data(
|
195
|
+
tflite_interpreter, tensor_detail, subgraph_index, dequantize
|
196
|
+
)
|
197
|
+
return tensors
|
198
|
+
|
199
|
+
|
200
|
+
def get_tensor_name_to_details_map(
|
201
|
+
tflite_interpreter: Any, subgraph_index: int = 0
|
202
|
+
) -> dict[str, Any]:
|
203
|
+
"""Gets internal tensors from a TFLite interpreter for a given subgraph.
|
204
|
+
|
205
|
+
Args:
|
206
|
+
tflite_interpreter: A TFLite interpreter.
|
207
|
+
subgraph_index: The index of the subgraph that the tensor belongs to.
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
A dictionary of internal tensors.
|
211
|
+
"""
|
212
|
+
tensor_name_to_detail = {}
|
213
|
+
for tensor_detail in tflite_interpreter.get_tensor_details(subgraph_index):
|
214
|
+
# Don't return temporary, unnamed tensors
|
215
|
+
if not tensor_detail["name"]:
|
216
|
+
continue
|
217
|
+
tensor_name_to_detail[tensor_detail["name"]] = tensor_detail
|
218
|
+
return tensor_name_to_detail
|
219
|
+
|
220
|
+
|
221
|
+
def get_input_tensor_names(
|
222
|
+
tflite_model: Union[str, bytes], signature_name: Optional[str] = None
|
223
|
+
) -> list[str]:
|
224
|
+
"""Gets input tensor names from a TFLite model for a signature.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
tflite_model: Model file path or bytes.
|
228
|
+
signature_name: The signature name that the input tensors belong to.
|
229
|
+
|
230
|
+
Returns:
|
231
|
+
A list of input tensor names.
|
232
|
+
"""
|
233
|
+
|
234
|
+
tfl_interpreter = create_tfl_interpreter(tflite_model, allocate_tensors=False)
|
235
|
+
signature_runner = tfl_interpreter.get_signature_runner(signature_name)
|
236
|
+
input_tensor_names = []
|
237
|
+
for _, input_detail in signature_runner.get_input_details().items():
|
238
|
+
input_tensor_names.append(input_detail["name"])
|
239
|
+
return input_tensor_names
|
240
|
+
|
241
|
+
|
242
|
+
def get_output_tensor_names(
|
243
|
+
tflite_model: Union[str, bytes], signature_name: Optional[str] = None
|
244
|
+
) -> list[str]:
|
245
|
+
"""Gets output tensor names from a TFLite model for a signature.
|
246
|
+
|
247
|
+
Args:
|
248
|
+
tflite_model: Model file path or bytes.
|
249
|
+
signature_name: The signature name that the output tensors belong to.
|
250
|
+
|
251
|
+
Returns:
|
252
|
+
A list of output tensor names.
|
253
|
+
"""
|
254
|
+
tfl_interpreter = create_tfl_interpreter(tflite_model, allocate_tensors=False)
|
255
|
+
signature_runner = tfl_interpreter.get_signature_runner(signature_name)
|
256
|
+
output_tensor_names = []
|
257
|
+
for _, output_detail in signature_runner.get_output_details().items():
|
258
|
+
output_tensor_names.append(output_detail["name"])
|
259
|
+
return output_tensor_names
|
260
|
+
|
261
|
+
|
262
|
+
def get_constant_tensor_names(
|
263
|
+
tflite_model: Union[str, bytes],
|
264
|
+
subgraph_index: int = 0,
|
265
|
+
min_constant_size: int = 1,
|
266
|
+
) -> list[str]:
|
267
|
+
"""Gets constant tensor names from a TFLite model for a subgraph.
|
268
|
+
|
269
|
+
Note that this function acts on subgraph level, not signature level. This is
|
270
|
+
because it is non-trivial to track constant tensors for a signature without
|
271
|
+
running it.
|
272
|
+
|
273
|
+
Args:
|
274
|
+
tflite_model: Model file path or bytes.
|
275
|
+
subgraph_index: The index of the subgraph that the tensor belongs to.
|
276
|
+
min_constant_size: The minimum size of a constant tensor.
|
277
|
+
|
278
|
+
Returns:
|
279
|
+
A list of names for constant tensor that bigger than min_constant_size and a
|
280
|
+
list of names for constant tensor that smaller than min_constant_size.
|
281
|
+
"""
|
282
|
+
tfl_interpreter = create_tfl_interpreter(tflite_model, allocate_tensors=False)
|
283
|
+
const_tensor_names = []
|
284
|
+
for tensor_detail in tfl_interpreter.get_tensor_details(subgraph_index):
|
285
|
+
if tensor_detail["dtype"] == np.object_:
|
286
|
+
continue
|
287
|
+
try:
|
288
|
+
tensor_data = get_tensor_data(
|
289
|
+
tfl_interpreter, tensor_detail, subgraph_index
|
290
|
+
)
|
291
|
+
if tensor_data.size >= min_constant_size:
|
292
|
+
const_tensor_names.append(tensor_detail["name"])
|
293
|
+
except ValueError:
|
294
|
+
continue
|
295
|
+
return const_tensor_names
|
296
|
+
|
297
|
+
|
298
|
+
def get_signature_main_subgraph_index(
|
299
|
+
tflite_interpreter: tfl.Interpreter,
|
300
|
+
signature_key: Optional[str] = None,
|
301
|
+
) -> int:
|
302
|
+
"""Gets the main subgraph index of a signature.
|
303
|
+
|
304
|
+
Args:
|
305
|
+
tflite_interpreter: A TFLite interpreter.
|
306
|
+
signature_key: The signature key.
|
307
|
+
|
308
|
+
Returns:
|
309
|
+
The main subgraph index of the signature.
|
310
|
+
"""
|
311
|
+
signature_runner = tflite_interpreter.get_signature_runner(signature_key)
|
312
|
+
return signature_runner._subgraph_index # pylint:disable=protected-access
|
@@ -0,0 +1,332 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import os
|
17
|
+
import numpy as np
|
18
|
+
from tensorflow.python.platform import googletest
|
19
|
+
from ai_edge_quantizer.utils import test_utils
|
20
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
21
|
+
|
22
|
+
|
23
|
+
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../tests/models")
|
24
|
+
|
25
|
+
|
26
|
+
class TflUtilsSingleSignatureModelTest(googletest.TestCase):
|
27
|
+
|
28
|
+
def setUp(self):
|
29
|
+
super().setUp()
|
30
|
+
np.random.seed(0)
|
31
|
+
self._test_model_path = os.path.join(
|
32
|
+
TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite"
|
33
|
+
)
|
34
|
+
self._input_data = np.random.rand(1, 28, 28, 1).astype(np.float32)
|
35
|
+
|
36
|
+
def test_create_tfl_interpreter(self):
|
37
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
38
|
+
self._test_model_path
|
39
|
+
)
|
40
|
+
self.assertIsNotNone(tfl_interpreter)
|
41
|
+
|
42
|
+
def test_invoke_interpreter_once(self):
|
43
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
44
|
+
self._test_model_path
|
45
|
+
)
|
46
|
+
tfl_interpreter_utils.invoke_interpreter_once(
|
47
|
+
tfl_interpreter, [self._input_data]
|
48
|
+
)
|
49
|
+
output_details = tfl_interpreter.get_output_details()[0]
|
50
|
+
output_data = tfl_interpreter.get_tensor(output_details["index"])
|
51
|
+
self.assertIsNotNone(output_data)
|
52
|
+
self.assertEqual(tuple(output_data.shape), (1, 10))
|
53
|
+
self.assertAlmostEqual(output_data[0][0], 0.0031010755)
|
54
|
+
|
55
|
+
def test_get_tensor_data(self):
|
56
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
57
|
+
self._test_model_path
|
58
|
+
)
|
59
|
+
tfl_interpreter_utils.invoke_interpreter_once(
|
60
|
+
tfl_interpreter, [self._input_data]
|
61
|
+
)
|
62
|
+
output_details = tfl_interpreter.get_output_details()[0]
|
63
|
+
output_data = tfl_interpreter_utils.get_tensor_data(
|
64
|
+
tfl_interpreter, output_details
|
65
|
+
)
|
66
|
+
self.assertEqual(tuple(output_data.shape), (1, 10))
|
67
|
+
|
68
|
+
def test_get_tensor_name_to_content_map(self):
|
69
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
70
|
+
self._test_model_path
|
71
|
+
)
|
72
|
+
tfl_interpreter_utils.invoke_interpreter_once(
|
73
|
+
tfl_interpreter, [self._input_data]
|
74
|
+
)
|
75
|
+
|
76
|
+
tensor_name_to_content_map = (
|
77
|
+
tfl_interpreter_utils.get_tensor_name_to_content_map(tfl_interpreter)
|
78
|
+
)
|
79
|
+
input_content = tensor_name_to_content_map["serving_default_conv2d_input:0"]
|
80
|
+
self.assertSequenceAlmostEqual(
|
81
|
+
self._input_data.flatten(), input_content.flatten()
|
82
|
+
)
|
83
|
+
weight_content = tensor_name_to_content_map["sequential/conv2d/Conv2D"]
|
84
|
+
self.assertEqual(tuple(weight_content.shape), (8, 3, 3, 1))
|
85
|
+
|
86
|
+
self.assertIn(
|
87
|
+
"sequential/average_pooling2d/AvgPool", tensor_name_to_content_map
|
88
|
+
)
|
89
|
+
average_pool_res = tensor_name_to_content_map[
|
90
|
+
"sequential/average_pooling2d/AvgPool"
|
91
|
+
]
|
92
|
+
self.assertEqual(tuple(average_pool_res.shape), (1, 14, 14, 8))
|
93
|
+
|
94
|
+
def test_is_tensor_quantized(self):
|
95
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
96
|
+
self._test_model_path
|
97
|
+
)
|
98
|
+
input_details = tfl_interpreter.get_input_details()[0]
|
99
|
+
self.assertFalse(tfl_interpreter_utils.is_tensor_quantized(input_details))
|
100
|
+
|
101
|
+
def test_get_input_tensor_names(self):
|
102
|
+
input_tensor_names = tfl_interpreter_utils.get_input_tensor_names(
|
103
|
+
self._test_model_path
|
104
|
+
)
|
105
|
+
self.assertEqual(
|
106
|
+
input_tensor_names,
|
107
|
+
["serving_default_conv2d_input:0"],
|
108
|
+
)
|
109
|
+
|
110
|
+
def test_get_output_tensor_names(self):
|
111
|
+
output_tensor_names = tfl_interpreter_utils.get_output_tensor_names(
|
112
|
+
self._test_model_path
|
113
|
+
)
|
114
|
+
self.assertEqual(
|
115
|
+
output_tensor_names,
|
116
|
+
["StatefulPartitionedCall:0"],
|
117
|
+
)
|
118
|
+
|
119
|
+
def test_get_constant_tensor_names(self):
|
120
|
+
const_tensor_names = tfl_interpreter_utils.get_constant_tensor_names(
|
121
|
+
self._test_model_path
|
122
|
+
)
|
123
|
+
self.assertEqual(
|
124
|
+
set(const_tensor_names),
|
125
|
+
set([
|
126
|
+
"sequential/conv2d/Conv2D",
|
127
|
+
"sequential/conv2d/Relu;sequential/conv2d/BiasAdd;sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp",
|
128
|
+
"arith.constant",
|
129
|
+
"arith.constant1",
|
130
|
+
"arith.constant2",
|
131
|
+
"arith.constant3",
|
132
|
+
]),
|
133
|
+
)
|
134
|
+
|
135
|
+
|
136
|
+
class TflUtilsQuantizedModelTest(googletest.TestCase):
|
137
|
+
|
138
|
+
def setUp(self):
|
139
|
+
super().setUp()
|
140
|
+
np.random.seed(0)
|
141
|
+
self._test_model_path = os.path.join(
|
142
|
+
TEST_DATA_PREFIX_PATH, "conv_fc_mnist_srq_a8w8.tflite"
|
143
|
+
)
|
144
|
+
self._signature_input_data = {
|
145
|
+
"conv2d_input": np.random.rand(1, 28, 28, 1).astype(np.float32)
|
146
|
+
}
|
147
|
+
|
148
|
+
def test_is_tensor_quantized(self):
|
149
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
150
|
+
self._test_model_path
|
151
|
+
)
|
152
|
+
input_details = tfl_interpreter.get_input_details()[0]
|
153
|
+
self.assertTrue(tfl_interpreter_utils.is_tensor_quantized(input_details))
|
154
|
+
|
155
|
+
def test_invoke_interpreter_signature(self):
|
156
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
157
|
+
self._test_model_path
|
158
|
+
)
|
159
|
+
signature_output = tfl_interpreter_utils.invoke_interpreter_signature(
|
160
|
+
tfl_interpreter, self._signature_input_data
|
161
|
+
)
|
162
|
+
print(signature_output)
|
163
|
+
self.assertEqual(tuple(signature_output["dense_1"].shape), (1, 10))
|
164
|
+
|
165
|
+
# Assert the input data is not modified in-place b/353340272.
|
166
|
+
self.assertEqual(
|
167
|
+
self._signature_input_data["conv2d_input"].dtype, np.float32
|
168
|
+
)
|
169
|
+
|
170
|
+
|
171
|
+
class TflUtilsMultiSignatureModelTest(googletest.TestCase):
|
172
|
+
|
173
|
+
def setUp(self):
|
174
|
+
super().setUp()
|
175
|
+
np.random.seed(0)
|
176
|
+
self._test_model_path = os.path.join(
|
177
|
+
TEST_DATA_PREFIX_PATH, "two_signatures.tflite"
|
178
|
+
)
|
179
|
+
self._signature_input_data = {"x": np.array([2.0]).astype(np.float32)}
|
180
|
+
|
181
|
+
def test_create_tfl_interpreter(self):
|
182
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
183
|
+
self._test_model_path
|
184
|
+
)
|
185
|
+
self.assertIsNotNone(tfl_interpreter)
|
186
|
+
|
187
|
+
def test_get_input_tensor_names(self):
|
188
|
+
signature_name = "add"
|
189
|
+
input_tensor_names = tfl_interpreter_utils.get_input_tensor_names(
|
190
|
+
self._test_model_path, signature_name
|
191
|
+
)
|
192
|
+
self.assertEqual(
|
193
|
+
input_tensor_names,
|
194
|
+
["add_x:0"],
|
195
|
+
)
|
196
|
+
|
197
|
+
signature_name = "multiply"
|
198
|
+
input_tensor_names = tfl_interpreter_utils.get_input_tensor_names(
|
199
|
+
self._test_model_path, signature_name
|
200
|
+
)
|
201
|
+
self.assertEqual(
|
202
|
+
input_tensor_names,
|
203
|
+
["multiply_x:0"],
|
204
|
+
)
|
205
|
+
|
206
|
+
def test_get_output_tensor_names(self):
|
207
|
+
signature_name = "add"
|
208
|
+
input_tensor_names = tfl_interpreter_utils.get_output_tensor_names(
|
209
|
+
self._test_model_path, signature_name
|
210
|
+
)
|
211
|
+
self.assertEqual(
|
212
|
+
input_tensor_names,
|
213
|
+
["PartitionedCall:0"],
|
214
|
+
)
|
215
|
+
|
216
|
+
signature_name = "multiply"
|
217
|
+
input_tensor_names = tfl_interpreter_utils.get_output_tensor_names(
|
218
|
+
self._test_model_path, signature_name
|
219
|
+
)
|
220
|
+
self.assertEqual(
|
221
|
+
input_tensor_names,
|
222
|
+
["PartitionedCall_1:0"],
|
223
|
+
)
|
224
|
+
|
225
|
+
def test_get_constant_tensor_names(self):
|
226
|
+
subgraph0_const_tensor_names = (
|
227
|
+
tfl_interpreter_utils.get_constant_tensor_names(
|
228
|
+
self._test_model_path, 0
|
229
|
+
)
|
230
|
+
)
|
231
|
+
self.assertEqual(subgraph0_const_tensor_names, ["Add/y"])
|
232
|
+
|
233
|
+
subgraph1_const_tensor_names = (
|
234
|
+
tfl_interpreter_utils.get_constant_tensor_names(
|
235
|
+
self._test_model_path, 1
|
236
|
+
)
|
237
|
+
)
|
238
|
+
self.assertEqual(subgraph1_const_tensor_names, ["Mul/y"])
|
239
|
+
|
240
|
+
def test_get_signature_main_subgraph_index(self):
|
241
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
242
|
+
self._test_model_path
|
243
|
+
)
|
244
|
+
add_subgraph_index = (
|
245
|
+
tfl_interpreter_utils.get_signature_main_subgraph_index(
|
246
|
+
tfl_interpreter, "add"
|
247
|
+
)
|
248
|
+
)
|
249
|
+
self.assertEqual(add_subgraph_index, 0)
|
250
|
+
multiply_subgraph_index = (
|
251
|
+
tfl_interpreter_utils.get_signature_main_subgraph_index(
|
252
|
+
tfl_interpreter, "multiply"
|
253
|
+
)
|
254
|
+
)
|
255
|
+
self.assertEqual(multiply_subgraph_index, 1)
|
256
|
+
|
257
|
+
def test_get_tensor_data(self):
|
258
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
259
|
+
self._test_model_path
|
260
|
+
)
|
261
|
+
# Invoke the ADD signature.
|
262
|
+
tfl_interpreter_utils.invoke_interpreter_signature(
|
263
|
+
tfl_interpreter, self._signature_input_data, "add"
|
264
|
+
)
|
265
|
+
output_details = {"index": 2, "quantization_parameters": {"scales": []}}
|
266
|
+
output_data = tfl_interpreter_utils.get_tensor_data(
|
267
|
+
tfl_interpreter, output_details, subgraph_index=0
|
268
|
+
) # The ADD signature is in the first subgraph.
|
269
|
+
self.assertEqual(output_data, [12.0]) # 10 + 2
|
270
|
+
|
271
|
+
# Invoke the MULTIPLY signature.
|
272
|
+
tfl_interpreter_utils.invoke_interpreter_signature(
|
273
|
+
tfl_interpreter, self._signature_input_data, "multiply"
|
274
|
+
)
|
275
|
+
output_data = tfl_interpreter_utils.get_tensor_data(
|
276
|
+
tfl_interpreter, output_details, subgraph_index=1
|
277
|
+
) # The Multiply signature is in the second subgraph.
|
278
|
+
self.assertEqual(output_data, [20.0]) # 10 * 2
|
279
|
+
|
280
|
+
def test_get_tensor_name_to_content_map(self):
|
281
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
282
|
+
self._test_model_path
|
283
|
+
)
|
284
|
+
# Invoke all signatures.
|
285
|
+
tfl_interpreter_utils.invoke_interpreter_signature(
|
286
|
+
tfl_interpreter, self._signature_input_data, "multiply"
|
287
|
+
)
|
288
|
+
tfl_interpreter_utils.invoke_interpreter_signature(
|
289
|
+
tfl_interpreter, self._signature_input_data, "add"
|
290
|
+
)
|
291
|
+
|
292
|
+
# Test tensors belonging to the ADD signature.
|
293
|
+
add_subgraph_index = (
|
294
|
+
tfl_interpreter_utils.get_signature_main_subgraph_index(
|
295
|
+
tfl_interpreter, "add"
|
296
|
+
)
|
297
|
+
)
|
298
|
+
add_tensor_content = tfl_interpreter_utils.get_tensor_name_to_content_map(
|
299
|
+
tfl_interpreter, add_subgraph_index
|
300
|
+
)
|
301
|
+
|
302
|
+
add_input_content = add_tensor_content["add_x:0"]
|
303
|
+
self.assertSequenceAlmostEqual(
|
304
|
+
self._signature_input_data["x"].flatten(), add_input_content.flatten()
|
305
|
+
)
|
306
|
+
weight_content = add_tensor_content["Add/y"]
|
307
|
+
self.assertEqual(weight_content, 10)
|
308
|
+
add_output_content = add_tensor_content["PartitionedCall:0"]
|
309
|
+
self.assertEqual(add_output_content, [12.0])
|
310
|
+
|
311
|
+
# Test tensors belonging to the MULTIPLY signature.
|
312
|
+
multiply_subgraph_index = (
|
313
|
+
tfl_interpreter_utils.get_signature_main_subgraph_index(
|
314
|
+
tfl_interpreter, "multiply"
|
315
|
+
)
|
316
|
+
)
|
317
|
+
mul_tensor_content = tfl_interpreter_utils.get_tensor_name_to_content_map(
|
318
|
+
tfl_interpreter, multiply_subgraph_index
|
319
|
+
)
|
320
|
+
multiply_input_content = mul_tensor_content["multiply_x:0"]
|
321
|
+
self.assertSequenceAlmostEqual(
|
322
|
+
self._signature_input_data["x"].flatten(),
|
323
|
+
multiply_input_content.flatten(),
|
324
|
+
)
|
325
|
+
weight_content = mul_tensor_content["Mul/y"]
|
326
|
+
self.assertEqual(weight_content, 10)
|
327
|
+
multiply_output_content = mul_tensor_content["PartitionedCall_1:0"]
|
328
|
+
self.assertEqual(multiply_output_content, [20.0])
|
329
|
+
|
330
|
+
|
331
|
+
if __name__ == "__main__":
|
332
|
+
googletest.main()
|