ai-edge-litert-nightly 2.2.0.dev20260102__cp312-cp312-manylinux_2_27_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_litert/__init__.py +1 -0
- ai_edge_litert/_pywrap_analyzer_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_compiled_model_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_interpreter_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_tensor_buffer_wrapper.so +0 -0
- ai_edge_litert/_pywrap_modify_model_interface.so +0 -0
- ai_edge_litert/_pywrap_string_util.so +0 -0
- ai_edge_litert/_pywrap_tensorflow_lite_calibration_wrapper.so +0 -0
- ai_edge_litert/_pywrap_tensorflow_lite_metrics_wrapper.so +0 -0
- ai_edge_litert/any_pb2.py +37 -0
- ai_edge_litert/aot/__init__.py +0 -0
- ai_edge_litert/aot/ai_pack/__init__.py +0 -0
- ai_edge_litert/aot/ai_pack/export_lib.py +300 -0
- ai_edge_litert/aot/aot_compile.py +153 -0
- ai_edge_litert/aot/core/__init__.py +0 -0
- ai_edge_litert/aot/core/apply_plugin.py +148 -0
- ai_edge_litert/aot/core/common.py +97 -0
- ai_edge_litert/aot/core/components.py +93 -0
- ai_edge_litert/aot/core/mlir_transforms.py +36 -0
- ai_edge_litert/aot/core/tflxx_util.py +30 -0
- ai_edge_litert/aot/core/types.py +374 -0
- ai_edge_litert/aot/prepare_for_npu.py +152 -0
- ai_edge_litert/aot/vendors/__init__.py +22 -0
- ai_edge_litert/aot/vendors/example/__init__.py +0 -0
- ai_edge_litert/aot/vendors/example/example_backend.py +157 -0
- ai_edge_litert/aot/vendors/fallback_backend.py +128 -0
- ai_edge_litert/aot/vendors/google_tensor/__init__.py +0 -0
- ai_edge_litert/aot/vendors/google_tensor/google_tensor_backend.py +168 -0
- ai_edge_litert/aot/vendors/google_tensor/target.py +84 -0
- ai_edge_litert/aot/vendors/import_vendor.py +132 -0
- ai_edge_litert/aot/vendors/mediatek/__init__.py +0 -0
- ai_edge_litert/aot/vendors/mediatek/mediatek_backend.py +196 -0
- ai_edge_litert/aot/vendors/mediatek/target.py +94 -0
- ai_edge_litert/aot/vendors/qualcomm/__init__.py +0 -0
- ai_edge_litert/aot/vendors/qualcomm/qualcomm_backend.py +161 -0
- ai_edge_litert/aot/vendors/qualcomm/target.py +75 -0
- ai_edge_litert/api_pb2.py +43 -0
- ai_edge_litert/compiled_model.py +250 -0
- ai_edge_litert/descriptor_pb2.py +3361 -0
- ai_edge_litert/duration_pb2.py +37 -0
- ai_edge_litert/empty_pb2.py +37 -0
- ai_edge_litert/field_mask_pb2.py +37 -0
- ai_edge_litert/format_converter_wrapper_pybind11.so +0 -0
- ai_edge_litert/hardware_accelerator.py +22 -0
- ai_edge_litert/internal/__init__.py +0 -0
- ai_edge_litert/internal/litertlm_builder.py +584 -0
- ai_edge_litert/internal/litertlm_core.py +58 -0
- ai_edge_litert/internal/litertlm_header_schema_py_generated.py +1596 -0
- ai_edge_litert/internal/llm_metadata_pb2.py +45 -0
- ai_edge_litert/internal/llm_model_type_pb2.py +51 -0
- ai_edge_litert/internal/sampler_params_pb2.py +39 -0
- ai_edge_litert/internal/token_pb2.py +38 -0
- ai_edge_litert/interpreter.py +1039 -0
- ai_edge_litert/libLiteRt.so +0 -0
- ai_edge_litert/libpywrap_litert_common.so +0 -0
- ai_edge_litert/metrics_interface.py +48 -0
- ai_edge_litert/metrics_portable.py +70 -0
- ai_edge_litert/model_runtime_info_pb2.py +66 -0
- ai_edge_litert/plugin_pb2.py +46 -0
- ai_edge_litert/profiling_info_pb2.py +47 -0
- ai_edge_litert/pywrap_genai_ops.so +0 -0
- ai_edge_litert/schema_py_generated.py +19640 -0
- ai_edge_litert/source_context_pb2.py +37 -0
- ai_edge_litert/struct_pb2.py +47 -0
- ai_edge_litert/tensor_buffer.py +167 -0
- ai_edge_litert/timestamp_pb2.py +37 -0
- ai_edge_litert/tools/__init__.py +0 -0
- ai_edge_litert/tools/apply_plugin_main +0 -0
- ai_edge_litert/tools/flatbuffer_utils.py +534 -0
- ai_edge_litert/type_pb2.py +53 -0
- ai_edge_litert/vendors/google_tensor/compiler/libLiteRtCompilerPlugin_google_tensor.so +0 -0
- ai_edge_litert/vendors/mediatek/compiler/libLiteRtCompilerPlugin_MediaTek.so +0 -0
- ai_edge_litert/vendors/qualcomm/compiler/libLiteRtCompilerPlugin_Qualcomm.so +0 -0
- ai_edge_litert/wrappers_pb2.py +53 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/METADATA +52 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/RECORD +78 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/WHEEL +5 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "2.2.0.dev20260102"
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
4
|
+
# source: google/protobuf/any.proto
|
|
5
|
+
# Protobuf Python Version: 6.31.1
|
|
6
|
+
"""Generated protocol buffer code."""
|
|
7
|
+
from google.protobuf import descriptor as _descriptor
|
|
8
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
10
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
6,
|
|
15
|
+
31,
|
|
16
|
+
1,
|
|
17
|
+
'',
|
|
18
|
+
'google/protobuf/any.proto'
|
|
19
|
+
)
|
|
20
|
+
# @@protoc_insertion_point(imports)
|
|
21
|
+
|
|
22
|
+
_sym_db = _symbol_database.Default()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19google/protobuf/any.proto\x12\x0fgoogle.protobuf\"&\n\x03\x41ny\x12\x10\n\x08type_url\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x42v\n\x13\x63om.google.protobufB\x08\x41nyProtoP\x01Z,google.golang.org/protobuf/types/known/anypb\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
|
|
28
|
+
|
|
29
|
+
_globals = globals()
|
|
30
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.any_pb2', _globals)
|
|
32
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
+
_globals['DESCRIPTOR']._loaded_options = None
|
|
34
|
+
_globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\010AnyProtoP\001Z,google.golang.org/protobuf/types/known/anypb\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
|
|
35
|
+
_globals['_ANY']._serialized_start=46
|
|
36
|
+
_globals['_ANY']._serialized_end=84
|
|
37
|
+
# @@protoc_insertion_point(module_scope)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Utility functions for exporting models to AI pack format."""
|
|
15
|
+
|
|
16
|
+
import itertools
|
|
17
|
+
import os
|
|
18
|
+
import pathlib
|
|
19
|
+
from typing import cast
|
|
20
|
+
|
|
21
|
+
from ai_edge_litert.aot.core import common
|
|
22
|
+
from ai_edge_litert.aot.core import types
|
|
23
|
+
from ai_edge_litert.aot.vendors import fallback_backend
|
|
24
|
+
from ai_edge_litert.aot.vendors.google_tensor import target as google_tensor_target
|
|
25
|
+
from ai_edge_litert.aot.vendors.mediatek import mediatek_backend
|
|
26
|
+
from ai_edge_litert.aot.vendors.mediatek import target as mtk_target
|
|
27
|
+
from ai_edge_litert.aot.vendors.qualcomm import qualcomm_backend
|
|
28
|
+
from ai_edge_litert.aot.vendors.qualcomm import target as qnn_target
|
|
29
|
+
|
|
30
|
+
# TODO: b/407453529 - Add unittests.
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
_DEVICE_TARGETING_CONFIGURATION = """<config:device-targeting-config
|
|
34
|
+
xmlns:config="http://schemas.android.com/apk/config">
|
|
35
|
+
{device_groups}
|
|
36
|
+
</config:device-targeting-config>"""
|
|
37
|
+
|
|
38
|
+
_DEVICE_GROUP_TEMPLATE = """ <config:device-group name="{device_group_name}">
|
|
39
|
+
{device_selectors}
|
|
40
|
+
</config:device-group>"""
|
|
41
|
+
|
|
42
|
+
_DEVICE_SELECTOR_TEMPLATE = """ <config:device-selector>
|
|
43
|
+
<config:system-on-chip manufacturer="{soc_man}" model="{soc_model}"/>
|
|
44
|
+
</config:device-selector>"""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _is_mobile_device_backend(backend: types.Backend):
|
|
48
|
+
target = backend.target
|
|
49
|
+
if backend.id() == qualcomm_backend.QualcommBackend.id():
|
|
50
|
+
target = cast(qnn_target.Target, target)
|
|
51
|
+
# Non Android QNN targets.
|
|
52
|
+
if target.soc_model in (
|
|
53
|
+
qnn_target.SocModel.SA8255,
|
|
54
|
+
qnn_target.SocModel.SA8295,
|
|
55
|
+
):
|
|
56
|
+
return False
|
|
57
|
+
return True
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _export_model_files_to_ai_pack(
|
|
61
|
+
compiled_models: types.CompilationResult,
|
|
62
|
+
ai_pack_dir: pathlib.Path,
|
|
63
|
+
ai_pack_name: str,
|
|
64
|
+
litert_model_name: str,
|
|
65
|
+
*,
|
|
66
|
+
separate_mtk_ai_pack: bool = True,
|
|
67
|
+
):
|
|
68
|
+
"""Exports the model tflite files to the AI pack directory structure.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
compiled_models: The compiled models to export.
|
|
72
|
+
ai_pack_dir: The directory to export the AI pack to.
|
|
73
|
+
ai_pack_name: The name of the AI pack.
|
|
74
|
+
litert_model_name: The name of the model in the litert format.
|
|
75
|
+
separate_mtk_ai_pack: Whether to separate the MTK AI pack. If True, the main
|
|
76
|
+
AI pack will use the fallback model for MTK targets. The MTK AI pack will
|
|
77
|
+
contain all MTK models, and empty directories for non-MTK targets.
|
|
78
|
+
"""
|
|
79
|
+
fallback_model = None
|
|
80
|
+
for backend, model in compiled_models.models_with_backend:
|
|
81
|
+
if backend.target_id == fallback_backend.FallbackBackend.id():
|
|
82
|
+
fallback_model = model
|
|
83
|
+
assert fallback_model is not None, 'Fallback model is required.'
|
|
84
|
+
|
|
85
|
+
model_export_dir = ai_pack_dir / ai_pack_name / 'src/main/assets'
|
|
86
|
+
os.makedirs(model_export_dir, exist_ok=True)
|
|
87
|
+
for backend, model in compiled_models.models_with_backend:
|
|
88
|
+
if not _is_mobile_device_backend(backend):
|
|
89
|
+
continue
|
|
90
|
+
target_id = backend.target_id
|
|
91
|
+
backend_id = backend.id()
|
|
92
|
+
if backend_id == fallback_backend.FallbackBackend.id():
|
|
93
|
+
target_id = 'other'
|
|
94
|
+
elif backend_id == mediatek_backend.MediaTekBackend.id():
|
|
95
|
+
target_id = backend.target_id.replace(
|
|
96
|
+
mtk_target.SocManufacturer.MEDIATEK, 'Mediatek'
|
|
97
|
+
)
|
|
98
|
+
group_name = 'model#group_' + target_id
|
|
99
|
+
export_dir = model_export_dir / group_name
|
|
100
|
+
os.makedirs(export_dir, exist_ok=True)
|
|
101
|
+
model_export_path = export_dir / (litert_model_name + common.DOT_TFLITE)
|
|
102
|
+
if (
|
|
103
|
+
separate_mtk_ai_pack
|
|
104
|
+
and backend_id == mediatek_backend.MediaTekBackend.id()
|
|
105
|
+
):
|
|
106
|
+
# Use the fallback model for MTK targets in main AI pack.
|
|
107
|
+
model_to_export = fallback_model
|
|
108
|
+
else:
|
|
109
|
+
model_to_export = model
|
|
110
|
+
if not model_to_export.in_memory:
|
|
111
|
+
model_to_export.load()
|
|
112
|
+
model_to_export.save(model_export_path, export_only=True)
|
|
113
|
+
|
|
114
|
+
if separate_mtk_ai_pack:
|
|
115
|
+
_export_model_files_to_mtk_ai_pack(
|
|
116
|
+
compiled_models=compiled_models,
|
|
117
|
+
ai_pack_dir=ai_pack_dir,
|
|
118
|
+
ai_pack_name=ai_pack_name + '_mtk',
|
|
119
|
+
litert_model_name=litert_model_name + '_mtk',
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _export_model_files_to_mtk_ai_pack(
|
|
124
|
+
compiled_models: types.CompilationResult,
|
|
125
|
+
ai_pack_dir: pathlib.Path,
|
|
126
|
+
ai_pack_name: str,
|
|
127
|
+
litert_model_name: str,
|
|
128
|
+
):
|
|
129
|
+
"""Exports the model tflite files to the MTK AI pack directory structure."""
|
|
130
|
+
model_export_dir = ai_pack_dir / ai_pack_name / 'src/main/assets'
|
|
131
|
+
os.makedirs(model_export_dir, exist_ok=True)
|
|
132
|
+
for backend, model in compiled_models.models_with_backend:
|
|
133
|
+
if not _is_mobile_device_backend(backend):
|
|
134
|
+
continue
|
|
135
|
+
backend_id = backend.id()
|
|
136
|
+
target_id = backend.target_id
|
|
137
|
+
if backend_id == fallback_backend.FallbackBackend.id():
|
|
138
|
+
target_id = 'other'
|
|
139
|
+
elif backend_id == mediatek_backend.MediaTekBackend.id():
|
|
140
|
+
target_id = backend.target_id.replace(
|
|
141
|
+
mtk_target.SocManufacturer.MEDIATEK, 'Mediatek'
|
|
142
|
+
)
|
|
143
|
+
group_name = 'model#group_' + target_id
|
|
144
|
+
export_dir = model_export_dir / group_name
|
|
145
|
+
os.makedirs(export_dir, exist_ok=True)
|
|
146
|
+
if backend_id != mediatek_backend.MediaTekBackend.id():
|
|
147
|
+
# Skip non-MTK targets, just create a placeholder file.
|
|
148
|
+
placeholder_file = export_dir / 'placeholder.txt'
|
|
149
|
+
placeholder_file.touch()
|
|
150
|
+
continue
|
|
151
|
+
model_export_path = export_dir / (litert_model_name + common.DOT_TFLITE)
|
|
152
|
+
if not model.in_memory:
|
|
153
|
+
model.load()
|
|
154
|
+
model.save(model_export_path, export_only=True)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _build_targeting_config(compiled_backends: list[types.Backend]) -> str:
|
|
158
|
+
"""Builds device-targeting-config in device_targeting_configuration.xml."""
|
|
159
|
+
device_groups = []
|
|
160
|
+
for backend in compiled_backends:
|
|
161
|
+
if not _is_mobile_device_backend(backend):
|
|
162
|
+
continue
|
|
163
|
+
target = backend.target
|
|
164
|
+
device_group = _target_to_ai_pack_info(target)
|
|
165
|
+
if device_group:
|
|
166
|
+
device_groups.append(device_group)
|
|
167
|
+
device_groups = '\n'.join(device_groups)
|
|
168
|
+
return _DEVICE_TARGETING_CONFIGURATION.format(device_groups=device_groups)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _target_to_ai_pack_info(target: types.Target) -> str | None:
|
|
172
|
+
"""Builds the device group used in device_targeting_configuration.xml."""
|
|
173
|
+
if isinstance(target, qnn_target.Target):
|
|
174
|
+
group_name = str(target)
|
|
175
|
+
selector = _process_qnn_target(target)
|
|
176
|
+
device_selectors = [
|
|
177
|
+
_DEVICE_SELECTOR_TEMPLATE.format(soc_man=man, soc_model=model)
|
|
178
|
+
for man, model in selector
|
|
179
|
+
]
|
|
180
|
+
device_selectors = '\n'.join(device_selectors)
|
|
181
|
+
device_group = _DEVICE_GROUP_TEMPLATE.format(
|
|
182
|
+
device_group_name=group_name, device_selectors=device_selectors
|
|
183
|
+
)
|
|
184
|
+
return device_group
|
|
185
|
+
elif isinstance(target, mtk_target.Target):
|
|
186
|
+
group_name = str(target).replace(
|
|
187
|
+
mtk_target.SocManufacturer.MEDIATEK, 'Mediatek'
|
|
188
|
+
)
|
|
189
|
+
# TODO: b/407453529 - Support MTK SDK Version / OS version in selector.
|
|
190
|
+
selector = _process_mtk_target(target)
|
|
191
|
+
device_selector = _DEVICE_SELECTOR_TEMPLATE.format(
|
|
192
|
+
soc_man=selector[0], soc_model=selector[1]
|
|
193
|
+
)
|
|
194
|
+
device_group = _DEVICE_GROUP_TEMPLATE.format(
|
|
195
|
+
device_group_name=group_name, device_selectors=device_selector
|
|
196
|
+
)
|
|
197
|
+
return device_group
|
|
198
|
+
elif isinstance(target, google_tensor_target.Target):
|
|
199
|
+
group_name = str(target)
|
|
200
|
+
soc_manufacturer, soc_model = _process_google_tensor_target(target)
|
|
201
|
+
device_selector = _DEVICE_SELECTOR_TEMPLATE.format(
|
|
202
|
+
soc_man=soc_manufacturer, soc_model=soc_model
|
|
203
|
+
)
|
|
204
|
+
device_group = _DEVICE_GROUP_TEMPLATE.format(
|
|
205
|
+
device_group_name=group_name, device_selectors=device_selector
|
|
206
|
+
)
|
|
207
|
+
return device_group
|
|
208
|
+
elif isinstance(target, fallback_backend.FallbackTarget):
|
|
209
|
+
# Don't need to have device selector for fallback target.
|
|
210
|
+
return None
|
|
211
|
+
else:
|
|
212
|
+
print('unsupported target ', target)
|
|
213
|
+
return None
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# TODO: b/407453529 - Auto-generate this function from CSVs.
|
|
217
|
+
def _process_qnn_target(target: qnn_target.Target) -> list[tuple[str, str]]:
|
|
218
|
+
"""Returns the list of (manufacturer, model) for the given QNN target."""
|
|
219
|
+
# Play cannot distinguish between Qualcomm and QTI for now.
|
|
220
|
+
manufacturer = ['Qualcomm', 'QTI']
|
|
221
|
+
models = [str(target.soc_model)]
|
|
222
|
+
return list(itertools.product(manufacturer, models))
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
# TODO: b/407453529 - Auto-generate this function from CSVs.
|
|
226
|
+
def _process_mtk_target(
|
|
227
|
+
target: mtk_target.Target,
|
|
228
|
+
) -> tuple[str, str]:
|
|
229
|
+
"""Returns tuple of (manufacturer, model) for the given MTK target."""
|
|
230
|
+
# Play cannot distinguish between Qualcomm and QTI for now.
|
|
231
|
+
return str(target.soc_manufacturer).replace(
|
|
232
|
+
mtk_target.SocManufacturer.MEDIATEK, 'Mediatek'
|
|
233
|
+
), str(target.soc_model)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# TODO: b/407453529 - Auto-generate this function from CSVs.
|
|
237
|
+
def _process_google_tensor_target(
|
|
238
|
+
target: google_tensor_target.Target,
|
|
239
|
+
) -> tuple[str, str]:
|
|
240
|
+
"""Returns tuple of (manufacturer, model) for the given Google Tensor target."""
|
|
241
|
+
return str(target.soc_manufacturer), str(target.soc_model).replace('_', ' ')
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _write_targeting_config(
|
|
245
|
+
compiled_models: types.CompilationResult, ai_pack_dir: pathlib.Path
|
|
246
|
+
) -> None:
|
|
247
|
+
"""Writes device_targeting_configuration.xml for the given compiled models."""
|
|
248
|
+
compiled_backends = [x for x, _ in compiled_models.models_with_backend]
|
|
249
|
+
targeting_config = _build_targeting_config(
|
|
250
|
+
compiled_backends=compiled_backends
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
targeting_config_path = ai_pack_dir / 'device_targeting_configuration.xml'
|
|
254
|
+
targeting_config_path.write_text(targeting_config)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def export(
|
|
258
|
+
compiled_models: types.CompilationResult,
|
|
259
|
+
ai_pack_dir: pathlib.Path | str,
|
|
260
|
+
ai_pack_name: str,
|
|
261
|
+
litert_model_name: str,
|
|
262
|
+
) -> None:
|
|
263
|
+
"""Exports the compiled models to AI pack format.
|
|
264
|
+
|
|
265
|
+
This function will export the compiled models to corresponding directory
|
|
266
|
+
structure:
|
|
267
|
+
|
|
268
|
+
{ai_pack_dir}/
|
|
269
|
+
AiPackManifest.xml
|
|
270
|
+
device_targeting_configuration.xml
|
|
271
|
+
{ai_pack_name}/src/main/assets/
|
|
272
|
+
model#group_target_1/
|
|
273
|
+
{litert_model_name}.tflite
|
|
274
|
+
model#group_target_2/
|
|
275
|
+
{litert_model_name}.tflite
|
|
276
|
+
model#group_target_3/
|
|
277
|
+
{litert_model_name}.tflite
|
|
278
|
+
model#group_other/
|
|
279
|
+
{litert_model_name}.tflite
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
compiled_models: The compiled models to export.
|
|
283
|
+
ai_pack_dir: The directory to export the AI pack to.
|
|
284
|
+
ai_pack_name: The name of the AI pack.
|
|
285
|
+
litert_model_name: The name of the model in the litert format.
|
|
286
|
+
"""
|
|
287
|
+
if isinstance(ai_pack_dir, str):
|
|
288
|
+
ai_pack_dir = pathlib.Path(ai_pack_dir)
|
|
289
|
+
|
|
290
|
+
ai_pack_dir.mkdir(parents=True, exist_ok=True)
|
|
291
|
+
|
|
292
|
+
_export_model_files_to_ai_pack(
|
|
293
|
+
compiled_models=compiled_models,
|
|
294
|
+
ai_pack_dir=ai_pack_dir,
|
|
295
|
+
ai_pack_name=ai_pack_name,
|
|
296
|
+
litert_model_name=litert_model_name,
|
|
297
|
+
)
|
|
298
|
+
_write_targeting_config(
|
|
299
|
+
compiled_models=compiled_models, ai_pack_dir=ai_pack_dir
|
|
300
|
+
)
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# Copyright 2025 The LiteRT Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""AOT Compilation for LiteRT model."""
|
|
17
|
+
import pathlib
|
|
18
|
+
import tempfile
|
|
19
|
+
|
|
20
|
+
from ai_edge_litert.aot import prepare_for_npu as core
|
|
21
|
+
from ai_edge_litert.aot.core import apply_plugin
|
|
22
|
+
from ai_edge_litert.aot.core import components
|
|
23
|
+
from ai_edge_litert.aot.core import mlir_transforms
|
|
24
|
+
from ai_edge_litert.aot.core import types
|
|
25
|
+
from ai_edge_litert.aot.vendors import import_vendor
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def aot_compile(
|
|
29
|
+
input_model: types.Model | str,
|
|
30
|
+
output_dir: str | pathlib.Path | None = None,
|
|
31
|
+
target: types.Target | list[types.Target] | None = None,
|
|
32
|
+
config: (
|
|
33
|
+
types.CompilationConfig | list[types.CompilationConfig] | None
|
|
34
|
+
) = None,
|
|
35
|
+
quantizer: components.AieQuantizerT | None = None,
|
|
36
|
+
keep_going: bool = True,
|
|
37
|
+
subgraphs_to_compile: list[int] | None = None,
|
|
38
|
+
**kwargs,
|
|
39
|
+
) -> types.CompilationResult:
|
|
40
|
+
"""Prepares a TFLite model for NPU execution.
|
|
41
|
+
|
|
42
|
+
High level command that erforms various backend specific pre-processing steps
|
|
43
|
+
and then applies an NPU compiler to the given model.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
input_model: The input model to compile.
|
|
47
|
+
output_dir: Directory to write the output files to. If not specified, the
|
|
48
|
+
output files will be written to the same directory as the input file.
|
|
49
|
+
target: The target to compile for. If not specified, will compile to all
|
|
50
|
+
registered targets.
|
|
51
|
+
config: The compilation config(s). Cannot be specified with target.
|
|
52
|
+
quantizer: The quantizer to use for quantization.
|
|
53
|
+
keep_going: Whether to keep going if some backends fail. If False, fail fast
|
|
54
|
+
on the first error and raise an exception.
|
|
55
|
+
subgraphs_to_compile: The subgraph index list to compile to NPU. If None,
|
|
56
|
+
compile all subgraphs.
|
|
57
|
+
**kwargs: Additional arguments to pass to the backend.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Compiled models.
|
|
61
|
+
"""
|
|
62
|
+
# Only one of target or config is needed.
|
|
63
|
+
if target and config:
|
|
64
|
+
raise ValueError("Cannot specify both target and config.")
|
|
65
|
+
|
|
66
|
+
if config is None:
|
|
67
|
+
if target is None:
|
|
68
|
+
target = import_vendor.AllRegisteredTarget()
|
|
69
|
+
if isinstance(target, types.Target):
|
|
70
|
+
config = types.CompilationConfig(target=target)
|
|
71
|
+
elif isinstance(target, list):
|
|
72
|
+
config = [types.CompilationConfig(target=t) for t in target]
|
|
73
|
+
else:
|
|
74
|
+
raise ValueError("Unsupported target type.")
|
|
75
|
+
|
|
76
|
+
if isinstance(input_model, str):
|
|
77
|
+
input_path = pathlib.Path(input_model)
|
|
78
|
+
input_model = types.Model.create_from_path(input_path)
|
|
79
|
+
|
|
80
|
+
# Resolve output paths.
|
|
81
|
+
temp_dir = None
|
|
82
|
+
if not output_dir:
|
|
83
|
+
if input_model.in_memory:
|
|
84
|
+
# Use a temp dir for in-memory models.
|
|
85
|
+
# The temp dir will be cleaned up after the models are compiled and loaded
|
|
86
|
+
# back to memory (i.e. function returns).
|
|
87
|
+
temp_dir = tempfile.TemporaryDirectory()
|
|
88
|
+
output_dir = temp_dir.name
|
|
89
|
+
else:
|
|
90
|
+
input_path = input_model.path
|
|
91
|
+
output_dir = input_path.parent / "_compiled_models"
|
|
92
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
93
|
+
output_dir = str(output_dir)
|
|
94
|
+
output_dir_path = pathlib.Path(output_dir)
|
|
95
|
+
output_dir_path.mkdir(parents=True, exist_ok=True)
|
|
96
|
+
|
|
97
|
+
if isinstance(config, types.CompilationConfig) or not config:
|
|
98
|
+
if config:
|
|
99
|
+
# Make pytype happy.
|
|
100
|
+
assert isinstance(config, types.CompilationConfig)
|
|
101
|
+
kw_config = config.to_dict() | kwargs
|
|
102
|
+
else:
|
|
103
|
+
kw_config = kwargs
|
|
104
|
+
|
|
105
|
+
backend_class = core.resolve_backend(kw_config)
|
|
106
|
+
|
|
107
|
+
quant_recipe = kw_config.get("quantize_recipe", None)
|
|
108
|
+
if quant_recipe:
|
|
109
|
+
assert quantizer is not None, "Quantizer is required for quantization."
|
|
110
|
+
|
|
111
|
+
results = core.prepare_for_npu(
|
|
112
|
+
input_model,
|
|
113
|
+
output_dir_path,
|
|
114
|
+
backend_class,
|
|
115
|
+
kw_config,
|
|
116
|
+
transforms=mlir_transforms.MlirTransforms(),
|
|
117
|
+
quantizer=quantizer,
|
|
118
|
+
plugin=apply_plugin.ApplyPlugin(
|
|
119
|
+
experimental_capture_stderr=True,
|
|
120
|
+
subgraphs_to_compile=subgraphs_to_compile,
|
|
121
|
+
),
|
|
122
|
+
keep_going=keep_going,
|
|
123
|
+
)
|
|
124
|
+
elif isinstance(config, list):
|
|
125
|
+
kw_configs = [c.to_dict() | kwargs for c in config]
|
|
126
|
+
|
|
127
|
+
configs_with_backend = [(core.resolve_backend(c), c) for c in kw_configs]
|
|
128
|
+
requires_quantizer = any("quantize_recipe" in c for c in kw_configs)
|
|
129
|
+
if requires_quantizer and quantizer is None:
|
|
130
|
+
raise ValueError("Quantizer is required for quantization.")
|
|
131
|
+
|
|
132
|
+
results = core.prepare_for_npu_multiple_configs(
|
|
133
|
+
input_model,
|
|
134
|
+
output_dir_path,
|
|
135
|
+
configs_with_backend,
|
|
136
|
+
transforms=mlir_transforms.MlirTransforms(),
|
|
137
|
+
quantizer=quantizer,
|
|
138
|
+
plugin=apply_plugin.ApplyPlugin(
|
|
139
|
+
experimental_capture_stderr=True,
|
|
140
|
+
subgraphs_to_compile=subgraphs_to_compile,
|
|
141
|
+
),
|
|
142
|
+
keep_going=keep_going,
|
|
143
|
+
)
|
|
144
|
+
else:
|
|
145
|
+
# Should not reach here.
|
|
146
|
+
raise ValueError("Unsupported config type.")
|
|
147
|
+
|
|
148
|
+
if temp_dir:
|
|
149
|
+
# Load the models to memory before cleaning up the temp dir.
|
|
150
|
+
results.load()
|
|
151
|
+
temp_dir.cleanup()
|
|
152
|
+
|
|
153
|
+
return results
|
|
File without changes
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Wrapper for calling the apply plugin tooling."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
import pathlib
|
|
21
|
+
import re
|
|
22
|
+
import subprocess
|
|
23
|
+
import tempfile
|
|
24
|
+
|
|
25
|
+
from ai_edge_litert.aot.core import common
|
|
26
|
+
from ai_edge_litert.aot.core import components
|
|
27
|
+
from ai_edge_litert.aot.core import types
|
|
28
|
+
|
|
29
|
+
_BINARY = pathlib.Path("tools/apply_plugin_main")
|
|
30
|
+
|
|
31
|
+
_RE_PARTITION_STATS = re.compile(
|
|
32
|
+
r"Partitioned subgraph<(\d+)>, selected (\d+) ops, from a total of "
|
|
33
|
+
r"(\d+) ops. resulted in (\d+) partitions."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ApplyPlugin(components.ApplyPluginT):
|
|
38
|
+
"""Wrapper for calling the apply plugin tooling."""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
experimental_capture_stderr: bool = False,
|
|
43
|
+
subgraphs_to_compile: list[int] | None = None,
|
|
44
|
+
):
|
|
45
|
+
self._experimental_capture_stderr = experimental_capture_stderr
|
|
46
|
+
self._subgraphs_to_compile = subgraphs_to_compile
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def default_err(self) -> str:
|
|
50
|
+
# NOTE: Capture stderr from underlying binary.
|
|
51
|
+
return "--"
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def component_name(self) -> str:
|
|
55
|
+
return "apply_plugin"
|
|
56
|
+
|
|
57
|
+
def __call__(
|
|
58
|
+
self,
|
|
59
|
+
input_model: types.Model,
|
|
60
|
+
output_model: types.Model,
|
|
61
|
+
soc_manufacturer: str,
|
|
62
|
+
soc_model: str,
|
|
63
|
+
sdk_libs_path: str | None = None,
|
|
64
|
+
**kwargs,
|
|
65
|
+
):
|
|
66
|
+
"""Applies a plugin to the input model.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
input_model: The path to the input model.
|
|
70
|
+
output_model: The path to the output model.
|
|
71
|
+
soc_manufacturer: The SOC manufacturer of the plugin.
|
|
72
|
+
soc_model: The SOC model of the plugin.
|
|
73
|
+
sdk_libs_path: The path to the SDK libs. If not provided,
|
|
74
|
+
the default SDK path will be used.
|
|
75
|
+
**kwargs: Additional arguments to pass to the underlying binary.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
The output model.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
ValueError: If no tflite model was created by the underying binary.
|
|
82
|
+
"""
|
|
83
|
+
if input_model.in_memory:
|
|
84
|
+
tmp_file = tempfile.NamedTemporaryFile(mode="wb")
|
|
85
|
+
input_model.save(tmp_file.name)
|
|
86
|
+
else:
|
|
87
|
+
tmp_file = None
|
|
88
|
+
|
|
89
|
+
binary = common.get_resource(_BINARY)
|
|
90
|
+
args = [
|
|
91
|
+
str(binary),
|
|
92
|
+
"--cmd=apply",
|
|
93
|
+
f"--model={str(input_model.path)}",
|
|
94
|
+
f"--o={str(output_model.path)}",
|
|
95
|
+
f"--soc_manufacturer={soc_manufacturer}",
|
|
96
|
+
f"--soc_model={soc_model}",
|
|
97
|
+
f"--err={self.default_err}",
|
|
98
|
+
]
|
|
99
|
+
extra_args = [f"--{key}={value}" for key, value in kwargs.items()]
|
|
100
|
+
args.extend(extra_args)
|
|
101
|
+
if self._subgraphs_to_compile:
|
|
102
|
+
subgraphs_to_compile = ",".join(
|
|
103
|
+
str(s) for s in self._subgraphs_to_compile
|
|
104
|
+
)
|
|
105
|
+
args.append(f"--subgraphs={subgraphs_to_compile}")
|
|
106
|
+
env = os.environ.copy()
|
|
107
|
+
ld_library_path = common.construct_ld_library_path()
|
|
108
|
+
if ld_library_path:
|
|
109
|
+
if sdk_libs_path:
|
|
110
|
+
ld_library_path = f"{sdk_libs_path}{os.pathsep}{ld_library_path}"
|
|
111
|
+
env["LD_LIBRARY_PATH"] = ld_library_path
|
|
112
|
+
|
|
113
|
+
result = subprocess.run(
|
|
114
|
+
args,
|
|
115
|
+
check=False,
|
|
116
|
+
text=True,
|
|
117
|
+
stdout=subprocess.PIPE,
|
|
118
|
+
stderr=subprocess.STDOUT,
|
|
119
|
+
env=env,
|
|
120
|
+
)
|
|
121
|
+
if result.returncode:
|
|
122
|
+
log_file = tempfile.NamedTemporaryFile(
|
|
123
|
+
suffix=".error", mode="w", delete=False
|
|
124
|
+
)
|
|
125
|
+
log_file.write(result.stdout)
|
|
126
|
+
log_file.close()
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"{self.component_name} failed to apply plugin. See"
|
|
129
|
+
f" {log_file.name} for details."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if not common.is_tflite(output_model.path):
|
|
133
|
+
raise ValueError(f"{output_model.path} is not a TFLite model.")
|
|
134
|
+
|
|
135
|
+
partition_stats = _RE_PARTITION_STATS.findall(result.stdout)
|
|
136
|
+
output_model.partition_stats = types.PartitionStats(
|
|
137
|
+
subgraph_stats=[
|
|
138
|
+
types.SubgraphPartitionStats(
|
|
139
|
+
subgraph_index=int(s[0]),
|
|
140
|
+
num_ops_offloaded=int(s[1]),
|
|
141
|
+
num_total_ops=int(s[2]),
|
|
142
|
+
num_partitions_offloaded=int(s[3]),
|
|
143
|
+
)
|
|
144
|
+
for s in partition_stats
|
|
145
|
+
]
|
|
146
|
+
)
|
|
147
|
+
if tmp_file is not None:
|
|
148
|
+
tmp_file.close()
|