ai-edge-litert-nightly 2.2.0.dev20260102__cp312-cp312-manylinux_2_27_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. ai_edge_litert/__init__.py +1 -0
  2. ai_edge_litert/_pywrap_analyzer_wrapper.so +0 -0
  3. ai_edge_litert/_pywrap_litert_compiled_model_wrapper.so +0 -0
  4. ai_edge_litert/_pywrap_litert_interpreter_wrapper.so +0 -0
  5. ai_edge_litert/_pywrap_litert_tensor_buffer_wrapper.so +0 -0
  6. ai_edge_litert/_pywrap_modify_model_interface.so +0 -0
  7. ai_edge_litert/_pywrap_string_util.so +0 -0
  8. ai_edge_litert/_pywrap_tensorflow_lite_calibration_wrapper.so +0 -0
  9. ai_edge_litert/_pywrap_tensorflow_lite_metrics_wrapper.so +0 -0
  10. ai_edge_litert/any_pb2.py +37 -0
  11. ai_edge_litert/aot/__init__.py +0 -0
  12. ai_edge_litert/aot/ai_pack/__init__.py +0 -0
  13. ai_edge_litert/aot/ai_pack/export_lib.py +300 -0
  14. ai_edge_litert/aot/aot_compile.py +153 -0
  15. ai_edge_litert/aot/core/__init__.py +0 -0
  16. ai_edge_litert/aot/core/apply_plugin.py +148 -0
  17. ai_edge_litert/aot/core/common.py +97 -0
  18. ai_edge_litert/aot/core/components.py +93 -0
  19. ai_edge_litert/aot/core/mlir_transforms.py +36 -0
  20. ai_edge_litert/aot/core/tflxx_util.py +30 -0
  21. ai_edge_litert/aot/core/types.py +374 -0
  22. ai_edge_litert/aot/prepare_for_npu.py +152 -0
  23. ai_edge_litert/aot/vendors/__init__.py +22 -0
  24. ai_edge_litert/aot/vendors/example/__init__.py +0 -0
  25. ai_edge_litert/aot/vendors/example/example_backend.py +157 -0
  26. ai_edge_litert/aot/vendors/fallback_backend.py +128 -0
  27. ai_edge_litert/aot/vendors/google_tensor/__init__.py +0 -0
  28. ai_edge_litert/aot/vendors/google_tensor/google_tensor_backend.py +168 -0
  29. ai_edge_litert/aot/vendors/google_tensor/target.py +84 -0
  30. ai_edge_litert/aot/vendors/import_vendor.py +132 -0
  31. ai_edge_litert/aot/vendors/mediatek/__init__.py +0 -0
  32. ai_edge_litert/aot/vendors/mediatek/mediatek_backend.py +196 -0
  33. ai_edge_litert/aot/vendors/mediatek/target.py +94 -0
  34. ai_edge_litert/aot/vendors/qualcomm/__init__.py +0 -0
  35. ai_edge_litert/aot/vendors/qualcomm/qualcomm_backend.py +161 -0
  36. ai_edge_litert/aot/vendors/qualcomm/target.py +75 -0
  37. ai_edge_litert/api_pb2.py +43 -0
  38. ai_edge_litert/compiled_model.py +250 -0
  39. ai_edge_litert/descriptor_pb2.py +3361 -0
  40. ai_edge_litert/duration_pb2.py +37 -0
  41. ai_edge_litert/empty_pb2.py +37 -0
  42. ai_edge_litert/field_mask_pb2.py +37 -0
  43. ai_edge_litert/format_converter_wrapper_pybind11.so +0 -0
  44. ai_edge_litert/hardware_accelerator.py +22 -0
  45. ai_edge_litert/internal/__init__.py +0 -0
  46. ai_edge_litert/internal/litertlm_builder.py +584 -0
  47. ai_edge_litert/internal/litertlm_core.py +58 -0
  48. ai_edge_litert/internal/litertlm_header_schema_py_generated.py +1596 -0
  49. ai_edge_litert/internal/llm_metadata_pb2.py +45 -0
  50. ai_edge_litert/internal/llm_model_type_pb2.py +51 -0
  51. ai_edge_litert/internal/sampler_params_pb2.py +39 -0
  52. ai_edge_litert/internal/token_pb2.py +38 -0
  53. ai_edge_litert/interpreter.py +1039 -0
  54. ai_edge_litert/libLiteRt.so +0 -0
  55. ai_edge_litert/libpywrap_litert_common.so +0 -0
  56. ai_edge_litert/metrics_interface.py +48 -0
  57. ai_edge_litert/metrics_portable.py +70 -0
  58. ai_edge_litert/model_runtime_info_pb2.py +66 -0
  59. ai_edge_litert/plugin_pb2.py +46 -0
  60. ai_edge_litert/profiling_info_pb2.py +47 -0
  61. ai_edge_litert/pywrap_genai_ops.so +0 -0
  62. ai_edge_litert/schema_py_generated.py +19640 -0
  63. ai_edge_litert/source_context_pb2.py +37 -0
  64. ai_edge_litert/struct_pb2.py +47 -0
  65. ai_edge_litert/tensor_buffer.py +167 -0
  66. ai_edge_litert/timestamp_pb2.py +37 -0
  67. ai_edge_litert/tools/__init__.py +0 -0
  68. ai_edge_litert/tools/apply_plugin_main +0 -0
  69. ai_edge_litert/tools/flatbuffer_utils.py +534 -0
  70. ai_edge_litert/type_pb2.py +53 -0
  71. ai_edge_litert/vendors/google_tensor/compiler/libLiteRtCompilerPlugin_google_tensor.so +0 -0
  72. ai_edge_litert/vendors/mediatek/compiler/libLiteRtCompilerPlugin_MediaTek.so +0 -0
  73. ai_edge_litert/vendors/qualcomm/compiler/libLiteRtCompilerPlugin_Qualcomm.so +0 -0
  74. ai_edge_litert/wrappers_pb2.py +53 -0
  75. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/METADATA +52 -0
  76. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/RECORD +78 -0
  77. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/WHEEL +5 -0
  78. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/top_level.txt +1 -0
@@ -0,0 +1,534 @@
1
+ # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utility functions for FlatBuffers.
16
+
17
+ All functions that are commonly used to work with FlatBuffers.
18
+ """
19
+
20
+ import copy
21
+ import random
22
+ import re
23
+ import struct
24
+ import sys
25
+ from typing import Optional, Type, TypeVar, Union
26
+
27
+ import flatbuffers
28
+
29
+ import os # import gfile
30
+ from ai_edge_litert import schema_py_generated as schema_fb # pylint:disable=g-direct-tensorflow-import
31
+
32
+ _TFLITE_FILE_IDENTIFIER = b'TFL3'
33
+
34
+
35
+ def get_builtin_code_from_operator_code(opcode):
36
+ """Return the builtin code of the given operator code.
37
+
38
+ The following method is introduced to resolve op builtin code shortage
39
+ problem. The new builtin operator will be assigned to the extended builtin
40
+ code field in the flatbuffer schema. Those methods helps to hide builtin code
41
+ details.
42
+
43
+ Args:
44
+ opcode: Operator code.
45
+
46
+ Returns:
47
+ The builtin code of the given operator code.
48
+ """
49
+ # Access BuiltinCode() method first if available.
50
+ if hasattr(opcode, 'BuiltinCode') and callable(opcode.BuiltinCode):
51
+ return max(opcode.BuiltinCode(), opcode.DeprecatedBuiltinCode())
52
+
53
+ return max(opcode.builtinCode, opcode.deprecatedBuiltinCode)
54
+
55
+
56
+ def convert_bytearray_to_object(model_bytearray):
57
+ """Converts a tflite model from a bytearray to an object for parsing."""
58
+ model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
59
+ return schema_fb.ModelT.InitFromObj(model_object)
60
+
61
+
62
+ def read_model(input_tflite_file):
63
+ """Reads a tflite model as a python object.
64
+
65
+ Args:
66
+ input_tflite_file: Full path name to the input tflite file
67
+
68
+ Raises:
69
+ RuntimeError: If input_tflite_file path is invalid.
70
+ IOError: If input_tflite_file cannot be opened.
71
+
72
+ Returns:
73
+ A python object corresponding to the input tflite file.
74
+ """
75
+ if not os.path.exists(input_tflite_file):
76
+ raise RuntimeError('Input file not found at %r\n' % input_tflite_file)
77
+ with open(input_tflite_file, 'rb') as input_file_handle:
78
+ model_bytearray = bytearray(input_file_handle.read())
79
+ return read_model_from_bytearray(model_bytearray)
80
+
81
+
82
+ def read_model_from_bytearray(model_bytearray):
83
+ """Reads a tflite model as a python object.
84
+
85
+ Args:
86
+ model_bytearray: TFLite model in bytearray format.
87
+
88
+ Returns:
89
+ A python object corresponding to the input tflite file.
90
+ """
91
+ model = convert_bytearray_to_object(model_bytearray)
92
+ if sys.byteorder == 'big':
93
+ byte_swap_tflite_model_obj(model, 'little', 'big')
94
+
95
+ # Offset handling for models > 2GB
96
+ for buffer in model.buffers:
97
+ if buffer.offset:
98
+ buffer.data = model_bytearray[buffer.offset : buffer.offset + buffer.size]
99
+ buffer.offset = 0
100
+ buffer.size = 0
101
+ for subgraph in model.subgraphs:
102
+ for op in subgraph.operators:
103
+ if op.largeCustomOptionsOffset:
104
+ op.customOptions = model_bytearray[
105
+ op.largeCustomOptionsOffset : op.largeCustomOptionsOffset
106
+ + op.largeCustomOptionsSize
107
+ ]
108
+ op.largeCustomOptionsOffset = 0
109
+ op.largeCustomOptionsSize = 0
110
+
111
+ return model
112
+
113
+
114
+ def read_model_with_mutable_tensors(input_tflite_file):
115
+ """Reads a tflite model as a python object with mutable tensors.
116
+
117
+ Similar to read_model() with the addition that the returned object has
118
+ mutable tensors (read_model() returns an object with immutable tensors).
119
+
120
+ NOTE: This API only works for TFLite generated with
121
+ _experimental_use_buffer_offset=false
122
+
123
+ Args:
124
+ input_tflite_file: Full path name to the input tflite file
125
+
126
+ Raises:
127
+ RuntimeError: If input_tflite_file path is invalid.
128
+ IOError: If input_tflite_file cannot be opened.
129
+
130
+ Returns:
131
+ A mutable python object corresponding to the input tflite file.
132
+ """
133
+ return copy.deepcopy(read_model(input_tflite_file))
134
+
135
+
136
+ def convert_object_to_bytearray(model_object, extra_buffer=b''):
137
+ """Converts a tflite model from an object to a immutable bytearray."""
138
+ # Initial size of the buffer, which will grow automatically if needed
139
+ builder = flatbuffers.Builder(1024)
140
+ model_offset = model_object.Pack(builder)
141
+ builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
142
+ model_bytearray = bytes(builder.Output())
143
+ model_bytearray = model_bytearray + extra_buffer
144
+ return model_bytearray
145
+
146
+
147
+ def write_model(model_object, output_tflite_file):
148
+ """Writes the tflite model, a python object, into the output file.
149
+
150
+ NOTE: This API only works for TFLite generated with
151
+ _experimental_use_buffer_offset=false
152
+
153
+ Args:
154
+ model_object: A tflite model as a python object
155
+ output_tflite_file: Full path name to the output tflite file.
156
+
157
+ Raises:
158
+ IOError: If output_tflite_file path is invalid or cannot be opened.
159
+ """
160
+ if sys.byteorder == 'big':
161
+ model_object = copy.deepcopy(model_object)
162
+ byte_swap_tflite_model_obj(model_object, 'big', 'little')
163
+ model_bytearray = convert_object_to_bytearray(model_object)
164
+ with open(output_tflite_file, 'wb') as output_file_handle:
165
+ output_file_handle.write(model_bytearray)
166
+
167
+
168
+ def strip_strings(model):
169
+ """Strips all nonessential strings from the model to reduce model size.
170
+
171
+ We remove the following strings:
172
+ (find strings by searching ":string" in the tensorflow lite flatbuffer schema)
173
+ 1. Model description
174
+ 2. SubGraph name
175
+ 3. Tensor names
176
+ We retain OperatorCode custom_code and Metadata name.
177
+
178
+ Args:
179
+ model: The model from which to remove nonessential strings.
180
+ """
181
+
182
+ model.description = None
183
+ for subgraph in model.subgraphs:
184
+ subgraph.name = None
185
+ for tensor in subgraph.tensors:
186
+ tensor.name = None
187
+ # We clear all signature_def structure, since without names it is useless.
188
+ model.signatureDefs = None
189
+
190
+
191
+ def type_to_name(tensor_type):
192
+ """Converts a numerical enum to a readable tensor type."""
193
+ for name, value in schema_fb.TensorType.__dict__.items():
194
+ if value == tensor_type:
195
+ return name
196
+ return None
197
+
198
+
199
+ def randomize_weights(model, random_seed=0, buffers_to_skip=None):
200
+ """Randomize weights in a model.
201
+
202
+ Args:
203
+ model: The model in which to randomize weights.
204
+ random_seed: The input to the random number generator (default value is 0).
205
+ buffers_to_skip: The list of buffer indices to skip. The weights in these
206
+ buffers are left unmodified.
207
+ """
208
+
209
+ # The input to the random seed generator. The default value is 0.
210
+ random.seed(random_seed)
211
+
212
+ # Parse model buffers which store the model weights
213
+ buffers = model.buffers
214
+ buffer_ids = range(1, len(buffers)) # ignore index 0 as it's always None
215
+ if buffers_to_skip is not None:
216
+ buffer_ids = [idx for idx in buffer_ids if idx not in buffers_to_skip]
217
+
218
+ buffer_types = {}
219
+ for graph in model.subgraphs:
220
+ for op in graph.operators:
221
+ if op.inputs is None:
222
+ break
223
+ for input_idx in op.inputs:
224
+ tensor = graph.tensors[input_idx]
225
+ buffer_types[tensor.buffer] = type_to_name(tensor.type)
226
+
227
+ for i in buffer_ids:
228
+ buffer_i_data = buffers[i].data
229
+ buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size
230
+ if buffer_i_size == 0:
231
+ continue
232
+
233
+ # Raw data buffers are of type ubyte (or uint8) whose values lie in the
234
+ # range [0, 255]. Those ubytes (or unint8s) are the underlying
235
+ # representation of each datatype. For example, a bias tensor of type
236
+ # int32 appears as a buffer 4 times it's length of type ubyte (or uint8).
237
+ # For floats, we need to generate a valid float and then pack it into
238
+ # the raw bytes in place.
239
+ buffer_type = buffer_types.get(i, 'INT8')
240
+ if buffer_type and buffer_type.startswith('FLOAT'):
241
+ format_code = 'e' if buffer_type == 'FLOAT16' else 'f'
242
+ for offset in range(0, buffer_i_size, struct.calcsize(format_code)):
243
+ value = random.uniform(-0.5, 0.5) # See http://b/152324470#comment2
244
+ struct.pack_into(format_code, buffer_i_data, offset, value)
245
+ else:
246
+ for j in range(buffer_i_size):
247
+ buffer_i_data[j] = random.randint(0, 255)
248
+
249
+
250
+ def rename_custom_ops(model, map_custom_op_renames):
251
+ """Rename custom ops so they use the same naming style as builtin ops.
252
+
253
+ Args:
254
+ model: The input tflite model.
255
+ map_custom_op_renames: A mapping from old to new custom op names.
256
+ """
257
+ for op_code in model.operatorCodes:
258
+ if op_code.customCode:
259
+ op_code_str = op_code.customCode.decode('ascii')
260
+ if op_code_str in map_custom_op_renames:
261
+ op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii')
262
+
263
+
264
+ def opcode_to_name(model, op_code):
265
+ """Converts a TFLite op_code to the human readable name.
266
+
267
+ Args:
268
+ model: The input tflite model.
269
+ op_code: The op_code to resolve to a readable name.
270
+
271
+ Returns:
272
+ A string containing the human readable op name, or None if not resolvable.
273
+ """
274
+ op = model.operatorCodes[op_code]
275
+ code = max(op.builtinCode, op.deprecatedBuiltinCode)
276
+ for name, value in vars(schema_fb.BuiltinOperator).items():
277
+ if value == code:
278
+ return name
279
+ return None
280
+
281
+
282
+ def xxd_output_to_bytes(input_cc_file):
283
+ """Converts xxd output C++ source file to bytes (immutable).
284
+
285
+ Args:
286
+ input_cc_file: Full path name to th C++ source file dumped by xxd
287
+
288
+ Raises:
289
+ RuntimeError: If input_cc_file path is invalid.
290
+ IOError: If input_cc_file cannot be opened.
291
+
292
+ Returns:
293
+ A bytearray corresponding to the input cc file array.
294
+ """
295
+ # Match hex values in the string with comma as separator
296
+ pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*')
297
+
298
+ model_bytearray = bytearray()
299
+
300
+ with open(input_cc_file) as file_handle:
301
+ for line in file_handle:
302
+ values_match = pattern.match(line)
303
+
304
+ if values_match is None:
305
+ continue
306
+
307
+ # Match in the parentheses (hex array only)
308
+ list_text = values_match.group(1)
309
+
310
+ # Extract hex values (text) from the line
311
+ # e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c,
312
+ values_text = filter(None, list_text.split(','))
313
+
314
+ # Convert to hex
315
+ values = [int(x, base=16) for x in values_text]
316
+ model_bytearray.extend(values)
317
+
318
+ return bytes(model_bytearray)
319
+
320
+
321
+ def xxd_output_to_object(input_cc_file):
322
+ """Converts xxd output C++ source file to object.
323
+
324
+ Args:
325
+ input_cc_file: Full path name to th C++ source file dumped by xxd
326
+
327
+ Raises:
328
+ RuntimeError: If input_cc_file path is invalid.
329
+ IOError: If input_cc_file cannot be opened.
330
+
331
+ Returns:
332
+ A python object corresponding to the input tflite file.
333
+ """
334
+ model_bytes = xxd_output_to_bytes(input_cc_file)
335
+ return convert_bytearray_to_object(model_bytes)
336
+
337
+
338
+ def byte_swap_buffer_content(buffer, chunksize, from_endiness, to_endiness):
339
+ """Helper function for byte-swapping the buffers field."""
340
+ to_swap = [
341
+ buffer.data[i : i + chunksize]
342
+ for i in range(0, len(buffer.data), chunksize)
343
+ ]
344
+ buffer.data = b''.join([
345
+ int.from_bytes(byteswap, from_endiness).to_bytes(chunksize, to_endiness)
346
+ for byteswap in to_swap
347
+ ])
348
+
349
+
350
+ def byte_swap_string_content(buffer, from_endiness, to_endiness):
351
+ """Helper function for byte-swapping the string buffer.
352
+
353
+ Args:
354
+ buffer: TFLite string buffer of from_endiness format.
355
+ from_endiness: The original endianness format of the string buffer.
356
+ to_endiness: The destined endianness format of the string buffer.
357
+ """
358
+ num_of_strings = int.from_bytes(buffer.data[0:4], from_endiness)
359
+ string_content = bytearray(buffer.data[4 * (num_of_strings + 2) :])
360
+ prefix_data = b''.join([
361
+ int.from_bytes(buffer.data[i : i + 4], from_endiness).to_bytes(
362
+ 4, to_endiness
363
+ )
364
+ for i in range(0, (num_of_strings + 1) * 4 + 1, 4)
365
+ ])
366
+ buffer.data = prefix_data + string_content
367
+
368
+
369
+ def byte_swap_tflite_model_obj(model, from_endiness, to_endiness):
370
+ """Byte swaps the buffers field in a TFLite model.
371
+
372
+ Args:
373
+ model: TFLite model object of from_endiness format.
374
+ from_endiness: The original endianness format of the buffers in model.
375
+ to_endiness: The destined endianness format of the buffers in model.
376
+ """
377
+ if model is None:
378
+ return
379
+ # Get all the constant buffers, byte swapping them as per their data types
380
+ buffer_swapped = []
381
+ types_of_16_bits = [
382
+ schema_fb.TensorType.FLOAT16,
383
+ schema_fb.TensorType.INT16,
384
+ schema_fb.TensorType.UINT16,
385
+ ]
386
+ types_of_32_bits = [
387
+ schema_fb.TensorType.FLOAT32,
388
+ schema_fb.TensorType.INT32,
389
+ schema_fb.TensorType.COMPLEX64,
390
+ schema_fb.TensorType.UINT32,
391
+ ]
392
+ types_of_64_bits = [
393
+ schema_fb.TensorType.INT64,
394
+ schema_fb.TensorType.FLOAT64,
395
+ schema_fb.TensorType.COMPLEX128,
396
+ schema_fb.TensorType.UINT64,
397
+ ]
398
+ for subgraph in model.subgraphs:
399
+ for tensor in subgraph.tensors:
400
+ if (
401
+ tensor.buffer > 0
402
+ and tensor.buffer < len(model.buffers)
403
+ and tensor.buffer not in buffer_swapped
404
+ and model.buffers[tensor.buffer].data is not None
405
+ ):
406
+ if tensor.type == schema_fb.TensorType.STRING:
407
+ byte_swap_string_content(
408
+ model.buffers[tensor.buffer], from_endiness, to_endiness
409
+ )
410
+ elif tensor.type in types_of_16_bits:
411
+ byte_swap_buffer_content(
412
+ model.buffers[tensor.buffer], 2, from_endiness, to_endiness
413
+ )
414
+ elif tensor.type in types_of_32_bits:
415
+ byte_swap_buffer_content(
416
+ model.buffers[tensor.buffer], 4, from_endiness, to_endiness
417
+ )
418
+ elif tensor.type in types_of_64_bits:
419
+ byte_swap_buffer_content(
420
+ model.buffers[tensor.buffer], 8, from_endiness, to_endiness
421
+ )
422
+ else:
423
+ continue
424
+ buffer_swapped.append(tensor.buffer)
425
+
426
+
427
+ def byte_swap_tflite_buffer(tflite_model, from_endiness, to_endiness):
428
+ """Generates a new model byte array after byte swapping its buffers field.
429
+
430
+ Args:
431
+ tflite_model: TFLite flatbuffer in a byte array.
432
+ from_endiness: The original endianness format of the buffers in
433
+ tflite_model.
434
+ to_endiness: The destined endianness format of the buffers in tflite_model.
435
+
436
+ Returns:
437
+ TFLite flatbuffer in a byte array, after being byte swapped to to_endiness
438
+ format.
439
+ """
440
+ if tflite_model is None:
441
+ return None
442
+ # Load TFLite Flatbuffer byte array into an object.
443
+ model = convert_bytearray_to_object(tflite_model)
444
+
445
+ # Byte swapping the constant buffers as per their data types
446
+ byte_swap_tflite_model_obj(model, from_endiness, to_endiness)
447
+
448
+ # Return a TFLite flatbuffer as a byte array.
449
+ return convert_object_to_bytearray(model)
450
+
451
+
452
+ def count_resource_variables(model):
453
+ """Calculates the number of unique resource variables in a model.
454
+
455
+ Args:
456
+ model: the input tflite model, either as bytearray or object.
457
+
458
+ Returns:
459
+ An integer number representing the number of unique resource variables.
460
+ """
461
+ if not isinstance(model, schema_fb.ModelT):
462
+ model = convert_bytearray_to_object(model)
463
+ unique_shared_names = set()
464
+ for subgraph in model.subgraphs:
465
+ if subgraph.operators is None:
466
+ continue
467
+ for op in subgraph.operators:
468
+ builtin_code = get_builtin_code_from_operator_code(
469
+ model.operatorCodes[op.opcodeIndex]
470
+ )
471
+ if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE:
472
+ unique_shared_names.add(op.builtinOptions.sharedName)
473
+ return len(unique_shared_names)
474
+
475
+
476
+ OptsT = TypeVar('OptsT')
477
+
478
+
479
+ def get_options_as(
480
+ op: Union[schema_fb.Operator, schema_fb.OperatorT], opts_type: Type[OptsT]
481
+ ) -> Optional[OptsT]:
482
+ """Get the options of an operator as the specified type.
483
+
484
+ Requested type must be an object-api type (ends in 'T').
485
+
486
+ Args:
487
+ op: The operator to get the options from.
488
+ opts_type: The type of the options to get.
489
+
490
+ Returns:
491
+ The options as the specified type, or None if the options are not of the
492
+ specified type.
493
+
494
+ Raises:
495
+ ValueError: If the specified type is not a valid options type.
496
+ """
497
+
498
+ err = ValueError(f'Unsupported options type: {opts_type}')
499
+ type_name: str = opts_type.__name__
500
+ if not type_name.endswith('T'):
501
+ raise err
502
+ base_type_name = type_name.removesuffix('T')
503
+ is_opt_1_type = hasattr(schema_fb.BuiltinOptions, base_type_name)
504
+ if not is_opt_1_type and not hasattr(
505
+ schema_fb.BuiltinOptions2, base_type_name
506
+ ):
507
+ raise err
508
+
509
+ if isinstance(op, schema_fb.Operator):
510
+ if not is_opt_1_type:
511
+ enum_val = getattr(schema_fb.BuiltinOptions2, base_type_name)
512
+ opts_creator = schema_fb.BuiltinOptions2Creator
513
+ raw_ops = op.BuiltinOptions2()
514
+ actual_enum_val = op.BuiltinOptions2Type()
515
+ else:
516
+ enum_val = getattr(schema_fb.BuiltinOptions, base_type_name)
517
+ opts_creator = schema_fb.BuiltinOptionsCreator
518
+ raw_ops = op.BuiltinOptions()
519
+ actual_enum_val = op.BuiltinOptionsType()
520
+ if raw_ops is None or actual_enum_val != enum_val:
521
+ return None
522
+ return opts_creator(enum_val, raw_ops)
523
+
524
+ elif isinstance(op, schema_fb.OperatorT):
525
+ if is_opt_1_type:
526
+ raw_ops_t = op.builtinOptions
527
+ else:
528
+ raw_ops_t = op.builtinOptions2
529
+ if raw_ops_t is None or not isinstance(raw_ops_t, opts_type):
530
+ return None
531
+ return raw_ops_t
532
+
533
+ else:
534
+ return None
@@ -0,0 +1,53 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: google/protobuf/type.proto
5
+ # Protobuf Python Version: 6.31.1
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 6,
15
+ 31,
16
+ 1,
17
+ '',
18
+ 'google/protobuf/type.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
26
+ from google.protobuf import source_context_pb2 as google_dot_protobuf_dot_source__context__pb2
27
+
28
+
29
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1agoogle/protobuf/type.proto\x12\x0fgoogle.protobuf\x1a\x19google/protobuf/any.proto\x1a$google/protobuf/source_context.proto\"\xe8\x01\n\x04Type\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x06\x66ields\x18\x02 \x03(\x0b\x32\x16.google.protobuf.Field\x12\x0e\n\x06oneofs\x18\x03 \x03(\t\x12(\n\x07options\x18\x04 \x03(\x0b\x32\x17.google.protobuf.Option\x12\x36\n\x0esource_context\x18\x05 \x01(\x0b\x32\x1e.google.protobuf.SourceContext\x12\'\n\x06syntax\x18\x06 \x01(\x0e\x32\x17.google.protobuf.Syntax\x12\x0f\n\x07\x65\x64ition\x18\x07 \x01(\t\"\xd5\x05\n\x05\x46ield\x12)\n\x04kind\x18\x01 \x01(\x0e\x32\x1b.google.protobuf.Field.Kind\x12\x37\n\x0b\x63\x61rdinality\x18\x02 \x01(\x0e\x32\".google.protobuf.Field.Cardinality\x12\x0e\n\x06number\x18\x03 \x01(\x05\x12\x0c\n\x04name\x18\x04 \x01(\t\x12\x10\n\x08type_url\x18\x06 \x01(\t\x12\x13\n\x0boneof_index\x18\x07 \x01(\x05\x12\x0e\n\x06packed\x18\x08 \x01(\x08\x12(\n\x07options\x18\t \x03(\x0b\x32\x17.google.protobuf.Option\x12\x11\n\tjson_name\x18\n \x01(\t\x12\x15\n\rdefault_value\x18\x0b \x01(\t\"\xc8\x02\n\x04Kind\x12\x10\n\x0cTYPE_UNKNOWN\x10\x00\x12\x0f\n\x0bTYPE_DOUBLE\x10\x01\x12\x0e\n\nTYPE_FLOAT\x10\x02\x12\x0e\n\nTYPE_INT64\x10\x03\x12\x0f\n\x0bTYPE_UINT64\x10\x04\x12\x0e\n\nTYPE_INT32\x10\x05\x12\x10\n\x0cTYPE_FIXED64\x10\x06\x12\x10\n\x0cTYPE_FIXED32\x10\x07\x12\r\n\tTYPE_BOOL\x10\x08\x12\x0f\n\x0bTYPE_STRING\x10\t\x12\x0e\n\nTYPE_GROUP\x10\n\x12\x10\n\x0cTYPE_MESSAGE\x10\x0b\x12\x0e\n\nTYPE_BYTES\x10\x0c\x12\x0f\n\x0bTYPE_UINT32\x10\r\x12\r\n\tTYPE_ENUM\x10\x0e\x12\x11\n\rTYPE_SFIXED32\x10\x0f\x12\x11\n\rTYPE_SFIXED64\x10\x10\x12\x0f\n\x0bTYPE_SINT32\x10\x11\x12\x0f\n\x0bTYPE_SINT64\x10\x12\"t\n\x0b\x43\x61rdinality\x12\x17\n\x13\x43\x41RDINALITY_UNKNOWN\x10\x00\x12\x18\n\x14\x43\x41RDINALITY_OPTIONAL\x10\x01\x12\x18\n\x14\x43\x41RDINALITY_REQUIRED\x10\x02\x12\x18\n\x14\x43\x41RDINALITY_REPEATED\x10\x03\"\xdf\x01\n\x04\x45num\x12\x0c\n\x04name\x18\x01 \x01(\t\x12-\n\tenumvalue\x18\x02 \x03(\x0b\x32\x1a.google.protobuf.EnumValue\x12(\n\x07options\x18\x03 \x03(\x0b\x32\x17.google.protobuf.Option\x12\x36\n\x0esource_context\x18\x04 \x01(\x0b\x32\x1e.google.protobuf.SourceContext\x12\'\n\x06syntax\x18\x05 \x01(\x0e\x32\x17.google.protobuf.Syntax\x12\x0f\n\x07\x65\x64ition\x18\x06 \x01(\t\"S\n\tEnumValue\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06number\x18\x02 \x01(\x05\x12(\n\x07options\x18\x03 \x03(\x0b\x32\x17.google.protobuf.Option\";\n\x06Option\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.Any*C\n\x06Syntax\x12\x11\n\rSYNTAX_PROTO2\x10\x00\x12\x11\n\rSYNTAX_PROTO3\x10\x01\x12\x13\n\x0fSYNTAX_EDITIONS\x10\x02\x42{\n\x13\x63om.google.protobufB\tTypeProtoP\x01Z-google.golang.org/protobuf/types/known/typepb\xf8\x01\x01\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
30
+
31
+ _globals = globals()
32
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
33
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.type_pb2', _globals)
34
+ if not _descriptor._USE_C_DESCRIPTORS:
35
+ _globals['DESCRIPTOR']._loaded_options = None
36
+ _globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\tTypeProtoP\001Z-google.golang.org/protobuf/types/known/typepb\370\001\001\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
37
+ _globals['_SYNTAX']._serialized_start=1447
38
+ _globals['_SYNTAX']._serialized_end=1514
39
+ _globals['_TYPE']._serialized_start=113
40
+ _globals['_TYPE']._serialized_end=345
41
+ _globals['_FIELD']._serialized_start=348
42
+ _globals['_FIELD']._serialized_end=1073
43
+ _globals['_FIELD_KIND']._serialized_start=627
44
+ _globals['_FIELD_KIND']._serialized_end=955
45
+ _globals['_FIELD_CARDINALITY']._serialized_start=957
46
+ _globals['_FIELD_CARDINALITY']._serialized_end=1073
47
+ _globals['_ENUM']._serialized_start=1076
48
+ _globals['_ENUM']._serialized_end=1299
49
+ _globals['_ENUMVALUE']._serialized_start=1301
50
+ _globals['_ENUMVALUE']._serialized_end=1384
51
+ _globals['_OPTION']._serialized_start=1386
52
+ _globals['_OPTION']._serialized_end=1445
53
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,53 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: google/protobuf/wrappers.proto
5
+ # Protobuf Python Version: 6.31.1
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 6,
15
+ 31,
16
+ 1,
17
+ '',
18
+ 'google/protobuf/wrappers.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+
26
+
27
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1egoogle/protobuf/wrappers.proto\x12\x0fgoogle.protobuf\"\x1c\n\x0b\x44oubleValue\x12\r\n\x05value\x18\x01 \x01(\x01\"\x1b\n\nFloatValue\x12\r\n\x05value\x18\x01 \x01(\x02\"\x1b\n\nInt64Value\x12\r\n\x05value\x18\x01 \x01(\x03\"\x1c\n\x0bUInt64Value\x12\r\n\x05value\x18\x01 \x01(\x04\"\x1b\n\nInt32Value\x12\r\n\x05value\x18\x01 \x01(\x05\"\x1c\n\x0bUInt32Value\x12\r\n\x05value\x18\x01 \x01(\r\"\x1a\n\tBoolValue\x12\r\n\x05value\x18\x01 \x01(\x08\"\x1c\n\x0bStringValue\x12\r\n\x05value\x18\x01 \x01(\t\"\x1b\n\nBytesValue\x12\r\n\x05value\x18\x01 \x01(\x0c\x42\x83\x01\n\x13\x63om.google.protobufB\rWrappersProtoP\x01Z1google.golang.org/protobuf/types/known/wrapperspb\xf8\x01\x01\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
28
+
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.wrappers_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ _globals['DESCRIPTOR']._loaded_options = None
34
+ _globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\rWrappersProtoP\001Z1google.golang.org/protobuf/types/known/wrapperspb\370\001\001\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
35
+ _globals['_DOUBLEVALUE']._serialized_start=51
36
+ _globals['_DOUBLEVALUE']._serialized_end=79
37
+ _globals['_FLOATVALUE']._serialized_start=81
38
+ _globals['_FLOATVALUE']._serialized_end=108
39
+ _globals['_INT64VALUE']._serialized_start=110
40
+ _globals['_INT64VALUE']._serialized_end=137
41
+ _globals['_UINT64VALUE']._serialized_start=139
42
+ _globals['_UINT64VALUE']._serialized_end=167
43
+ _globals['_INT32VALUE']._serialized_start=169
44
+ _globals['_INT32VALUE']._serialized_end=196
45
+ _globals['_UINT32VALUE']._serialized_start=198
46
+ _globals['_UINT32VALUE']._serialized_end=226
47
+ _globals['_BOOLVALUE']._serialized_start=228
48
+ _globals['_BOOLVALUE']._serialized_end=254
49
+ _globals['_STRINGVALUE']._serialized_start=256
50
+ _globals['_STRINGVALUE']._serialized_end=284
51
+ _globals['_BYTESVALUE']._serialized_start=286
52
+ _globals['_BYTESVALUE']._serialized_end=313
53
+ # @@protoc_insertion_point(module_scope)