ai-edge-litert-nightly 2.2.0.dev20260102__cp312-cp312-manylinux_2_27_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_litert/__init__.py +1 -0
- ai_edge_litert/_pywrap_analyzer_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_compiled_model_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_interpreter_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_tensor_buffer_wrapper.so +0 -0
- ai_edge_litert/_pywrap_modify_model_interface.so +0 -0
- ai_edge_litert/_pywrap_string_util.so +0 -0
- ai_edge_litert/_pywrap_tensorflow_lite_calibration_wrapper.so +0 -0
- ai_edge_litert/_pywrap_tensorflow_lite_metrics_wrapper.so +0 -0
- ai_edge_litert/any_pb2.py +37 -0
- ai_edge_litert/aot/__init__.py +0 -0
- ai_edge_litert/aot/ai_pack/__init__.py +0 -0
- ai_edge_litert/aot/ai_pack/export_lib.py +300 -0
- ai_edge_litert/aot/aot_compile.py +153 -0
- ai_edge_litert/aot/core/__init__.py +0 -0
- ai_edge_litert/aot/core/apply_plugin.py +148 -0
- ai_edge_litert/aot/core/common.py +97 -0
- ai_edge_litert/aot/core/components.py +93 -0
- ai_edge_litert/aot/core/mlir_transforms.py +36 -0
- ai_edge_litert/aot/core/tflxx_util.py +30 -0
- ai_edge_litert/aot/core/types.py +374 -0
- ai_edge_litert/aot/prepare_for_npu.py +152 -0
- ai_edge_litert/aot/vendors/__init__.py +22 -0
- ai_edge_litert/aot/vendors/example/__init__.py +0 -0
- ai_edge_litert/aot/vendors/example/example_backend.py +157 -0
- ai_edge_litert/aot/vendors/fallback_backend.py +128 -0
- ai_edge_litert/aot/vendors/google_tensor/__init__.py +0 -0
- ai_edge_litert/aot/vendors/google_tensor/google_tensor_backend.py +168 -0
- ai_edge_litert/aot/vendors/google_tensor/target.py +84 -0
- ai_edge_litert/aot/vendors/import_vendor.py +132 -0
- ai_edge_litert/aot/vendors/mediatek/__init__.py +0 -0
- ai_edge_litert/aot/vendors/mediatek/mediatek_backend.py +196 -0
- ai_edge_litert/aot/vendors/mediatek/target.py +94 -0
- ai_edge_litert/aot/vendors/qualcomm/__init__.py +0 -0
- ai_edge_litert/aot/vendors/qualcomm/qualcomm_backend.py +161 -0
- ai_edge_litert/aot/vendors/qualcomm/target.py +75 -0
- ai_edge_litert/api_pb2.py +43 -0
- ai_edge_litert/compiled_model.py +250 -0
- ai_edge_litert/descriptor_pb2.py +3361 -0
- ai_edge_litert/duration_pb2.py +37 -0
- ai_edge_litert/empty_pb2.py +37 -0
- ai_edge_litert/field_mask_pb2.py +37 -0
- ai_edge_litert/format_converter_wrapper_pybind11.so +0 -0
- ai_edge_litert/hardware_accelerator.py +22 -0
- ai_edge_litert/internal/__init__.py +0 -0
- ai_edge_litert/internal/litertlm_builder.py +584 -0
- ai_edge_litert/internal/litertlm_core.py +58 -0
- ai_edge_litert/internal/litertlm_header_schema_py_generated.py +1596 -0
- ai_edge_litert/internal/llm_metadata_pb2.py +45 -0
- ai_edge_litert/internal/llm_model_type_pb2.py +51 -0
- ai_edge_litert/internal/sampler_params_pb2.py +39 -0
- ai_edge_litert/internal/token_pb2.py +38 -0
- ai_edge_litert/interpreter.py +1039 -0
- ai_edge_litert/libLiteRt.so +0 -0
- ai_edge_litert/libpywrap_litert_common.so +0 -0
- ai_edge_litert/metrics_interface.py +48 -0
- ai_edge_litert/metrics_portable.py +70 -0
- ai_edge_litert/model_runtime_info_pb2.py +66 -0
- ai_edge_litert/plugin_pb2.py +46 -0
- ai_edge_litert/profiling_info_pb2.py +47 -0
- ai_edge_litert/pywrap_genai_ops.so +0 -0
- ai_edge_litert/schema_py_generated.py +19640 -0
- ai_edge_litert/source_context_pb2.py +37 -0
- ai_edge_litert/struct_pb2.py +47 -0
- ai_edge_litert/tensor_buffer.py +167 -0
- ai_edge_litert/timestamp_pb2.py +37 -0
- ai_edge_litert/tools/__init__.py +0 -0
- ai_edge_litert/tools/apply_plugin_main +0 -0
- ai_edge_litert/tools/flatbuffer_utils.py +534 -0
- ai_edge_litert/type_pb2.py +53 -0
- ai_edge_litert/vendors/google_tensor/compiler/libLiteRtCompilerPlugin_google_tensor.so +0 -0
- ai_edge_litert/vendors/mediatek/compiler/libLiteRtCompilerPlugin_MediaTek.so +0 -0
- ai_edge_litert/vendors/qualcomm/compiler/libLiteRtCompilerPlugin_Qualcomm.so +0 -0
- ai_edge_litert/wrappers_pb2.py +53 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/METADATA +52 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/RECORD +78 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/WHEEL +5 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,534 @@
|
|
|
1
|
+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Utility functions for FlatBuffers.
|
|
16
|
+
|
|
17
|
+
All functions that are commonly used to work with FlatBuffers.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import copy
|
|
21
|
+
import random
|
|
22
|
+
import re
|
|
23
|
+
import struct
|
|
24
|
+
import sys
|
|
25
|
+
from typing import Optional, Type, TypeVar, Union
|
|
26
|
+
|
|
27
|
+
import flatbuffers
|
|
28
|
+
|
|
29
|
+
import os # import gfile
|
|
30
|
+
from ai_edge_litert import schema_py_generated as schema_fb # pylint:disable=g-direct-tensorflow-import
|
|
31
|
+
|
|
32
|
+
_TFLITE_FILE_IDENTIFIER = b'TFL3'
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_builtin_code_from_operator_code(opcode):
|
|
36
|
+
"""Return the builtin code of the given operator code.
|
|
37
|
+
|
|
38
|
+
The following method is introduced to resolve op builtin code shortage
|
|
39
|
+
problem. The new builtin operator will be assigned to the extended builtin
|
|
40
|
+
code field in the flatbuffer schema. Those methods helps to hide builtin code
|
|
41
|
+
details.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
opcode: Operator code.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
The builtin code of the given operator code.
|
|
48
|
+
"""
|
|
49
|
+
# Access BuiltinCode() method first if available.
|
|
50
|
+
if hasattr(opcode, 'BuiltinCode') and callable(opcode.BuiltinCode):
|
|
51
|
+
return max(opcode.BuiltinCode(), opcode.DeprecatedBuiltinCode())
|
|
52
|
+
|
|
53
|
+
return max(opcode.builtinCode, opcode.deprecatedBuiltinCode)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def convert_bytearray_to_object(model_bytearray):
|
|
57
|
+
"""Converts a tflite model from a bytearray to an object for parsing."""
|
|
58
|
+
model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
|
|
59
|
+
return schema_fb.ModelT.InitFromObj(model_object)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def read_model(input_tflite_file):
|
|
63
|
+
"""Reads a tflite model as a python object.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
input_tflite_file: Full path name to the input tflite file
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
RuntimeError: If input_tflite_file path is invalid.
|
|
70
|
+
IOError: If input_tflite_file cannot be opened.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
A python object corresponding to the input tflite file.
|
|
74
|
+
"""
|
|
75
|
+
if not os.path.exists(input_tflite_file):
|
|
76
|
+
raise RuntimeError('Input file not found at %r\n' % input_tflite_file)
|
|
77
|
+
with open(input_tflite_file, 'rb') as input_file_handle:
|
|
78
|
+
model_bytearray = bytearray(input_file_handle.read())
|
|
79
|
+
return read_model_from_bytearray(model_bytearray)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def read_model_from_bytearray(model_bytearray):
|
|
83
|
+
"""Reads a tflite model as a python object.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
model_bytearray: TFLite model in bytearray format.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
A python object corresponding to the input tflite file.
|
|
90
|
+
"""
|
|
91
|
+
model = convert_bytearray_to_object(model_bytearray)
|
|
92
|
+
if sys.byteorder == 'big':
|
|
93
|
+
byte_swap_tflite_model_obj(model, 'little', 'big')
|
|
94
|
+
|
|
95
|
+
# Offset handling for models > 2GB
|
|
96
|
+
for buffer in model.buffers:
|
|
97
|
+
if buffer.offset:
|
|
98
|
+
buffer.data = model_bytearray[buffer.offset : buffer.offset + buffer.size]
|
|
99
|
+
buffer.offset = 0
|
|
100
|
+
buffer.size = 0
|
|
101
|
+
for subgraph in model.subgraphs:
|
|
102
|
+
for op in subgraph.operators:
|
|
103
|
+
if op.largeCustomOptionsOffset:
|
|
104
|
+
op.customOptions = model_bytearray[
|
|
105
|
+
op.largeCustomOptionsOffset : op.largeCustomOptionsOffset
|
|
106
|
+
+ op.largeCustomOptionsSize
|
|
107
|
+
]
|
|
108
|
+
op.largeCustomOptionsOffset = 0
|
|
109
|
+
op.largeCustomOptionsSize = 0
|
|
110
|
+
|
|
111
|
+
return model
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def read_model_with_mutable_tensors(input_tflite_file):
|
|
115
|
+
"""Reads a tflite model as a python object with mutable tensors.
|
|
116
|
+
|
|
117
|
+
Similar to read_model() with the addition that the returned object has
|
|
118
|
+
mutable tensors (read_model() returns an object with immutable tensors).
|
|
119
|
+
|
|
120
|
+
NOTE: This API only works for TFLite generated with
|
|
121
|
+
_experimental_use_buffer_offset=false
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
input_tflite_file: Full path name to the input tflite file
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
RuntimeError: If input_tflite_file path is invalid.
|
|
128
|
+
IOError: If input_tflite_file cannot be opened.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
A mutable python object corresponding to the input tflite file.
|
|
132
|
+
"""
|
|
133
|
+
return copy.deepcopy(read_model(input_tflite_file))
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def convert_object_to_bytearray(model_object, extra_buffer=b''):
|
|
137
|
+
"""Converts a tflite model from an object to a immutable bytearray."""
|
|
138
|
+
# Initial size of the buffer, which will grow automatically if needed
|
|
139
|
+
builder = flatbuffers.Builder(1024)
|
|
140
|
+
model_offset = model_object.Pack(builder)
|
|
141
|
+
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
|
142
|
+
model_bytearray = bytes(builder.Output())
|
|
143
|
+
model_bytearray = model_bytearray + extra_buffer
|
|
144
|
+
return model_bytearray
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def write_model(model_object, output_tflite_file):
|
|
148
|
+
"""Writes the tflite model, a python object, into the output file.
|
|
149
|
+
|
|
150
|
+
NOTE: This API only works for TFLite generated with
|
|
151
|
+
_experimental_use_buffer_offset=false
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
model_object: A tflite model as a python object
|
|
155
|
+
output_tflite_file: Full path name to the output tflite file.
|
|
156
|
+
|
|
157
|
+
Raises:
|
|
158
|
+
IOError: If output_tflite_file path is invalid or cannot be opened.
|
|
159
|
+
"""
|
|
160
|
+
if sys.byteorder == 'big':
|
|
161
|
+
model_object = copy.deepcopy(model_object)
|
|
162
|
+
byte_swap_tflite_model_obj(model_object, 'big', 'little')
|
|
163
|
+
model_bytearray = convert_object_to_bytearray(model_object)
|
|
164
|
+
with open(output_tflite_file, 'wb') as output_file_handle:
|
|
165
|
+
output_file_handle.write(model_bytearray)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def strip_strings(model):
|
|
169
|
+
"""Strips all nonessential strings from the model to reduce model size.
|
|
170
|
+
|
|
171
|
+
We remove the following strings:
|
|
172
|
+
(find strings by searching ":string" in the tensorflow lite flatbuffer schema)
|
|
173
|
+
1. Model description
|
|
174
|
+
2. SubGraph name
|
|
175
|
+
3. Tensor names
|
|
176
|
+
We retain OperatorCode custom_code and Metadata name.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
model: The model from which to remove nonessential strings.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
model.description = None
|
|
183
|
+
for subgraph in model.subgraphs:
|
|
184
|
+
subgraph.name = None
|
|
185
|
+
for tensor in subgraph.tensors:
|
|
186
|
+
tensor.name = None
|
|
187
|
+
# We clear all signature_def structure, since without names it is useless.
|
|
188
|
+
model.signatureDefs = None
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def type_to_name(tensor_type):
|
|
192
|
+
"""Converts a numerical enum to a readable tensor type."""
|
|
193
|
+
for name, value in schema_fb.TensorType.__dict__.items():
|
|
194
|
+
if value == tensor_type:
|
|
195
|
+
return name
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def randomize_weights(model, random_seed=0, buffers_to_skip=None):
|
|
200
|
+
"""Randomize weights in a model.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
model: The model in which to randomize weights.
|
|
204
|
+
random_seed: The input to the random number generator (default value is 0).
|
|
205
|
+
buffers_to_skip: The list of buffer indices to skip. The weights in these
|
|
206
|
+
buffers are left unmodified.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
# The input to the random seed generator. The default value is 0.
|
|
210
|
+
random.seed(random_seed)
|
|
211
|
+
|
|
212
|
+
# Parse model buffers which store the model weights
|
|
213
|
+
buffers = model.buffers
|
|
214
|
+
buffer_ids = range(1, len(buffers)) # ignore index 0 as it's always None
|
|
215
|
+
if buffers_to_skip is not None:
|
|
216
|
+
buffer_ids = [idx for idx in buffer_ids if idx not in buffers_to_skip]
|
|
217
|
+
|
|
218
|
+
buffer_types = {}
|
|
219
|
+
for graph in model.subgraphs:
|
|
220
|
+
for op in graph.operators:
|
|
221
|
+
if op.inputs is None:
|
|
222
|
+
break
|
|
223
|
+
for input_idx in op.inputs:
|
|
224
|
+
tensor = graph.tensors[input_idx]
|
|
225
|
+
buffer_types[tensor.buffer] = type_to_name(tensor.type)
|
|
226
|
+
|
|
227
|
+
for i in buffer_ids:
|
|
228
|
+
buffer_i_data = buffers[i].data
|
|
229
|
+
buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size
|
|
230
|
+
if buffer_i_size == 0:
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
# Raw data buffers are of type ubyte (or uint8) whose values lie in the
|
|
234
|
+
# range [0, 255]. Those ubytes (or unint8s) are the underlying
|
|
235
|
+
# representation of each datatype. For example, a bias tensor of type
|
|
236
|
+
# int32 appears as a buffer 4 times it's length of type ubyte (or uint8).
|
|
237
|
+
# For floats, we need to generate a valid float and then pack it into
|
|
238
|
+
# the raw bytes in place.
|
|
239
|
+
buffer_type = buffer_types.get(i, 'INT8')
|
|
240
|
+
if buffer_type and buffer_type.startswith('FLOAT'):
|
|
241
|
+
format_code = 'e' if buffer_type == 'FLOAT16' else 'f'
|
|
242
|
+
for offset in range(0, buffer_i_size, struct.calcsize(format_code)):
|
|
243
|
+
value = random.uniform(-0.5, 0.5) # See http://b/152324470#comment2
|
|
244
|
+
struct.pack_into(format_code, buffer_i_data, offset, value)
|
|
245
|
+
else:
|
|
246
|
+
for j in range(buffer_i_size):
|
|
247
|
+
buffer_i_data[j] = random.randint(0, 255)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def rename_custom_ops(model, map_custom_op_renames):
|
|
251
|
+
"""Rename custom ops so they use the same naming style as builtin ops.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
model: The input tflite model.
|
|
255
|
+
map_custom_op_renames: A mapping from old to new custom op names.
|
|
256
|
+
"""
|
|
257
|
+
for op_code in model.operatorCodes:
|
|
258
|
+
if op_code.customCode:
|
|
259
|
+
op_code_str = op_code.customCode.decode('ascii')
|
|
260
|
+
if op_code_str in map_custom_op_renames:
|
|
261
|
+
op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii')
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def opcode_to_name(model, op_code):
|
|
265
|
+
"""Converts a TFLite op_code to the human readable name.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
model: The input tflite model.
|
|
269
|
+
op_code: The op_code to resolve to a readable name.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
A string containing the human readable op name, or None if not resolvable.
|
|
273
|
+
"""
|
|
274
|
+
op = model.operatorCodes[op_code]
|
|
275
|
+
code = max(op.builtinCode, op.deprecatedBuiltinCode)
|
|
276
|
+
for name, value in vars(schema_fb.BuiltinOperator).items():
|
|
277
|
+
if value == code:
|
|
278
|
+
return name
|
|
279
|
+
return None
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def xxd_output_to_bytes(input_cc_file):
|
|
283
|
+
"""Converts xxd output C++ source file to bytes (immutable).
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
input_cc_file: Full path name to th C++ source file dumped by xxd
|
|
287
|
+
|
|
288
|
+
Raises:
|
|
289
|
+
RuntimeError: If input_cc_file path is invalid.
|
|
290
|
+
IOError: If input_cc_file cannot be opened.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
A bytearray corresponding to the input cc file array.
|
|
294
|
+
"""
|
|
295
|
+
# Match hex values in the string with comma as separator
|
|
296
|
+
pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*')
|
|
297
|
+
|
|
298
|
+
model_bytearray = bytearray()
|
|
299
|
+
|
|
300
|
+
with open(input_cc_file) as file_handle:
|
|
301
|
+
for line in file_handle:
|
|
302
|
+
values_match = pattern.match(line)
|
|
303
|
+
|
|
304
|
+
if values_match is None:
|
|
305
|
+
continue
|
|
306
|
+
|
|
307
|
+
# Match in the parentheses (hex array only)
|
|
308
|
+
list_text = values_match.group(1)
|
|
309
|
+
|
|
310
|
+
# Extract hex values (text) from the line
|
|
311
|
+
# e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c,
|
|
312
|
+
values_text = filter(None, list_text.split(','))
|
|
313
|
+
|
|
314
|
+
# Convert to hex
|
|
315
|
+
values = [int(x, base=16) for x in values_text]
|
|
316
|
+
model_bytearray.extend(values)
|
|
317
|
+
|
|
318
|
+
return bytes(model_bytearray)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def xxd_output_to_object(input_cc_file):
|
|
322
|
+
"""Converts xxd output C++ source file to object.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
input_cc_file: Full path name to th C++ source file dumped by xxd
|
|
326
|
+
|
|
327
|
+
Raises:
|
|
328
|
+
RuntimeError: If input_cc_file path is invalid.
|
|
329
|
+
IOError: If input_cc_file cannot be opened.
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
A python object corresponding to the input tflite file.
|
|
333
|
+
"""
|
|
334
|
+
model_bytes = xxd_output_to_bytes(input_cc_file)
|
|
335
|
+
return convert_bytearray_to_object(model_bytes)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def byte_swap_buffer_content(buffer, chunksize, from_endiness, to_endiness):
|
|
339
|
+
"""Helper function for byte-swapping the buffers field."""
|
|
340
|
+
to_swap = [
|
|
341
|
+
buffer.data[i : i + chunksize]
|
|
342
|
+
for i in range(0, len(buffer.data), chunksize)
|
|
343
|
+
]
|
|
344
|
+
buffer.data = b''.join([
|
|
345
|
+
int.from_bytes(byteswap, from_endiness).to_bytes(chunksize, to_endiness)
|
|
346
|
+
for byteswap in to_swap
|
|
347
|
+
])
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def byte_swap_string_content(buffer, from_endiness, to_endiness):
|
|
351
|
+
"""Helper function for byte-swapping the string buffer.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
buffer: TFLite string buffer of from_endiness format.
|
|
355
|
+
from_endiness: The original endianness format of the string buffer.
|
|
356
|
+
to_endiness: The destined endianness format of the string buffer.
|
|
357
|
+
"""
|
|
358
|
+
num_of_strings = int.from_bytes(buffer.data[0:4], from_endiness)
|
|
359
|
+
string_content = bytearray(buffer.data[4 * (num_of_strings + 2) :])
|
|
360
|
+
prefix_data = b''.join([
|
|
361
|
+
int.from_bytes(buffer.data[i : i + 4], from_endiness).to_bytes(
|
|
362
|
+
4, to_endiness
|
|
363
|
+
)
|
|
364
|
+
for i in range(0, (num_of_strings + 1) * 4 + 1, 4)
|
|
365
|
+
])
|
|
366
|
+
buffer.data = prefix_data + string_content
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def byte_swap_tflite_model_obj(model, from_endiness, to_endiness):
|
|
370
|
+
"""Byte swaps the buffers field in a TFLite model.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
model: TFLite model object of from_endiness format.
|
|
374
|
+
from_endiness: The original endianness format of the buffers in model.
|
|
375
|
+
to_endiness: The destined endianness format of the buffers in model.
|
|
376
|
+
"""
|
|
377
|
+
if model is None:
|
|
378
|
+
return
|
|
379
|
+
# Get all the constant buffers, byte swapping them as per their data types
|
|
380
|
+
buffer_swapped = []
|
|
381
|
+
types_of_16_bits = [
|
|
382
|
+
schema_fb.TensorType.FLOAT16,
|
|
383
|
+
schema_fb.TensorType.INT16,
|
|
384
|
+
schema_fb.TensorType.UINT16,
|
|
385
|
+
]
|
|
386
|
+
types_of_32_bits = [
|
|
387
|
+
schema_fb.TensorType.FLOAT32,
|
|
388
|
+
schema_fb.TensorType.INT32,
|
|
389
|
+
schema_fb.TensorType.COMPLEX64,
|
|
390
|
+
schema_fb.TensorType.UINT32,
|
|
391
|
+
]
|
|
392
|
+
types_of_64_bits = [
|
|
393
|
+
schema_fb.TensorType.INT64,
|
|
394
|
+
schema_fb.TensorType.FLOAT64,
|
|
395
|
+
schema_fb.TensorType.COMPLEX128,
|
|
396
|
+
schema_fb.TensorType.UINT64,
|
|
397
|
+
]
|
|
398
|
+
for subgraph in model.subgraphs:
|
|
399
|
+
for tensor in subgraph.tensors:
|
|
400
|
+
if (
|
|
401
|
+
tensor.buffer > 0
|
|
402
|
+
and tensor.buffer < len(model.buffers)
|
|
403
|
+
and tensor.buffer not in buffer_swapped
|
|
404
|
+
and model.buffers[tensor.buffer].data is not None
|
|
405
|
+
):
|
|
406
|
+
if tensor.type == schema_fb.TensorType.STRING:
|
|
407
|
+
byte_swap_string_content(
|
|
408
|
+
model.buffers[tensor.buffer], from_endiness, to_endiness
|
|
409
|
+
)
|
|
410
|
+
elif tensor.type in types_of_16_bits:
|
|
411
|
+
byte_swap_buffer_content(
|
|
412
|
+
model.buffers[tensor.buffer], 2, from_endiness, to_endiness
|
|
413
|
+
)
|
|
414
|
+
elif tensor.type in types_of_32_bits:
|
|
415
|
+
byte_swap_buffer_content(
|
|
416
|
+
model.buffers[tensor.buffer], 4, from_endiness, to_endiness
|
|
417
|
+
)
|
|
418
|
+
elif tensor.type in types_of_64_bits:
|
|
419
|
+
byte_swap_buffer_content(
|
|
420
|
+
model.buffers[tensor.buffer], 8, from_endiness, to_endiness
|
|
421
|
+
)
|
|
422
|
+
else:
|
|
423
|
+
continue
|
|
424
|
+
buffer_swapped.append(tensor.buffer)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def byte_swap_tflite_buffer(tflite_model, from_endiness, to_endiness):
|
|
428
|
+
"""Generates a new model byte array after byte swapping its buffers field.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
tflite_model: TFLite flatbuffer in a byte array.
|
|
432
|
+
from_endiness: The original endianness format of the buffers in
|
|
433
|
+
tflite_model.
|
|
434
|
+
to_endiness: The destined endianness format of the buffers in tflite_model.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
TFLite flatbuffer in a byte array, after being byte swapped to to_endiness
|
|
438
|
+
format.
|
|
439
|
+
"""
|
|
440
|
+
if tflite_model is None:
|
|
441
|
+
return None
|
|
442
|
+
# Load TFLite Flatbuffer byte array into an object.
|
|
443
|
+
model = convert_bytearray_to_object(tflite_model)
|
|
444
|
+
|
|
445
|
+
# Byte swapping the constant buffers as per their data types
|
|
446
|
+
byte_swap_tflite_model_obj(model, from_endiness, to_endiness)
|
|
447
|
+
|
|
448
|
+
# Return a TFLite flatbuffer as a byte array.
|
|
449
|
+
return convert_object_to_bytearray(model)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def count_resource_variables(model):
|
|
453
|
+
"""Calculates the number of unique resource variables in a model.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
model: the input tflite model, either as bytearray or object.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
An integer number representing the number of unique resource variables.
|
|
460
|
+
"""
|
|
461
|
+
if not isinstance(model, schema_fb.ModelT):
|
|
462
|
+
model = convert_bytearray_to_object(model)
|
|
463
|
+
unique_shared_names = set()
|
|
464
|
+
for subgraph in model.subgraphs:
|
|
465
|
+
if subgraph.operators is None:
|
|
466
|
+
continue
|
|
467
|
+
for op in subgraph.operators:
|
|
468
|
+
builtin_code = get_builtin_code_from_operator_code(
|
|
469
|
+
model.operatorCodes[op.opcodeIndex]
|
|
470
|
+
)
|
|
471
|
+
if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE:
|
|
472
|
+
unique_shared_names.add(op.builtinOptions.sharedName)
|
|
473
|
+
return len(unique_shared_names)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
OptsT = TypeVar('OptsT')
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def get_options_as(
|
|
480
|
+
op: Union[schema_fb.Operator, schema_fb.OperatorT], opts_type: Type[OptsT]
|
|
481
|
+
) -> Optional[OptsT]:
|
|
482
|
+
"""Get the options of an operator as the specified type.
|
|
483
|
+
|
|
484
|
+
Requested type must be an object-api type (ends in 'T').
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
op: The operator to get the options from.
|
|
488
|
+
opts_type: The type of the options to get.
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
The options as the specified type, or None if the options are not of the
|
|
492
|
+
specified type.
|
|
493
|
+
|
|
494
|
+
Raises:
|
|
495
|
+
ValueError: If the specified type is not a valid options type.
|
|
496
|
+
"""
|
|
497
|
+
|
|
498
|
+
err = ValueError(f'Unsupported options type: {opts_type}')
|
|
499
|
+
type_name: str = opts_type.__name__
|
|
500
|
+
if not type_name.endswith('T'):
|
|
501
|
+
raise err
|
|
502
|
+
base_type_name = type_name.removesuffix('T')
|
|
503
|
+
is_opt_1_type = hasattr(schema_fb.BuiltinOptions, base_type_name)
|
|
504
|
+
if not is_opt_1_type and not hasattr(
|
|
505
|
+
schema_fb.BuiltinOptions2, base_type_name
|
|
506
|
+
):
|
|
507
|
+
raise err
|
|
508
|
+
|
|
509
|
+
if isinstance(op, schema_fb.Operator):
|
|
510
|
+
if not is_opt_1_type:
|
|
511
|
+
enum_val = getattr(schema_fb.BuiltinOptions2, base_type_name)
|
|
512
|
+
opts_creator = schema_fb.BuiltinOptions2Creator
|
|
513
|
+
raw_ops = op.BuiltinOptions2()
|
|
514
|
+
actual_enum_val = op.BuiltinOptions2Type()
|
|
515
|
+
else:
|
|
516
|
+
enum_val = getattr(schema_fb.BuiltinOptions, base_type_name)
|
|
517
|
+
opts_creator = schema_fb.BuiltinOptionsCreator
|
|
518
|
+
raw_ops = op.BuiltinOptions()
|
|
519
|
+
actual_enum_val = op.BuiltinOptionsType()
|
|
520
|
+
if raw_ops is None or actual_enum_val != enum_val:
|
|
521
|
+
return None
|
|
522
|
+
return opts_creator(enum_val, raw_ops)
|
|
523
|
+
|
|
524
|
+
elif isinstance(op, schema_fb.OperatorT):
|
|
525
|
+
if is_opt_1_type:
|
|
526
|
+
raw_ops_t = op.builtinOptions
|
|
527
|
+
else:
|
|
528
|
+
raw_ops_t = op.builtinOptions2
|
|
529
|
+
if raw_ops_t is None or not isinstance(raw_ops_t, opts_type):
|
|
530
|
+
return None
|
|
531
|
+
return raw_ops_t
|
|
532
|
+
|
|
533
|
+
else:
|
|
534
|
+
return None
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
4
|
+
# source: google/protobuf/type.proto
|
|
5
|
+
# Protobuf Python Version: 6.31.1
|
|
6
|
+
"""Generated protocol buffer code."""
|
|
7
|
+
from google.protobuf import descriptor as _descriptor
|
|
8
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
10
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
6,
|
|
15
|
+
31,
|
|
16
|
+
1,
|
|
17
|
+
'',
|
|
18
|
+
'google/protobuf/type.proto'
|
|
19
|
+
)
|
|
20
|
+
# @@protoc_insertion_point(imports)
|
|
21
|
+
|
|
22
|
+
_sym_db = _symbol_database.Default()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
|
|
26
|
+
from google.protobuf import source_context_pb2 as google_dot_protobuf_dot_source__context__pb2
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1agoogle/protobuf/type.proto\x12\x0fgoogle.protobuf\x1a\x19google/protobuf/any.proto\x1a$google/protobuf/source_context.proto\"\xe8\x01\n\x04Type\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x06\x66ields\x18\x02 \x03(\x0b\x32\x16.google.protobuf.Field\x12\x0e\n\x06oneofs\x18\x03 \x03(\t\x12(\n\x07options\x18\x04 \x03(\x0b\x32\x17.google.protobuf.Option\x12\x36\n\x0esource_context\x18\x05 \x01(\x0b\x32\x1e.google.protobuf.SourceContext\x12\'\n\x06syntax\x18\x06 \x01(\x0e\x32\x17.google.protobuf.Syntax\x12\x0f\n\x07\x65\x64ition\x18\x07 \x01(\t\"\xd5\x05\n\x05\x46ield\x12)\n\x04kind\x18\x01 \x01(\x0e\x32\x1b.google.protobuf.Field.Kind\x12\x37\n\x0b\x63\x61rdinality\x18\x02 \x01(\x0e\x32\".google.protobuf.Field.Cardinality\x12\x0e\n\x06number\x18\x03 \x01(\x05\x12\x0c\n\x04name\x18\x04 \x01(\t\x12\x10\n\x08type_url\x18\x06 \x01(\t\x12\x13\n\x0boneof_index\x18\x07 \x01(\x05\x12\x0e\n\x06packed\x18\x08 \x01(\x08\x12(\n\x07options\x18\t \x03(\x0b\x32\x17.google.protobuf.Option\x12\x11\n\tjson_name\x18\n \x01(\t\x12\x15\n\rdefault_value\x18\x0b \x01(\t\"\xc8\x02\n\x04Kind\x12\x10\n\x0cTYPE_UNKNOWN\x10\x00\x12\x0f\n\x0bTYPE_DOUBLE\x10\x01\x12\x0e\n\nTYPE_FLOAT\x10\x02\x12\x0e\n\nTYPE_INT64\x10\x03\x12\x0f\n\x0bTYPE_UINT64\x10\x04\x12\x0e\n\nTYPE_INT32\x10\x05\x12\x10\n\x0cTYPE_FIXED64\x10\x06\x12\x10\n\x0cTYPE_FIXED32\x10\x07\x12\r\n\tTYPE_BOOL\x10\x08\x12\x0f\n\x0bTYPE_STRING\x10\t\x12\x0e\n\nTYPE_GROUP\x10\n\x12\x10\n\x0cTYPE_MESSAGE\x10\x0b\x12\x0e\n\nTYPE_BYTES\x10\x0c\x12\x0f\n\x0bTYPE_UINT32\x10\r\x12\r\n\tTYPE_ENUM\x10\x0e\x12\x11\n\rTYPE_SFIXED32\x10\x0f\x12\x11\n\rTYPE_SFIXED64\x10\x10\x12\x0f\n\x0bTYPE_SINT32\x10\x11\x12\x0f\n\x0bTYPE_SINT64\x10\x12\"t\n\x0b\x43\x61rdinality\x12\x17\n\x13\x43\x41RDINALITY_UNKNOWN\x10\x00\x12\x18\n\x14\x43\x41RDINALITY_OPTIONAL\x10\x01\x12\x18\n\x14\x43\x41RDINALITY_REQUIRED\x10\x02\x12\x18\n\x14\x43\x41RDINALITY_REPEATED\x10\x03\"\xdf\x01\n\x04\x45num\x12\x0c\n\x04name\x18\x01 \x01(\t\x12-\n\tenumvalue\x18\x02 \x03(\x0b\x32\x1a.google.protobuf.EnumValue\x12(\n\x07options\x18\x03 \x03(\x0b\x32\x17.google.protobuf.Option\x12\x36\n\x0esource_context\x18\x04 \x01(\x0b\x32\x1e.google.protobuf.SourceContext\x12\'\n\x06syntax\x18\x05 \x01(\x0e\x32\x17.google.protobuf.Syntax\x12\x0f\n\x07\x65\x64ition\x18\x06 \x01(\t\"S\n\tEnumValue\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06number\x18\x02 \x01(\x05\x12(\n\x07options\x18\x03 \x03(\x0b\x32\x17.google.protobuf.Option\";\n\x06Option\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.Any*C\n\x06Syntax\x12\x11\n\rSYNTAX_PROTO2\x10\x00\x12\x11\n\rSYNTAX_PROTO3\x10\x01\x12\x13\n\x0fSYNTAX_EDITIONS\x10\x02\x42{\n\x13\x63om.google.protobufB\tTypeProtoP\x01Z-google.golang.org/protobuf/types/known/typepb\xf8\x01\x01\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
|
|
30
|
+
|
|
31
|
+
_globals = globals()
|
|
32
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
33
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.type_pb2', _globals)
|
|
34
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
35
|
+
_globals['DESCRIPTOR']._loaded_options = None
|
|
36
|
+
_globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\tTypeProtoP\001Z-google.golang.org/protobuf/types/known/typepb\370\001\001\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
|
|
37
|
+
_globals['_SYNTAX']._serialized_start=1447
|
|
38
|
+
_globals['_SYNTAX']._serialized_end=1514
|
|
39
|
+
_globals['_TYPE']._serialized_start=113
|
|
40
|
+
_globals['_TYPE']._serialized_end=345
|
|
41
|
+
_globals['_FIELD']._serialized_start=348
|
|
42
|
+
_globals['_FIELD']._serialized_end=1073
|
|
43
|
+
_globals['_FIELD_KIND']._serialized_start=627
|
|
44
|
+
_globals['_FIELD_KIND']._serialized_end=955
|
|
45
|
+
_globals['_FIELD_CARDINALITY']._serialized_start=957
|
|
46
|
+
_globals['_FIELD_CARDINALITY']._serialized_end=1073
|
|
47
|
+
_globals['_ENUM']._serialized_start=1076
|
|
48
|
+
_globals['_ENUM']._serialized_end=1299
|
|
49
|
+
_globals['_ENUMVALUE']._serialized_start=1301
|
|
50
|
+
_globals['_ENUMVALUE']._serialized_end=1384
|
|
51
|
+
_globals['_OPTION']._serialized_start=1386
|
|
52
|
+
_globals['_OPTION']._serialized_end=1445
|
|
53
|
+
# @@protoc_insertion_point(module_scope)
|
|
Binary file
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
4
|
+
# source: google/protobuf/wrappers.proto
|
|
5
|
+
# Protobuf Python Version: 6.31.1
|
|
6
|
+
"""Generated protocol buffer code."""
|
|
7
|
+
from google.protobuf import descriptor as _descriptor
|
|
8
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
10
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
6,
|
|
15
|
+
31,
|
|
16
|
+
1,
|
|
17
|
+
'',
|
|
18
|
+
'google/protobuf/wrappers.proto'
|
|
19
|
+
)
|
|
20
|
+
# @@protoc_insertion_point(imports)
|
|
21
|
+
|
|
22
|
+
_sym_db = _symbol_database.Default()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1egoogle/protobuf/wrappers.proto\x12\x0fgoogle.protobuf\"\x1c\n\x0b\x44oubleValue\x12\r\n\x05value\x18\x01 \x01(\x01\"\x1b\n\nFloatValue\x12\r\n\x05value\x18\x01 \x01(\x02\"\x1b\n\nInt64Value\x12\r\n\x05value\x18\x01 \x01(\x03\"\x1c\n\x0bUInt64Value\x12\r\n\x05value\x18\x01 \x01(\x04\"\x1b\n\nInt32Value\x12\r\n\x05value\x18\x01 \x01(\x05\"\x1c\n\x0bUInt32Value\x12\r\n\x05value\x18\x01 \x01(\r\"\x1a\n\tBoolValue\x12\r\n\x05value\x18\x01 \x01(\x08\"\x1c\n\x0bStringValue\x12\r\n\x05value\x18\x01 \x01(\t\"\x1b\n\nBytesValue\x12\r\n\x05value\x18\x01 \x01(\x0c\x42\x83\x01\n\x13\x63om.google.protobufB\rWrappersProtoP\x01Z1google.golang.org/protobuf/types/known/wrapperspb\xf8\x01\x01\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
|
|
28
|
+
|
|
29
|
+
_globals = globals()
|
|
30
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.wrappers_pb2', _globals)
|
|
32
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
+
_globals['DESCRIPTOR']._loaded_options = None
|
|
34
|
+
_globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\rWrappersProtoP\001Z1google.golang.org/protobuf/types/known/wrapperspb\370\001\001\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
|
|
35
|
+
_globals['_DOUBLEVALUE']._serialized_start=51
|
|
36
|
+
_globals['_DOUBLEVALUE']._serialized_end=79
|
|
37
|
+
_globals['_FLOATVALUE']._serialized_start=81
|
|
38
|
+
_globals['_FLOATVALUE']._serialized_end=108
|
|
39
|
+
_globals['_INT64VALUE']._serialized_start=110
|
|
40
|
+
_globals['_INT64VALUE']._serialized_end=137
|
|
41
|
+
_globals['_UINT64VALUE']._serialized_start=139
|
|
42
|
+
_globals['_UINT64VALUE']._serialized_end=167
|
|
43
|
+
_globals['_INT32VALUE']._serialized_start=169
|
|
44
|
+
_globals['_INT32VALUE']._serialized_end=196
|
|
45
|
+
_globals['_UINT32VALUE']._serialized_start=198
|
|
46
|
+
_globals['_UINT32VALUE']._serialized_end=226
|
|
47
|
+
_globals['_BOOLVALUE']._serialized_start=228
|
|
48
|
+
_globals['_BOOLVALUE']._serialized_end=254
|
|
49
|
+
_globals['_STRINGVALUE']._serialized_start=256
|
|
50
|
+
_globals['_STRINGVALUE']._serialized_end=284
|
|
51
|
+
_globals['_BYTESVALUE']._serialized_start=286
|
|
52
|
+
_globals['_BYTESVALUE']._serialized_end=313
|
|
53
|
+
# @@protoc_insertion_point(module_scope)
|