ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,317 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """flatbuffer utils for the Quantizer."""
17
+
18
+ from typing import Any, Optional, Union
19
+
20
+ import immutabledict
21
+ import numpy as np
22
+
23
+ from ai_edge_quantizer import qtyping
24
+ from ai_edge_litert import schema_py_generated # pylint:disable=g-direct-tensorflow-import
25
+ from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
26
+ from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
27
+
28
+ _TFLOpName = qtyping.TFLOperationName
29
+
30
+ TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
31
+ _TFLOpName.FULLY_CONNECTED: (
32
+ schema_py_generated.BuiltinOperator.FULLY_CONNECTED
33
+ ),
34
+ _TFLOpName.BATCH_MATMUL: schema_py_generated.BuiltinOperator.BATCH_MATMUL,
35
+ _TFLOpName.CONV_2D: schema_py_generated.BuiltinOperator.CONV_2D,
36
+ _TFLOpName.DEPTHWISE_CONV_2D: (
37
+ schema_py_generated.BuiltinOperator.DEPTHWISE_CONV_2D
38
+ ),
39
+ _TFLOpName.CONV_2D_TRANSPOSE: (
40
+ schema_py_generated.BuiltinOperator.TRANSPOSE_CONV
41
+ ),
42
+ _TFLOpName.EMBEDDING_LOOKUP: (
43
+ schema_py_generated.BuiltinOperator.EMBEDDING_LOOKUP
44
+ ),
45
+ _TFLOpName.SOFTMAX: schema_py_generated.BuiltinOperator.SOFTMAX,
46
+ _TFLOpName.AVERAGE_POOL_2D: (
47
+ schema_py_generated.BuiltinOperator.AVERAGE_POOL_2D
48
+ ),
49
+ _TFLOpName.RESHAPE: schema_py_generated.BuiltinOperator.RESHAPE,
50
+ _TFLOpName.TANH: schema_py_generated.BuiltinOperator.TANH,
51
+ _TFLOpName.TRANSPOSE: schema_py_generated.BuiltinOperator.TRANSPOSE,
52
+ _TFLOpName.GELU: schema_py_generated.BuiltinOperator.GELU,
53
+ _TFLOpName.ADD: schema_py_generated.BuiltinOperator.ADD,
54
+ _TFLOpName.SUB: schema_py_generated.BuiltinOperator.SUB,
55
+ _TFLOpName.MUL: schema_py_generated.BuiltinOperator.MUL,
56
+ _TFLOpName.MEAN: schema_py_generated.BuiltinOperator.MEAN,
57
+ _TFLOpName.RSQRT: schema_py_generated.BuiltinOperator.RSQRT,
58
+ _TFLOpName.CONCATENATION: schema_py_generated.BuiltinOperator.CONCATENATION,
59
+ _TFLOpName.STRIDED_SLICE: schema_py_generated.BuiltinOperator.STRIDED_SLICE,
60
+ _TFLOpName.SPLIT: schema_py_generated.BuiltinOperator.SPLIT,
61
+ _TFLOpName.LOGISTIC: schema_py_generated.BuiltinOperator.LOGISTIC,
62
+ _TFLOpName.SLICE: schema_py_generated.BuiltinOperator.SLICE,
63
+ _TFLOpName.SUM: schema_py_generated.BuiltinOperator.SUM,
64
+ _TFLOpName.SELECT_V2: schema_py_generated.BuiltinOperator.SELECT_V2,
65
+ })
66
+
67
+ TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
68
+ dict((reversed(item) for item in TFL_OP_NAME_TO_CODE.items()))
69
+ )
70
+
71
+ # Quantized dimension for per-channel quantization.
72
+ # See https://www.tensorflow.org/lite/performance/quantization_spec.
73
+ TFL_OP_TO_WEIGHT_QUANTIZED_DIM = immutabledict.immutabledict({
74
+ _TFLOpName.FULLY_CONNECTED: 0,
75
+ _TFLOpName.DEPTHWISE_CONV_2D: 3,
76
+ _TFLOpName.CONV_2D: 0,
77
+ _TFLOpName.EMBEDDING_LOOKUP: 0,
78
+ _TFLOpName.CONV_2D_TRANSPOSE: 0,
79
+ })
80
+
81
+ NUM_TFL_DATATYPES = 18
82
+ TENSOR_CODE_TO_TYPE = {}
83
+ for dtype_code in range(NUM_TFL_DATATYPES):
84
+ TENSOR_CODE_TO_TYPE[dtype_code] = flatbuffer_utils.type_to_name(dtype_code)
85
+ TENSOR_CODE_TO_TYPE = immutabledict.immutabledict(TENSOR_CODE_TO_TYPE)
86
+ TENSOR_TYPE_TO_CODE = immutabledict.immutabledict(
87
+ (reversed(item) for item in TENSOR_CODE_TO_TYPE.items())
88
+ )
89
+
90
+ # Expose functions in tensorflow.lite.tools.flatbuffer_utils
91
+ write_model = flatbuffer_utils.write_model
92
+
93
+
94
+ def read_model(tflite_model: Union[str, bytearray]) -> Any:
95
+ """Read and convert the TFLite model into a flatbuffer object.
96
+
97
+ Args:
98
+ tflite_model: TFLite model path or bytearray.
99
+
100
+ Raises:
101
+ ValueError: Unsupported tflite_model type.
102
+
103
+ Returns:
104
+ flatbuffer_model: the flatbuffer_model.
105
+ """
106
+ if isinstance(tflite_model, str):
107
+ return flatbuffer_utils.read_model(tflite_model)
108
+ elif isinstance(tflite_model, bytes) or isinstance(tflite_model, bytearray):
109
+ return flatbuffer_utils.read_model_from_bytearray(tflite_model)
110
+ else:
111
+ raise ValueError(
112
+ "Unsupported tflite_model type: %s" % type(tflite_model).__name__
113
+ )
114
+
115
+
116
+ def get_model_content(tflite_path: str) -> bytes:
117
+ """Get the model content (bytes) from the path.
118
+
119
+ Args:
120
+ tflite_path: Path to the .tflite.
121
+
122
+ Returns:
123
+ The model bytes.
124
+ """
125
+ with gfile.Open(tflite_path, "rb") as tflite_file:
126
+ return tflite_file.read()
127
+
128
+
129
+ def get_model_buffer(tflite_path: str) -> bytearray:
130
+ """Get the model buffer from the path.
131
+
132
+ Args:
133
+ tflite_path: path to the .tflite.
134
+
135
+ Returns:
136
+ model_buffer: the model buffer.
137
+ """
138
+ with gfile.Open(tflite_path, "rb") as tflite_file:
139
+ return bytearray(tflite_file.read())
140
+
141
+
142
+ def parse_op_tensors(op: Any, subgraph_tensors: list[Any]) -> list[Any]:
143
+ """Parse the op tensors.
144
+
145
+ Args:
146
+ op: the op that need to be parsed.
147
+ subgraph_tensors: list of tensors in the subgraph.
148
+
149
+ Returns:
150
+ tensors: list of tensors that are associated with the op.
151
+ """
152
+
153
+ tensors = []
154
+ for tensor_idx in list(op.outputs) + list(op.inputs):
155
+ if tensor_idx != -1:
156
+ tensors.append(subgraph_tensors[tensor_idx])
157
+ return tensors
158
+
159
+
160
+ def parse_fc_bmm_conv_tensors(
161
+ op: Any,
162
+ subgraph_tensors: list[Any],
163
+ input_index: int = 0,
164
+ weight_index: int = 1,
165
+ bias_index: int = 2,
166
+ output_index: int = 0,
167
+ ) -> tuple[Any, Any, Any, Any]:
168
+ """Parse tensors in FullyConnected, BatchMatmul, and Convolutions.
169
+
170
+ Args:
171
+ op: the TFLite op, must be fully_connected, batch_matmul, or convolution.
172
+ subgraph_tensors: tensors in the subgraph.
173
+ input_index: index for the input tensor.
174
+ weight_index: index for the weight tensor.
175
+ bias_index: index for the bias tensor.
176
+ output_index: index for the output tensor.
177
+
178
+ Returns:
179
+ input_tensor, weight_tensor, bias_tensor, output_tensor
180
+ """
181
+
182
+ input_tensor = subgraph_tensors[op.inputs[input_index]]
183
+ weight_tensor = subgraph_tensors[op.inputs[weight_index]]
184
+ bias_tensor = None
185
+ if bias_index < len(op.inputs) and op.inputs[bias_index] != -1:
186
+ bias_tensor = subgraph_tensors[op.inputs[bias_index]]
187
+ output_tensor = subgraph_tensors[op.outputs[output_index]]
188
+ return input_tensor, weight_tensor, bias_tensor, output_tensor
189
+
190
+
191
+ # flatbuffer_model has Any type since tensorflow.lite.tools.flatbuffer_utils
192
+ # is not type annotated.
193
+ def buffer_to_tensors(flatbuffer_model: Any) -> dict[int, list[Any]]:
194
+ """Get the buffer to tensor map for a tflite model.
195
+
196
+ Args:
197
+ flatbuffer_model: the flatbuffer_model.
198
+
199
+ Returns:
200
+ buffer_to_tensor_map: key as buffer index, value as list of tensors share
201
+ the buffer
202
+ """
203
+ buffer_to_tensor_map = {}
204
+ for subgraph in flatbuffer_model.subgraphs:
205
+ for op in subgraph.operators:
206
+ for tensor in parse_op_tensors(op, subgraph.tensors):
207
+ if tensor.buffer not in buffer_to_tensor_map:
208
+ buffer_to_tensor_map[tensor.buffer] = []
209
+ buffer_to_tensor_map[tensor.buffer].append(tensor)
210
+ return buffer_to_tensor_map
211
+
212
+
213
+ def get_tensor_name(tensor: Any) -> str:
214
+ """Get the tensor name for a fb tensor.
215
+
216
+ Args:
217
+ tensor: tensor in flatbuffer.
218
+
219
+ Returns:
220
+ tensor_name: name of the buffer
221
+ """
222
+ return tensor.name.decode("utf-8")
223
+
224
+
225
+ def get_tensor_data(tensor: Any, buffers: list[Any]) -> Optional[np.ndarray]:
226
+ """Get the tensor data.
227
+
228
+ Args:
229
+ tensor: tensor in flatbuffer.
230
+ buffers: list of buffers
231
+
232
+ Returns:
233
+ tensor_data: data inside the tensor
234
+ """
235
+ tensor_buffer = buffers[tensor.buffer]
236
+ buffer_data = tensor_buffer.data
237
+ if buffer_data is None:
238
+ return None
239
+ data = np.frombuffer(
240
+ buffer_data, dtype=TENSOR_CODE_TO_TYPE[tensor.type].lower()
241
+ )
242
+ data = np.reshape(data, tensor.shape)
243
+ return data
244
+
245
+
246
+ def has_same_quantization(tensor1: Any, tensor2: Any) -> bool:
247
+ """Check if two tensors have the same quantization.
248
+
249
+ Args:
250
+ tensor1: tensor in flatbuffer.
251
+ tensor2: tensor in flatbuffer.
252
+
253
+ Returns:
254
+ True if two tensors have the same quantization.
255
+ """
256
+
257
+ def to_tuple(val):
258
+ if val is None:
259
+ val = []
260
+ return tuple(val)
261
+
262
+ same_type = tensor1.type == tensor2.type
263
+
264
+ # Return True if both tensors are not quantized.
265
+ if tensor1.quantization is None and tensor2.quantization is None:
266
+ return True
267
+ if tensor1.quantization.scale is None and tensor2.quantization.scale is None:
268
+ return True
269
+
270
+ same_scale = to_tuple(tensor1.quantization.scale) == to_tuple(
271
+ tensor2.quantization.scale
272
+ )
273
+ same_zero_point = to_tuple(tensor1.quantization.zeroPoint) == to_tuple(
274
+ tensor2.quantization.zeroPoint
275
+ )
276
+ same_quantized_dimension = (
277
+ tensor1.quantization.quantizedDimension
278
+ == tensor2.quantization.quantizedDimension
279
+ )
280
+ return (
281
+ same_type and same_scale and same_zero_point and same_quantized_dimension
282
+ )
283
+
284
+
285
+ def is_float_model(flatbuffer_model: Any) -> bool:
286
+ """Checks that the model is float and not already quantized."""
287
+ for subgraph in flatbuffer_model.subgraphs:
288
+ for tensor in subgraph.tensors:
289
+ if tensor.quantization is None:
290
+ continue
291
+ if tensor.quantization.scale is not None:
292
+ return False
293
+ return True
294
+
295
+
296
+ def get_subgraph_input_output_operators(
297
+ subgraph: Any,
298
+ ) -> list[qtyping.IOOperator]:
299
+ """Get the input/output operators for the subgraph.
300
+
301
+ Args:
302
+ subgraph: The subgraph object.
303
+
304
+ Returns:
305
+ Input and output operators for the subgraph.
306
+ """
307
+ input_op = qtyping.IOOperator(
308
+ inputs=[],
309
+ outputs=subgraph.inputs,
310
+ op_key=qtyping.TFLOperationName.INPUT,
311
+ )
312
+ output_op = qtyping.IOOperator(
313
+ inputs=subgraph.outputs,
314
+ outputs=[],
315
+ op_key=qtyping.TFLOperationName.OUTPUT,
316
+ )
317
+ return [input_op, output_op]
@@ -0,0 +1,200 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Tests for tfl_flatbuffer_utils.py."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from tensorflow.python.platform import googletest
21
+ from ai_edge_quantizer import qtyping
22
+ from ai_edge_quantizer.utils import test_utils
23
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
24
+
25
+
26
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../tests/models")
27
+
28
+
29
+ # TODO: b/328830092 - Add test cases for model require buffer offset.
30
+ class FlatbufferUtilsTest(googletest.TestCase):
31
+
32
+ def setUp(self):
33
+ super().setUp()
34
+ self._test_model_path = os.path.join(
35
+ TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite"
36
+ )
37
+
38
+ self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
39
+
40
+ def test_get_model_buffer(self):
41
+ model_buffer = tfl_flatbuffer_utils.get_model_buffer(self._test_model_path)
42
+ file_stats = os.stat(self._test_model_path)
43
+ self.assertLen(model_buffer, file_stats.st_size)
44
+
45
+ def test_parse_op_tensors(self):
46
+ subgraph0 = self._test_model.subgraphs[0]
47
+ conv2d_op = subgraph0.operators[0]
48
+ op_tensors = tfl_flatbuffer_utils.parse_op_tensors(
49
+ conv2d_op, subgraph0.tensors
50
+ )
51
+ # conv2d have three inputs and one output
52
+ self.assertLen(op_tensors, 4)
53
+
54
+ average_pool_op = subgraph0.operators[1]
55
+ op_tensors = tfl_flatbuffer_utils.parse_op_tensors(
56
+ average_pool_op, subgraph0.tensors
57
+ )
58
+ # averagepool have one input and one output
59
+ self.assertLen(op_tensors, 2)
60
+
61
+ def test_parse_fc_bmm_conv_tensors(self):
62
+ subgraph0 = self._test_model.subgraphs[0]
63
+ conv2d_op = subgraph0.operators[0]
64
+ inputs, weight, bias, output = (
65
+ tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
66
+ conv2d_op, subgraph0.tensors
67
+ )
68
+ )
69
+ self.assertEqual(tuple(inputs.shape), (1, 28, 28, 1))
70
+ self.assertEqual(tuple(weight.shape), (8, 3, 3, 1))
71
+ self.assertEqual(tuple(bias.shape), (8,))
72
+ self.assertEqual(tuple(output.shape), (1, 28, 28, 8))
73
+
74
+ fc_with_bias = subgraph0.operators[3]
75
+ inputs, weight, bias, output = (
76
+ tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
77
+ fc_with_bias,
78
+ subgraph0.tensors,
79
+ )
80
+ )
81
+ self.assertEqual(tuple(inputs.shape), (1, 1568))
82
+ self.assertEqual(tuple(weight.shape), (32, 1568))
83
+ self.assertEqual(tuple(bias.shape), (32,))
84
+ self.assertEqual(tuple(output.shape), (1, 32))
85
+
86
+ fc_no_bias = subgraph0.operators[4]
87
+ inputs, weight, bias, output = (
88
+ tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
89
+ fc_no_bias,
90
+ subgraph0.tensors,
91
+ )
92
+ )
93
+ self.assertEqual(tuple(inputs.shape), (1, 32))
94
+ self.assertEqual(tuple(weight.shape), (10, 32))
95
+ self.assertIsNone(bias)
96
+ self.assertEqual(tuple(output.shape), (1, 10))
97
+
98
+ def test_buffer_to_tensors(self):
99
+ buffer_to_tensor_map = tfl_flatbuffer_utils.buffer_to_tensors(
100
+ self._test_model
101
+ )
102
+ # Read from Netron/Model Explorer
103
+ tensors = buffer_to_tensor_map[6]
104
+ self.assertLen(tensors, 1)
105
+ conv2d_filter_tensor = tensors[0]
106
+ self.assertEqual(tuple(conv2d_filter_tensor.shape), (8, 3, 3, 1))
107
+
108
+ def test_get_tensor_name(self):
109
+ subgraph0 = self._test_model.subgraphs[0]
110
+ subgraph_tensors = subgraph0.tensors
111
+ conv2d_op = subgraph0.operators[0]
112
+ weight_tensor = subgraph_tensors[conv2d_op.inputs[1]]
113
+ weight_tensor_name = tfl_flatbuffer_utils.get_tensor_name(weight_tensor)
114
+ self.assertEqual(weight_tensor_name, "sequential/conv2d/Conv2D")
115
+
116
+ # TODO: b/325123193 - test tensor with data outside of flatbuffer.
117
+ def test_get_tensor_data(self):
118
+ subgraph0 = self._test_model.subgraphs[0]
119
+ subgraph_tensors = subgraph0.tensors
120
+ conv2d_op = subgraph0.operators[0]
121
+ # Check tensor with data
122
+ weight_tensor = subgraph_tensors[conv2d_op.inputs[1]]
123
+ weight_tensor_data = tfl_flatbuffer_utils.get_tensor_data(
124
+ weight_tensor, self._test_model.buffers
125
+ )
126
+ self.assertEqual(
127
+ tuple(weight_tensor.shape), tuple(weight_tensor_data.shape) # pytype: disable=attribute-error
128
+ )
129
+ self.assertAlmostEqual(weight_tensor_data[0][0][0][0], -0.12941549718379974)
130
+
131
+ # Check tensor with no data
132
+ input_tensor = subgraph_tensors[conv2d_op.inputs[0]]
133
+ input_tensor_data = tfl_flatbuffer_utils.get_tensor_data(
134
+ input_tensor, self._test_model.buffers
135
+ )
136
+ self.assertIsNone(input_tensor_data)
137
+
138
+ def test_has_same_quantization_succeeds(self):
139
+ tensor0, tensor1 = self._test_model.subgraphs[0].tensors[:2]
140
+ tensor0.quantization.scale = np.array([1, 2, 3]).astype(np.float32)
141
+ tensor0.quantization.zeroPoint = np.array([3, 2, 1]).astype(np.int32)
142
+ tensor1.quantization.scale = np.array([1, 2, 3]).astype(np.float32)
143
+ tensor1.quantization.zeroPoint = np.array([3, 2, 1]).astype(np.int32)
144
+ self.assertTrue(
145
+ tfl_flatbuffer_utils.has_same_quantization(tensor0, tensor1)
146
+ )
147
+
148
+ def test_has_same_quantization_succeds_not_quantized(self):
149
+ tensor0, tensor1 = self._test_model.subgraphs[0].tensors[:2]
150
+ tensor0.type = 10
151
+ self.assertTrue(
152
+ tfl_flatbuffer_utils.has_same_quantization(tensor0, tensor1)
153
+ )
154
+
155
+ def test_has_same_quantization_fails_different_scale(self):
156
+ tensor0, tensor1 = self._test_model.subgraphs[0].tensors[:2]
157
+ tensor1.quantization.scale = np.array([1, 2, 3]).astype(np.float32)
158
+ self.assertFalse(
159
+ tfl_flatbuffer_utils.has_same_quantization(tensor0, tensor1)
160
+ )
161
+
162
+ def test_has_same_quantization_fails_different_zp(self):
163
+ tensor0, tensor1 = self._test_model.subgraphs[0].tensors[:2]
164
+ tensor0.quantization.scale = np.array([1, 2, 3]).astype(np.float32)
165
+ tensor0.quantization.zeroPoint = np.array([3, 2, 1]).astype(np.int32)
166
+ tensor1.quantization.scale = np.array([1, 2, 3]).astype(np.float32)
167
+ tensor1.quantization.zeroPoint = np.array([1, 2, 3]).astype(np.int32)
168
+ self.assertFalse(
169
+ tfl_flatbuffer_utils.has_same_quantization(tensor0, tensor1)
170
+ )
171
+
172
+ def test_check_is_float_model_true_when_model_is_float(self):
173
+ test_model_path = os.path.join(
174
+ TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite"
175
+ )
176
+ model = tfl_flatbuffer_utils.read_model(test_model_path)
177
+ self.assertTrue(tfl_flatbuffer_utils.is_float_model(model))
178
+
179
+ def test_check_is_float_model_false_when_model_is_quantized(self):
180
+ test_model_path = os.path.join(
181
+ TEST_DATA_PREFIX_PATH, "mnist_quantized.tflite"
182
+ )
183
+ model = tfl_flatbuffer_utils.read_model(test_model_path)
184
+ self.assertFalse(tfl_flatbuffer_utils.is_float_model(model))
185
+
186
+ def test_get_subgraph_input_output_operators(self):
187
+ subgraph = self._test_model.subgraphs[0]
188
+ input_op, output_op = (
189
+ tfl_flatbuffer_utils.get_subgraph_input_output_operators(subgraph)
190
+ )
191
+ self.assertEqual(input_op.op_key, qtyping.TFLOperationName.INPUT)
192
+ self.assertEmpty(input_op.inputs)
193
+ self.assertListEqual(list(input_op.outputs), [0])
194
+ self.assertEqual(output_op.op_key, qtyping.TFLOperationName.OUTPUT)
195
+ self.assertListEqual(list(output_op.inputs), [12])
196
+ self.assertEmpty(output_op.outputs)
197
+
198
+
199
+ if __name__ == "__main__":
200
+ googletest.main()