ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,317 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""flatbuffer utils for the Quantizer."""
|
17
|
+
|
18
|
+
from typing import Any, Optional, Union
|
19
|
+
|
20
|
+
import immutabledict
|
21
|
+
import numpy as np
|
22
|
+
|
23
|
+
from ai_edge_quantizer import qtyping
|
24
|
+
from ai_edge_litert import schema_py_generated # pylint:disable=g-direct-tensorflow-import
|
25
|
+
from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
|
26
|
+
from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
|
27
|
+
|
28
|
+
_TFLOpName = qtyping.TFLOperationName
|
29
|
+
|
30
|
+
TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
|
31
|
+
_TFLOpName.FULLY_CONNECTED: (
|
32
|
+
schema_py_generated.BuiltinOperator.FULLY_CONNECTED
|
33
|
+
),
|
34
|
+
_TFLOpName.BATCH_MATMUL: schema_py_generated.BuiltinOperator.BATCH_MATMUL,
|
35
|
+
_TFLOpName.CONV_2D: schema_py_generated.BuiltinOperator.CONV_2D,
|
36
|
+
_TFLOpName.DEPTHWISE_CONV_2D: (
|
37
|
+
schema_py_generated.BuiltinOperator.DEPTHWISE_CONV_2D
|
38
|
+
),
|
39
|
+
_TFLOpName.CONV_2D_TRANSPOSE: (
|
40
|
+
schema_py_generated.BuiltinOperator.TRANSPOSE_CONV
|
41
|
+
),
|
42
|
+
_TFLOpName.EMBEDDING_LOOKUP: (
|
43
|
+
schema_py_generated.BuiltinOperator.EMBEDDING_LOOKUP
|
44
|
+
),
|
45
|
+
_TFLOpName.SOFTMAX: schema_py_generated.BuiltinOperator.SOFTMAX,
|
46
|
+
_TFLOpName.AVERAGE_POOL_2D: (
|
47
|
+
schema_py_generated.BuiltinOperator.AVERAGE_POOL_2D
|
48
|
+
),
|
49
|
+
_TFLOpName.RESHAPE: schema_py_generated.BuiltinOperator.RESHAPE,
|
50
|
+
_TFLOpName.TANH: schema_py_generated.BuiltinOperator.TANH,
|
51
|
+
_TFLOpName.TRANSPOSE: schema_py_generated.BuiltinOperator.TRANSPOSE,
|
52
|
+
_TFLOpName.GELU: schema_py_generated.BuiltinOperator.GELU,
|
53
|
+
_TFLOpName.ADD: schema_py_generated.BuiltinOperator.ADD,
|
54
|
+
_TFLOpName.SUB: schema_py_generated.BuiltinOperator.SUB,
|
55
|
+
_TFLOpName.MUL: schema_py_generated.BuiltinOperator.MUL,
|
56
|
+
_TFLOpName.MEAN: schema_py_generated.BuiltinOperator.MEAN,
|
57
|
+
_TFLOpName.RSQRT: schema_py_generated.BuiltinOperator.RSQRT,
|
58
|
+
_TFLOpName.CONCATENATION: schema_py_generated.BuiltinOperator.CONCATENATION,
|
59
|
+
_TFLOpName.STRIDED_SLICE: schema_py_generated.BuiltinOperator.STRIDED_SLICE,
|
60
|
+
_TFLOpName.SPLIT: schema_py_generated.BuiltinOperator.SPLIT,
|
61
|
+
_TFLOpName.LOGISTIC: schema_py_generated.BuiltinOperator.LOGISTIC,
|
62
|
+
_TFLOpName.SLICE: schema_py_generated.BuiltinOperator.SLICE,
|
63
|
+
_TFLOpName.SUM: schema_py_generated.BuiltinOperator.SUM,
|
64
|
+
_TFLOpName.SELECT_V2: schema_py_generated.BuiltinOperator.SELECT_V2,
|
65
|
+
})
|
66
|
+
|
67
|
+
TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
|
68
|
+
dict((reversed(item) for item in TFL_OP_NAME_TO_CODE.items()))
|
69
|
+
)
|
70
|
+
|
71
|
+
# Quantized dimension for per-channel quantization.
|
72
|
+
# See https://www.tensorflow.org/lite/performance/quantization_spec.
|
73
|
+
TFL_OP_TO_WEIGHT_QUANTIZED_DIM = immutabledict.immutabledict({
|
74
|
+
_TFLOpName.FULLY_CONNECTED: 0,
|
75
|
+
_TFLOpName.DEPTHWISE_CONV_2D: 3,
|
76
|
+
_TFLOpName.CONV_2D: 0,
|
77
|
+
_TFLOpName.EMBEDDING_LOOKUP: 0,
|
78
|
+
_TFLOpName.CONV_2D_TRANSPOSE: 0,
|
79
|
+
})
|
80
|
+
|
81
|
+
NUM_TFL_DATATYPES = 18
|
82
|
+
TENSOR_CODE_TO_TYPE = {}
|
83
|
+
for dtype_code in range(NUM_TFL_DATATYPES):
|
84
|
+
TENSOR_CODE_TO_TYPE[dtype_code] = flatbuffer_utils.type_to_name(dtype_code)
|
85
|
+
TENSOR_CODE_TO_TYPE = immutabledict.immutabledict(TENSOR_CODE_TO_TYPE)
|
86
|
+
TENSOR_TYPE_TO_CODE = immutabledict.immutabledict(
|
87
|
+
(reversed(item) for item in TENSOR_CODE_TO_TYPE.items())
|
88
|
+
)
|
89
|
+
|
90
|
+
# Expose functions in tensorflow.lite.tools.flatbuffer_utils
|
91
|
+
write_model = flatbuffer_utils.write_model
|
92
|
+
|
93
|
+
|
94
|
+
def read_model(tflite_model: Union[str, bytearray]) -> Any:
|
95
|
+
"""Read and convert the TFLite model into a flatbuffer object.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
tflite_model: TFLite model path or bytearray.
|
99
|
+
|
100
|
+
Raises:
|
101
|
+
ValueError: Unsupported tflite_model type.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
flatbuffer_model: the flatbuffer_model.
|
105
|
+
"""
|
106
|
+
if isinstance(tflite_model, str):
|
107
|
+
return flatbuffer_utils.read_model(tflite_model)
|
108
|
+
elif isinstance(tflite_model, bytes) or isinstance(tflite_model, bytearray):
|
109
|
+
return flatbuffer_utils.read_model_from_bytearray(tflite_model)
|
110
|
+
else:
|
111
|
+
raise ValueError(
|
112
|
+
"Unsupported tflite_model type: %s" % type(tflite_model).__name__
|
113
|
+
)
|
114
|
+
|
115
|
+
|
116
|
+
def get_model_content(tflite_path: str) -> bytes:
|
117
|
+
"""Get the model content (bytes) from the path.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
tflite_path: Path to the .tflite.
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
The model bytes.
|
124
|
+
"""
|
125
|
+
with gfile.Open(tflite_path, "rb") as tflite_file:
|
126
|
+
return tflite_file.read()
|
127
|
+
|
128
|
+
|
129
|
+
def get_model_buffer(tflite_path: str) -> bytearray:
|
130
|
+
"""Get the model buffer from the path.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
tflite_path: path to the .tflite.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
model_buffer: the model buffer.
|
137
|
+
"""
|
138
|
+
with gfile.Open(tflite_path, "rb") as tflite_file:
|
139
|
+
return bytearray(tflite_file.read())
|
140
|
+
|
141
|
+
|
142
|
+
def parse_op_tensors(op: Any, subgraph_tensors: list[Any]) -> list[Any]:
|
143
|
+
"""Parse the op tensors.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
op: the op that need to be parsed.
|
147
|
+
subgraph_tensors: list of tensors in the subgraph.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
tensors: list of tensors that are associated with the op.
|
151
|
+
"""
|
152
|
+
|
153
|
+
tensors = []
|
154
|
+
for tensor_idx in list(op.outputs) + list(op.inputs):
|
155
|
+
if tensor_idx != -1:
|
156
|
+
tensors.append(subgraph_tensors[tensor_idx])
|
157
|
+
return tensors
|
158
|
+
|
159
|
+
|
160
|
+
def parse_fc_bmm_conv_tensors(
|
161
|
+
op: Any,
|
162
|
+
subgraph_tensors: list[Any],
|
163
|
+
input_index: int = 0,
|
164
|
+
weight_index: int = 1,
|
165
|
+
bias_index: int = 2,
|
166
|
+
output_index: int = 0,
|
167
|
+
) -> tuple[Any, Any, Any, Any]:
|
168
|
+
"""Parse tensors in FullyConnected, BatchMatmul, and Convolutions.
|
169
|
+
|
170
|
+
Args:
|
171
|
+
op: the TFLite op, must be fully_connected, batch_matmul, or convolution.
|
172
|
+
subgraph_tensors: tensors in the subgraph.
|
173
|
+
input_index: index for the input tensor.
|
174
|
+
weight_index: index for the weight tensor.
|
175
|
+
bias_index: index for the bias tensor.
|
176
|
+
output_index: index for the output tensor.
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
input_tensor, weight_tensor, bias_tensor, output_tensor
|
180
|
+
"""
|
181
|
+
|
182
|
+
input_tensor = subgraph_tensors[op.inputs[input_index]]
|
183
|
+
weight_tensor = subgraph_tensors[op.inputs[weight_index]]
|
184
|
+
bias_tensor = None
|
185
|
+
if bias_index < len(op.inputs) and op.inputs[bias_index] != -1:
|
186
|
+
bias_tensor = subgraph_tensors[op.inputs[bias_index]]
|
187
|
+
output_tensor = subgraph_tensors[op.outputs[output_index]]
|
188
|
+
return input_tensor, weight_tensor, bias_tensor, output_tensor
|
189
|
+
|
190
|
+
|
191
|
+
# flatbuffer_model has Any type since tensorflow.lite.tools.flatbuffer_utils
|
192
|
+
# is not type annotated.
|
193
|
+
def buffer_to_tensors(flatbuffer_model: Any) -> dict[int, list[Any]]:
|
194
|
+
"""Get the buffer to tensor map for a tflite model.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
flatbuffer_model: the flatbuffer_model.
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
buffer_to_tensor_map: key as buffer index, value as list of tensors share
|
201
|
+
the buffer
|
202
|
+
"""
|
203
|
+
buffer_to_tensor_map = {}
|
204
|
+
for subgraph in flatbuffer_model.subgraphs:
|
205
|
+
for op in subgraph.operators:
|
206
|
+
for tensor in parse_op_tensors(op, subgraph.tensors):
|
207
|
+
if tensor.buffer not in buffer_to_tensor_map:
|
208
|
+
buffer_to_tensor_map[tensor.buffer] = []
|
209
|
+
buffer_to_tensor_map[tensor.buffer].append(tensor)
|
210
|
+
return buffer_to_tensor_map
|
211
|
+
|
212
|
+
|
213
|
+
def get_tensor_name(tensor: Any) -> str:
|
214
|
+
"""Get the tensor name for a fb tensor.
|
215
|
+
|
216
|
+
Args:
|
217
|
+
tensor: tensor in flatbuffer.
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
tensor_name: name of the buffer
|
221
|
+
"""
|
222
|
+
return tensor.name.decode("utf-8")
|
223
|
+
|
224
|
+
|
225
|
+
def get_tensor_data(tensor: Any, buffers: list[Any]) -> Optional[np.ndarray]:
|
226
|
+
"""Get the tensor data.
|
227
|
+
|
228
|
+
Args:
|
229
|
+
tensor: tensor in flatbuffer.
|
230
|
+
buffers: list of buffers
|
231
|
+
|
232
|
+
Returns:
|
233
|
+
tensor_data: data inside the tensor
|
234
|
+
"""
|
235
|
+
tensor_buffer = buffers[tensor.buffer]
|
236
|
+
buffer_data = tensor_buffer.data
|
237
|
+
if buffer_data is None:
|
238
|
+
return None
|
239
|
+
data = np.frombuffer(
|
240
|
+
buffer_data, dtype=TENSOR_CODE_TO_TYPE[tensor.type].lower()
|
241
|
+
)
|
242
|
+
data = np.reshape(data, tensor.shape)
|
243
|
+
return data
|
244
|
+
|
245
|
+
|
246
|
+
def has_same_quantization(tensor1: Any, tensor2: Any) -> bool:
|
247
|
+
"""Check if two tensors have the same quantization.
|
248
|
+
|
249
|
+
Args:
|
250
|
+
tensor1: tensor in flatbuffer.
|
251
|
+
tensor2: tensor in flatbuffer.
|
252
|
+
|
253
|
+
Returns:
|
254
|
+
True if two tensors have the same quantization.
|
255
|
+
"""
|
256
|
+
|
257
|
+
def to_tuple(val):
|
258
|
+
if val is None:
|
259
|
+
val = []
|
260
|
+
return tuple(val)
|
261
|
+
|
262
|
+
same_type = tensor1.type == tensor2.type
|
263
|
+
|
264
|
+
# Return True if both tensors are not quantized.
|
265
|
+
if tensor1.quantization is None and tensor2.quantization is None:
|
266
|
+
return True
|
267
|
+
if tensor1.quantization.scale is None and tensor2.quantization.scale is None:
|
268
|
+
return True
|
269
|
+
|
270
|
+
same_scale = to_tuple(tensor1.quantization.scale) == to_tuple(
|
271
|
+
tensor2.quantization.scale
|
272
|
+
)
|
273
|
+
same_zero_point = to_tuple(tensor1.quantization.zeroPoint) == to_tuple(
|
274
|
+
tensor2.quantization.zeroPoint
|
275
|
+
)
|
276
|
+
same_quantized_dimension = (
|
277
|
+
tensor1.quantization.quantizedDimension
|
278
|
+
== tensor2.quantization.quantizedDimension
|
279
|
+
)
|
280
|
+
return (
|
281
|
+
same_type and same_scale and same_zero_point and same_quantized_dimension
|
282
|
+
)
|
283
|
+
|
284
|
+
|
285
|
+
def is_float_model(flatbuffer_model: Any) -> bool:
|
286
|
+
"""Checks that the model is float and not already quantized."""
|
287
|
+
for subgraph in flatbuffer_model.subgraphs:
|
288
|
+
for tensor in subgraph.tensors:
|
289
|
+
if tensor.quantization is None:
|
290
|
+
continue
|
291
|
+
if tensor.quantization.scale is not None:
|
292
|
+
return False
|
293
|
+
return True
|
294
|
+
|
295
|
+
|
296
|
+
def get_subgraph_input_output_operators(
|
297
|
+
subgraph: Any,
|
298
|
+
) -> list[qtyping.IOOperator]:
|
299
|
+
"""Get the input/output operators for the subgraph.
|
300
|
+
|
301
|
+
Args:
|
302
|
+
subgraph: The subgraph object.
|
303
|
+
|
304
|
+
Returns:
|
305
|
+
Input and output operators for the subgraph.
|
306
|
+
"""
|
307
|
+
input_op = qtyping.IOOperator(
|
308
|
+
inputs=[],
|
309
|
+
outputs=subgraph.inputs,
|
310
|
+
op_key=qtyping.TFLOperationName.INPUT,
|
311
|
+
)
|
312
|
+
output_op = qtyping.IOOperator(
|
313
|
+
inputs=subgraph.outputs,
|
314
|
+
outputs=[],
|
315
|
+
op_key=qtyping.TFLOperationName.OUTPUT,
|
316
|
+
)
|
317
|
+
return [input_op, output_op]
|
@@ -0,0 +1,200 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Tests for tfl_flatbuffer_utils.py."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import numpy as np
|
20
|
+
from tensorflow.python.platform import googletest
|
21
|
+
from ai_edge_quantizer import qtyping
|
22
|
+
from ai_edge_quantizer.utils import test_utils
|
23
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
24
|
+
|
25
|
+
|
26
|
+
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../tests/models")
|
27
|
+
|
28
|
+
|
29
|
+
# TODO: b/328830092 - Add test cases for model require buffer offset.
|
30
|
+
class FlatbufferUtilsTest(googletest.TestCase):
|
31
|
+
|
32
|
+
def setUp(self):
|
33
|
+
super().setUp()
|
34
|
+
self._test_model_path = os.path.join(
|
35
|
+
TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite"
|
36
|
+
)
|
37
|
+
|
38
|
+
self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
|
39
|
+
|
40
|
+
def test_get_model_buffer(self):
|
41
|
+
model_buffer = tfl_flatbuffer_utils.get_model_buffer(self._test_model_path)
|
42
|
+
file_stats = os.stat(self._test_model_path)
|
43
|
+
self.assertLen(model_buffer, file_stats.st_size)
|
44
|
+
|
45
|
+
def test_parse_op_tensors(self):
|
46
|
+
subgraph0 = self._test_model.subgraphs[0]
|
47
|
+
conv2d_op = subgraph0.operators[0]
|
48
|
+
op_tensors = tfl_flatbuffer_utils.parse_op_tensors(
|
49
|
+
conv2d_op, subgraph0.tensors
|
50
|
+
)
|
51
|
+
# conv2d have three inputs and one output
|
52
|
+
self.assertLen(op_tensors, 4)
|
53
|
+
|
54
|
+
average_pool_op = subgraph0.operators[1]
|
55
|
+
op_tensors = tfl_flatbuffer_utils.parse_op_tensors(
|
56
|
+
average_pool_op, subgraph0.tensors
|
57
|
+
)
|
58
|
+
# averagepool have one input and one output
|
59
|
+
self.assertLen(op_tensors, 2)
|
60
|
+
|
61
|
+
def test_parse_fc_bmm_conv_tensors(self):
|
62
|
+
subgraph0 = self._test_model.subgraphs[0]
|
63
|
+
conv2d_op = subgraph0.operators[0]
|
64
|
+
inputs, weight, bias, output = (
|
65
|
+
tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
|
66
|
+
conv2d_op, subgraph0.tensors
|
67
|
+
)
|
68
|
+
)
|
69
|
+
self.assertEqual(tuple(inputs.shape), (1, 28, 28, 1))
|
70
|
+
self.assertEqual(tuple(weight.shape), (8, 3, 3, 1))
|
71
|
+
self.assertEqual(tuple(bias.shape), (8,))
|
72
|
+
self.assertEqual(tuple(output.shape), (1, 28, 28, 8))
|
73
|
+
|
74
|
+
fc_with_bias = subgraph0.operators[3]
|
75
|
+
inputs, weight, bias, output = (
|
76
|
+
tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
|
77
|
+
fc_with_bias,
|
78
|
+
subgraph0.tensors,
|
79
|
+
)
|
80
|
+
)
|
81
|
+
self.assertEqual(tuple(inputs.shape), (1, 1568))
|
82
|
+
self.assertEqual(tuple(weight.shape), (32, 1568))
|
83
|
+
self.assertEqual(tuple(bias.shape), (32,))
|
84
|
+
self.assertEqual(tuple(output.shape), (1, 32))
|
85
|
+
|
86
|
+
fc_no_bias = subgraph0.operators[4]
|
87
|
+
inputs, weight, bias, output = (
|
88
|
+
tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
|
89
|
+
fc_no_bias,
|
90
|
+
subgraph0.tensors,
|
91
|
+
)
|
92
|
+
)
|
93
|
+
self.assertEqual(tuple(inputs.shape), (1, 32))
|
94
|
+
self.assertEqual(tuple(weight.shape), (10, 32))
|
95
|
+
self.assertIsNone(bias)
|
96
|
+
self.assertEqual(tuple(output.shape), (1, 10))
|
97
|
+
|
98
|
+
def test_buffer_to_tensors(self):
|
99
|
+
buffer_to_tensor_map = tfl_flatbuffer_utils.buffer_to_tensors(
|
100
|
+
self._test_model
|
101
|
+
)
|
102
|
+
# Read from Netron/Model Explorer
|
103
|
+
tensors = buffer_to_tensor_map[6]
|
104
|
+
self.assertLen(tensors, 1)
|
105
|
+
conv2d_filter_tensor = tensors[0]
|
106
|
+
self.assertEqual(tuple(conv2d_filter_tensor.shape), (8, 3, 3, 1))
|
107
|
+
|
108
|
+
def test_get_tensor_name(self):
|
109
|
+
subgraph0 = self._test_model.subgraphs[0]
|
110
|
+
subgraph_tensors = subgraph0.tensors
|
111
|
+
conv2d_op = subgraph0.operators[0]
|
112
|
+
weight_tensor = subgraph_tensors[conv2d_op.inputs[1]]
|
113
|
+
weight_tensor_name = tfl_flatbuffer_utils.get_tensor_name(weight_tensor)
|
114
|
+
self.assertEqual(weight_tensor_name, "sequential/conv2d/Conv2D")
|
115
|
+
|
116
|
+
# TODO: b/325123193 - test tensor with data outside of flatbuffer.
|
117
|
+
def test_get_tensor_data(self):
|
118
|
+
subgraph0 = self._test_model.subgraphs[0]
|
119
|
+
subgraph_tensors = subgraph0.tensors
|
120
|
+
conv2d_op = subgraph0.operators[0]
|
121
|
+
# Check tensor with data
|
122
|
+
weight_tensor = subgraph_tensors[conv2d_op.inputs[1]]
|
123
|
+
weight_tensor_data = tfl_flatbuffer_utils.get_tensor_data(
|
124
|
+
weight_tensor, self._test_model.buffers
|
125
|
+
)
|
126
|
+
self.assertEqual(
|
127
|
+
tuple(weight_tensor.shape), tuple(weight_tensor_data.shape) # pytype: disable=attribute-error
|
128
|
+
)
|
129
|
+
self.assertAlmostEqual(weight_tensor_data[0][0][0][0], -0.12941549718379974)
|
130
|
+
|
131
|
+
# Check tensor with no data
|
132
|
+
input_tensor = subgraph_tensors[conv2d_op.inputs[0]]
|
133
|
+
input_tensor_data = tfl_flatbuffer_utils.get_tensor_data(
|
134
|
+
input_tensor, self._test_model.buffers
|
135
|
+
)
|
136
|
+
self.assertIsNone(input_tensor_data)
|
137
|
+
|
138
|
+
def test_has_same_quantization_succeeds(self):
|
139
|
+
tensor0, tensor1 = self._test_model.subgraphs[0].tensors[:2]
|
140
|
+
tensor0.quantization.scale = np.array([1, 2, 3]).astype(np.float32)
|
141
|
+
tensor0.quantization.zeroPoint = np.array([3, 2, 1]).astype(np.int32)
|
142
|
+
tensor1.quantization.scale = np.array([1, 2, 3]).astype(np.float32)
|
143
|
+
tensor1.quantization.zeroPoint = np.array([3, 2, 1]).astype(np.int32)
|
144
|
+
self.assertTrue(
|
145
|
+
tfl_flatbuffer_utils.has_same_quantization(tensor0, tensor1)
|
146
|
+
)
|
147
|
+
|
148
|
+
def test_has_same_quantization_succeds_not_quantized(self):
|
149
|
+
tensor0, tensor1 = self._test_model.subgraphs[0].tensors[:2]
|
150
|
+
tensor0.type = 10
|
151
|
+
self.assertTrue(
|
152
|
+
tfl_flatbuffer_utils.has_same_quantization(tensor0, tensor1)
|
153
|
+
)
|
154
|
+
|
155
|
+
def test_has_same_quantization_fails_different_scale(self):
|
156
|
+
tensor0, tensor1 = self._test_model.subgraphs[0].tensors[:2]
|
157
|
+
tensor1.quantization.scale = np.array([1, 2, 3]).astype(np.float32)
|
158
|
+
self.assertFalse(
|
159
|
+
tfl_flatbuffer_utils.has_same_quantization(tensor0, tensor1)
|
160
|
+
)
|
161
|
+
|
162
|
+
def test_has_same_quantization_fails_different_zp(self):
|
163
|
+
tensor0, tensor1 = self._test_model.subgraphs[0].tensors[:2]
|
164
|
+
tensor0.quantization.scale = np.array([1, 2, 3]).astype(np.float32)
|
165
|
+
tensor0.quantization.zeroPoint = np.array([3, 2, 1]).astype(np.int32)
|
166
|
+
tensor1.quantization.scale = np.array([1, 2, 3]).astype(np.float32)
|
167
|
+
tensor1.quantization.zeroPoint = np.array([1, 2, 3]).astype(np.int32)
|
168
|
+
self.assertFalse(
|
169
|
+
tfl_flatbuffer_utils.has_same_quantization(tensor0, tensor1)
|
170
|
+
)
|
171
|
+
|
172
|
+
def test_check_is_float_model_true_when_model_is_float(self):
|
173
|
+
test_model_path = os.path.join(
|
174
|
+
TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite"
|
175
|
+
)
|
176
|
+
model = tfl_flatbuffer_utils.read_model(test_model_path)
|
177
|
+
self.assertTrue(tfl_flatbuffer_utils.is_float_model(model))
|
178
|
+
|
179
|
+
def test_check_is_float_model_false_when_model_is_quantized(self):
|
180
|
+
test_model_path = os.path.join(
|
181
|
+
TEST_DATA_PREFIX_PATH, "mnist_quantized.tflite"
|
182
|
+
)
|
183
|
+
model = tfl_flatbuffer_utils.read_model(test_model_path)
|
184
|
+
self.assertFalse(tfl_flatbuffer_utils.is_float_model(model))
|
185
|
+
|
186
|
+
def test_get_subgraph_input_output_operators(self):
|
187
|
+
subgraph = self._test_model.subgraphs[0]
|
188
|
+
input_op, output_op = (
|
189
|
+
tfl_flatbuffer_utils.get_subgraph_input_output_operators(subgraph)
|
190
|
+
)
|
191
|
+
self.assertEqual(input_op.op_key, qtyping.TFLOperationName.INPUT)
|
192
|
+
self.assertEmpty(input_op.inputs)
|
193
|
+
self.assertListEqual(list(input_op.outputs), [0])
|
194
|
+
self.assertEqual(output_op.op_key, qtyping.TFLOperationName.OUTPUT)
|
195
|
+
self.assertListEqual(list(output_op.inputs), [12])
|
196
|
+
self.assertEmpty(output_op.outputs)
|
197
|
+
|
198
|
+
|
199
|
+
if __name__ == "__main__":
|
200
|
+
googletest.main()
|