ai-edge-litert-nightly 2.2.0.dev20260102__cp312-cp312-manylinux_2_27_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_litert/__init__.py +1 -0
- ai_edge_litert/_pywrap_analyzer_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_compiled_model_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_interpreter_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_tensor_buffer_wrapper.so +0 -0
- ai_edge_litert/_pywrap_modify_model_interface.so +0 -0
- ai_edge_litert/_pywrap_string_util.so +0 -0
- ai_edge_litert/_pywrap_tensorflow_lite_calibration_wrapper.so +0 -0
- ai_edge_litert/_pywrap_tensorflow_lite_metrics_wrapper.so +0 -0
- ai_edge_litert/any_pb2.py +37 -0
- ai_edge_litert/aot/__init__.py +0 -0
- ai_edge_litert/aot/ai_pack/__init__.py +0 -0
- ai_edge_litert/aot/ai_pack/export_lib.py +300 -0
- ai_edge_litert/aot/aot_compile.py +153 -0
- ai_edge_litert/aot/core/__init__.py +0 -0
- ai_edge_litert/aot/core/apply_plugin.py +148 -0
- ai_edge_litert/aot/core/common.py +97 -0
- ai_edge_litert/aot/core/components.py +93 -0
- ai_edge_litert/aot/core/mlir_transforms.py +36 -0
- ai_edge_litert/aot/core/tflxx_util.py +30 -0
- ai_edge_litert/aot/core/types.py +374 -0
- ai_edge_litert/aot/prepare_for_npu.py +152 -0
- ai_edge_litert/aot/vendors/__init__.py +22 -0
- ai_edge_litert/aot/vendors/example/__init__.py +0 -0
- ai_edge_litert/aot/vendors/example/example_backend.py +157 -0
- ai_edge_litert/aot/vendors/fallback_backend.py +128 -0
- ai_edge_litert/aot/vendors/google_tensor/__init__.py +0 -0
- ai_edge_litert/aot/vendors/google_tensor/google_tensor_backend.py +168 -0
- ai_edge_litert/aot/vendors/google_tensor/target.py +84 -0
- ai_edge_litert/aot/vendors/import_vendor.py +132 -0
- ai_edge_litert/aot/vendors/mediatek/__init__.py +0 -0
- ai_edge_litert/aot/vendors/mediatek/mediatek_backend.py +196 -0
- ai_edge_litert/aot/vendors/mediatek/target.py +94 -0
- ai_edge_litert/aot/vendors/qualcomm/__init__.py +0 -0
- ai_edge_litert/aot/vendors/qualcomm/qualcomm_backend.py +161 -0
- ai_edge_litert/aot/vendors/qualcomm/target.py +75 -0
- ai_edge_litert/api_pb2.py +43 -0
- ai_edge_litert/compiled_model.py +250 -0
- ai_edge_litert/descriptor_pb2.py +3361 -0
- ai_edge_litert/duration_pb2.py +37 -0
- ai_edge_litert/empty_pb2.py +37 -0
- ai_edge_litert/field_mask_pb2.py +37 -0
- ai_edge_litert/format_converter_wrapper_pybind11.so +0 -0
- ai_edge_litert/hardware_accelerator.py +22 -0
- ai_edge_litert/internal/__init__.py +0 -0
- ai_edge_litert/internal/litertlm_builder.py +584 -0
- ai_edge_litert/internal/litertlm_core.py +58 -0
- ai_edge_litert/internal/litertlm_header_schema_py_generated.py +1596 -0
- ai_edge_litert/internal/llm_metadata_pb2.py +45 -0
- ai_edge_litert/internal/llm_model_type_pb2.py +51 -0
- ai_edge_litert/internal/sampler_params_pb2.py +39 -0
- ai_edge_litert/internal/token_pb2.py +38 -0
- ai_edge_litert/interpreter.py +1039 -0
- ai_edge_litert/libLiteRt.so +0 -0
- ai_edge_litert/libpywrap_litert_common.so +0 -0
- ai_edge_litert/metrics_interface.py +48 -0
- ai_edge_litert/metrics_portable.py +70 -0
- ai_edge_litert/model_runtime_info_pb2.py +66 -0
- ai_edge_litert/plugin_pb2.py +46 -0
- ai_edge_litert/profiling_info_pb2.py +47 -0
- ai_edge_litert/pywrap_genai_ops.so +0 -0
- ai_edge_litert/schema_py_generated.py +19640 -0
- ai_edge_litert/source_context_pb2.py +37 -0
- ai_edge_litert/struct_pb2.py +47 -0
- ai_edge_litert/tensor_buffer.py +167 -0
- ai_edge_litert/timestamp_pb2.py +37 -0
- ai_edge_litert/tools/__init__.py +0 -0
- ai_edge_litert/tools/apply_plugin_main +0 -0
- ai_edge_litert/tools/flatbuffer_utils.py +534 -0
- ai_edge_litert/type_pb2.py +53 -0
- ai_edge_litert/vendors/google_tensor/compiler/libLiteRtCompilerPlugin_google_tensor.so +0 -0
- ai_edge_litert/vendors/mediatek/compiler/libLiteRtCompilerPlugin_MediaTek.so +0 -0
- ai_edge_litert/vendors/qualcomm/compiler/libLiteRtCompilerPlugin_Qualcomm.so +0 -0
- ai_edge_litert/wrappers_pb2.py +53 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/METADATA +52 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/RECORD +78 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/WHEEL +5 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Implementations for the main public API functionalities."""
|
|
17
|
+
|
|
18
|
+
import pathlib
|
|
19
|
+
from typing import cast
|
|
20
|
+
|
|
21
|
+
# pylint: disable=g-import-not-at-top
|
|
22
|
+
# pytype: disable=import-error
|
|
23
|
+
try:
|
|
24
|
+
from tqdm import auto as autotqdm
|
|
25
|
+
except ImportError:
|
|
26
|
+
from tqdm.tqdm import auto as autotqdm
|
|
27
|
+
# pytype: enable=import-error
|
|
28
|
+
|
|
29
|
+
from ai_edge_litert.aot.core import common
|
|
30
|
+
from ai_edge_litert.aot.core import components
|
|
31
|
+
from ai_edge_litert.aot.core import types
|
|
32
|
+
from ai_edge_litert.aot.vendors import import_vendor
|
|
33
|
+
|
|
34
|
+
# pylint: enable=g-import-not-at-top
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def resolve_backend(config: types.Config) -> types.BackendT:
|
|
38
|
+
# Import the backend based on the ID.
|
|
39
|
+
backend_id = config.get("backend_id", None)
|
|
40
|
+
if backend_id is None:
|
|
41
|
+
raise ValueError("Backend ID is required.")
|
|
42
|
+
return import_vendor.import_vendor(backend_id)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def prepare_for_npu_multiple_configs(
|
|
46
|
+
flatbuffer: types.Model,
|
|
47
|
+
output_dir: pathlib.Path,
|
|
48
|
+
configs: list[tuple[types.BackendT, types.Config]],
|
|
49
|
+
plugin: components.ApplyPluginT,
|
|
50
|
+
transforms: components.MlirTransformsT | None = None,
|
|
51
|
+
quantizer: components.AieQuantizerT | None = None,
|
|
52
|
+
keep_going: bool = False,
|
|
53
|
+
) -> types.CompilationResult:
|
|
54
|
+
"""Prepares a TFLite model for NPU execution."""
|
|
55
|
+
backends = []
|
|
56
|
+
for backend_class, config in configs:
|
|
57
|
+
backend = backend_class.create(config)
|
|
58
|
+
backends += list(backend.specialize())
|
|
59
|
+
|
|
60
|
+
pipeline: list[types.Component] = [
|
|
61
|
+
c for c in [transforms, quantizer, plugin] if c is not None
|
|
62
|
+
]
|
|
63
|
+
return compile_model(flatbuffer, output_dir, backends, pipeline, keep_going)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def prepare_for_npu(
|
|
67
|
+
flatbuffer: types.Model,
|
|
68
|
+
output_dir: pathlib.Path,
|
|
69
|
+
backend_class: types.BackendT,
|
|
70
|
+
config: types.Config,
|
|
71
|
+
plugin: components.ApplyPluginT,
|
|
72
|
+
transforms: components.MlirTransformsT | None = None,
|
|
73
|
+
quantizer: components.AieQuantizerT | None = None,
|
|
74
|
+
keep_going: bool = False,
|
|
75
|
+
) -> types.CompilationResult:
|
|
76
|
+
"""Prepares a TFLite model for NPU execution.
|
|
77
|
+
|
|
78
|
+
High level command that erforms various backend specific pre-processing steps
|
|
79
|
+
and then applies an NPU compiler to the given model.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
flatbuffer: Path to the input flatbuffer file.
|
|
83
|
+
output_dir: Directory to write the output flatbuffer file.
|
|
84
|
+
backend_class: The backend to prepare the model for.
|
|
85
|
+
config: The configuration for the backend.
|
|
86
|
+
plugin: The plugin to apply to the model.
|
|
87
|
+
transforms: The transforms to apply to the model.
|
|
88
|
+
quantizer: The quantizer to apply to the model.
|
|
89
|
+
keep_going: Whether to keep going if some backends fail.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
List of the paths to the output flatbuffer file.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ValueError: If the given path is not a valid flatbuffer file.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
backend = backend_class.create(config)
|
|
99
|
+
|
|
100
|
+
pipeline: list[types.Component] = [
|
|
101
|
+
c for c in [transforms, quantizer, plugin] if c is not None
|
|
102
|
+
]
|
|
103
|
+
backends = list(backend.specialize())
|
|
104
|
+
return compile_model(flatbuffer, output_dir, backends, pipeline, keep_going)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def compile_model(
|
|
108
|
+
flatbuffer: types.Model,
|
|
109
|
+
output_dir: pathlib.Path,
|
|
110
|
+
backends: list[types.Backend],
|
|
111
|
+
pipeline: list[types.Component],
|
|
112
|
+
keep_going: bool = False,
|
|
113
|
+
) -> types.CompilationResult:
|
|
114
|
+
"""Compiles a TFLite model for NPU execution."""
|
|
115
|
+
if flatbuffer.in_memory:
|
|
116
|
+
base_name = "model"
|
|
117
|
+
else:
|
|
118
|
+
base_name = flatbuffer.path.name.removesuffix(common.DOT_TFLITE)
|
|
119
|
+
compile_models = types.CompilationResult()
|
|
120
|
+
with autotqdm.tqdm(backends, desc="Backend") as t_backends:
|
|
121
|
+
for backend in t_backends:
|
|
122
|
+
component_input = flatbuffer
|
|
123
|
+
backend = cast(types.Backend, backend)
|
|
124
|
+
input_name_pref = base_name + backend.target_id_suffix
|
|
125
|
+
t_backends.set_description(f"Compiling {backend.target_id}")
|
|
126
|
+
try:
|
|
127
|
+
for component in pipeline:
|
|
128
|
+
component = cast(types.Component, component)
|
|
129
|
+
t_backends.set_description(
|
|
130
|
+
f"Compiling {backend.target_id}: {component.component_name}"
|
|
131
|
+
)
|
|
132
|
+
component_output = types.Model.create_from_path(
|
|
133
|
+
output_dir
|
|
134
|
+
/ f"{input_name_pref}_{component.component_name}{common.DOT_TFLITE}"
|
|
135
|
+
)
|
|
136
|
+
backend.call_component(component_input, component_output, component)
|
|
137
|
+
if not component_output.in_memory and not common.is_tflite(
|
|
138
|
+
component_output.path
|
|
139
|
+
):
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"{component.component_name} failed to produce a TFLite model."
|
|
142
|
+
)
|
|
143
|
+
component_input = component_output
|
|
144
|
+
compile_models.models_with_backend.append((backend, component_input))
|
|
145
|
+
except ValueError as e:
|
|
146
|
+
if keep_going:
|
|
147
|
+
print(f"Skipping failed compilation for {backend.target}. Error: {e}")
|
|
148
|
+
compile_models.failed_backends.append((backend, str(e)))
|
|
149
|
+
else:
|
|
150
|
+
raise
|
|
151
|
+
|
|
152
|
+
return compile_models
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Vendor backends for LiteRt."""
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from ai_edge_litert.aot.vendors.mediatek import mediatek_backend as _
|
|
19
|
+
from ai_edge_litert.aot.vendors.qualcomm import qualcomm_backend as _
|
|
20
|
+
|
|
21
|
+
if os.environ.get("GOOGLE_TENSOR_COMPILER_LIB") is not None:
|
|
22
|
+
from ai_edge_litert.aot.vendors.google_tensor import google_tensor_backend as _ # pylint: disable=g-import-not-at-top
|
|
File without changes
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Backend implementation for the example compiler plugin.."""
|
|
17
|
+
|
|
18
|
+
import functools
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from ai_edge_litert.aot.core import components
|
|
22
|
+
from ai_edge_litert.aot.core import types
|
|
23
|
+
from ai_edge_litert.aot.vendors import import_vendor
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ExampleTarget(types.Target):
|
|
27
|
+
"""Compilation target for the example backend."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, soc_manufacturer: str, soc_model: str):
|
|
30
|
+
self.soc_manufacturer = soc_manufacturer
|
|
31
|
+
self.soc_model = soc_model
|
|
32
|
+
|
|
33
|
+
def __hash__(self) -> int:
|
|
34
|
+
return hash((self.soc_manufacturer, self.soc_model))
|
|
35
|
+
|
|
36
|
+
def __eq__(self, other) -> bool:
|
|
37
|
+
return (
|
|
38
|
+
self.soc_manufacturer == other.soc_manufacturer
|
|
39
|
+
and self.soc_model == other.soc_model
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def __repr__(self) -> str:
|
|
43
|
+
return f"{self.soc_manufacturer}_{self.soc_model}"
|
|
44
|
+
|
|
45
|
+
def flatten(self) -> dict[str, Any]:
|
|
46
|
+
return {
|
|
47
|
+
"soc_manufacturer": self.soc_manufacturer,
|
|
48
|
+
"soc_model": self.soc_model,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def backend_id(cls) -> str:
|
|
53
|
+
return "example"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# Note this is not a real target so not auto-registered unless the module is
|
|
57
|
+
# imported.
|
|
58
|
+
@import_vendor.register_backend
|
|
59
|
+
class ExampleBackend(types.Backend):
|
|
60
|
+
"""Backend implementation for the example compiler plugin."""
|
|
61
|
+
|
|
62
|
+
def __init__(self, config: types.Config):
|
|
63
|
+
super().__init__(config)
|
|
64
|
+
self._compilation_config = config.get("compilation_config", None)
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def target_(cls) -> ExampleTarget:
|
|
68
|
+
return ExampleTarget("ExampleSocManufacturer", "ExampleSocModel")
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def target(self) -> ExampleTarget:
|
|
72
|
+
return self.target_()
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def soc_manufacturer(cls) -> str:
|
|
76
|
+
return cls.target_().soc_manufacturer
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def soc_model(cls) -> str:
|
|
80
|
+
return cls.target_().soc_model
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def id(cls) -> str:
|
|
84
|
+
return "example"
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def target_id(self) -> str:
|
|
88
|
+
return ""
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def shared_pass_names(self) -> list[str]:
|
|
92
|
+
return ["example-pass"]
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def create(cls, config: types.Config) -> "ExampleBackend":
|
|
96
|
+
if config.get("backend_id", "") != cls.id():
|
|
97
|
+
raise ValueError("Invalid backend id")
|
|
98
|
+
return cls(config)
|
|
99
|
+
|
|
100
|
+
def call_component(
|
|
101
|
+
self,
|
|
102
|
+
input_model: types.Model,
|
|
103
|
+
output_model: types.Model,
|
|
104
|
+
component: types.Component,
|
|
105
|
+
):
|
|
106
|
+
return _call_component(component, self, input_model, output_model)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@functools.singledispatch
|
|
110
|
+
def _call_component(
|
|
111
|
+
component: types.Component,
|
|
112
|
+
backend: ExampleBackend,
|
|
113
|
+
unused_input_model: types.Model,
|
|
114
|
+
unused_output_model: types.Model,
|
|
115
|
+
):
|
|
116
|
+
raise NotImplementedError(
|
|
117
|
+
f"{backend.id()} backend does not support"
|
|
118
|
+
f" {component.component_name} component."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@_call_component.register
|
|
123
|
+
def _apply_plugin(
|
|
124
|
+
component: components.ApplyPluginT,
|
|
125
|
+
backend: ExampleBackend,
|
|
126
|
+
input_model: types.Model,
|
|
127
|
+
output_model: types.Model,
|
|
128
|
+
):
|
|
129
|
+
return component(
|
|
130
|
+
input_model,
|
|
131
|
+
output_model,
|
|
132
|
+
backend.soc_manufacturer,
|
|
133
|
+
backend.soc_model,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@_call_component.register
|
|
138
|
+
def _aie_quantizer(
|
|
139
|
+
component: components.AieQuantizerT,
|
|
140
|
+
unused_backend: ExampleBackend,
|
|
141
|
+
input_model: types.Model,
|
|
142
|
+
output_model: types.Model,
|
|
143
|
+
):
|
|
144
|
+
return component(
|
|
145
|
+
input_model,
|
|
146
|
+
output_model,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@_call_component.register
|
|
151
|
+
def _mlir_transforms(
|
|
152
|
+
component: components.MlirTransformsT,
|
|
153
|
+
backend: ExampleBackend,
|
|
154
|
+
input_model: types.Model,
|
|
155
|
+
output_model: types.Model,
|
|
156
|
+
):
|
|
157
|
+
return component(input_model, output_model, backend.shared_pass_names)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""A Fallback backend for LITERT."""
|
|
17
|
+
|
|
18
|
+
import functools
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from ai_edge_litert.aot.core import components
|
|
22
|
+
from ai_edge_litert.aot.core import types
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class FallbackTarget(types.Target):
|
|
26
|
+
"""A virtual Compilation target."""
|
|
27
|
+
|
|
28
|
+
def __hash__(self) -> int:
|
|
29
|
+
return hash(self.backend_id())
|
|
30
|
+
|
|
31
|
+
def __eq__(self, other: types.Target) -> bool:
|
|
32
|
+
return self.backend_id() == other.backend_id()
|
|
33
|
+
|
|
34
|
+
def __repr__(self) -> str:
|
|
35
|
+
return f"{self.backend_id()}"
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def backend_id(cls) -> str:
|
|
39
|
+
return "fallback"
|
|
40
|
+
|
|
41
|
+
def flatten(self) -> dict[str, Any]:
|
|
42
|
+
return {"backend_id": self.backend_id()}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class FallbackBackend(types.Backend):
|
|
46
|
+
"""Fallback backend for LITERT."""
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def target(self) -> FallbackTarget:
|
|
50
|
+
return FallbackTarget()
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def target_id(self) -> str:
|
|
54
|
+
return repr(self.target)
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def id(cls) -> str:
|
|
58
|
+
return "fallback"
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def create(cls, config: types.Config) -> "FallbackBackend":
|
|
62
|
+
if config.get("backend_id", "") != cls.id():
|
|
63
|
+
raise ValueError("Invalid backend id")
|
|
64
|
+
return cls(config)
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def quantize_recipe(self) -> str | None:
|
|
68
|
+
return self.config.get("quantize_recipe", None)
|
|
69
|
+
|
|
70
|
+
def call_component(
|
|
71
|
+
self,
|
|
72
|
+
input_model: types.Model,
|
|
73
|
+
output_model: types.Model,
|
|
74
|
+
component: types.Component,
|
|
75
|
+
):
|
|
76
|
+
return _call_component(component, self, input_model, output_model)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@functools.singledispatch
|
|
80
|
+
def _call_component(
|
|
81
|
+
component: types.Component,
|
|
82
|
+
backend: FallbackBackend,
|
|
83
|
+
unused_input_model: types.Model,
|
|
84
|
+
unused_output_model: types.Model,
|
|
85
|
+
):
|
|
86
|
+
raise NotImplementedError(
|
|
87
|
+
f"{backend.id()} backend does not support"
|
|
88
|
+
f" {component.component_name} component."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@_call_component.register
|
|
93
|
+
def _apply_plugin(
|
|
94
|
+
component: components.ApplyPluginT,
|
|
95
|
+
backend: FallbackBackend,
|
|
96
|
+
input_model: types.Model,
|
|
97
|
+
output_model: types.Model,
|
|
98
|
+
):
|
|
99
|
+
"""A no-op component that just copies the input model to the output model."""
|
|
100
|
+
del component, backend
|
|
101
|
+
if input_model.in_memory:
|
|
102
|
+
output_model.set_bytes(input_model.model_bytes)
|
|
103
|
+
else:
|
|
104
|
+
output_model.set_path(input_model.path)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@_call_component.register
|
|
108
|
+
def _aie_quantizer(
|
|
109
|
+
component: components.AieQuantizerT,
|
|
110
|
+
backend: FallbackBackend,
|
|
111
|
+
input_model: types.Model,
|
|
112
|
+
output_model: types.Model,
|
|
113
|
+
):
|
|
114
|
+
return component(
|
|
115
|
+
input_model,
|
|
116
|
+
output_model,
|
|
117
|
+
quantization_recipe=backend.quantize_recipe,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@_call_component.register
|
|
122
|
+
def _mlir_transforms(
|
|
123
|
+
component: components.MlirTransformsT,
|
|
124
|
+
unused_backend: FallbackBackend,
|
|
125
|
+
input_model: types.Model,
|
|
126
|
+
output_model: types.Model,
|
|
127
|
+
):
|
|
128
|
+
return component(input_model, output_model, [])
|
|
File without changes
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Backend implementation for the Google Tensor compiler plugin.."""
|
|
16
|
+
|
|
17
|
+
import copy
|
|
18
|
+
import functools
|
|
19
|
+
import os
|
|
20
|
+
import pathlib
|
|
21
|
+
from typing import Iterable
|
|
22
|
+
|
|
23
|
+
from ai_edge_litert.aot.core import common
|
|
24
|
+
from ai_edge_litert.aot.core import components
|
|
25
|
+
from ai_edge_litert.aot.core import types
|
|
26
|
+
from ai_edge_litert.aot.vendors import import_vendor
|
|
27
|
+
from ai_edge_litert.aot.vendors.google_tensor import target as target_lib
|
|
28
|
+
|
|
29
|
+
COMPILER_PLUGIN_LIB_PATH = pathlib.Path(
|
|
30
|
+
"vendors/google_tensor/compiler/libLiteRtCompilerPlugin_google_tensor.so"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Returns true if the flag is a google_tensor flag.
|
|
35
|
+
def _is_google_tensor_flag(flag: str) -> bool:
|
|
36
|
+
return flag.startswith("google_tensor_")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@import_vendor.register_backend
|
|
40
|
+
class GoogleTensorBackend(types.Backend):
|
|
41
|
+
"""Backend implementation for the Google Tensor compiler plugin."""
|
|
42
|
+
|
|
43
|
+
def __init__(self, config: types.Config):
|
|
44
|
+
super().__init__(config)
|
|
45
|
+
self._compilation_config = config.get("compilation_config", None)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def soc_manufacturer(self) -> target_lib.SocManufacturer:
|
|
49
|
+
return target_lib.SocManufacturer.GOOGLE
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def soc_model(self) -> target_lib.SocModel:
|
|
53
|
+
return target_lib.SocModel(self.config.get("soc_model", "ALL"))
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def target(self) -> target_lib.Target:
|
|
57
|
+
return target_lib.Target(self.soc_model, self.soc_manufacturer)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def target_id(self) -> str:
|
|
61
|
+
return repr(self.target)
|
|
62
|
+
|
|
63
|
+
def specialize(self) -> Iterable["GoogleTensorBackend"]:
|
|
64
|
+
if self.soc_model != target_lib.SocModel.ALL:
|
|
65
|
+
yield self
|
|
66
|
+
else:
|
|
67
|
+
for soc_model in target_lib.SocModel:
|
|
68
|
+
if soc_model != target_lib.SocModel.ALL:
|
|
69
|
+
new_config = copy.deepcopy(self.config)
|
|
70
|
+
new_config["soc_model"] = soc_model.value
|
|
71
|
+
yield self.create(new_config)
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def id(cls) -> str:
|
|
75
|
+
return target_lib._GOOGLE_TENSOR_BACKEND_ID # pylint: disable=protected-access
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def create(cls, config: types.Config) -> "GoogleTensorBackend":
|
|
79
|
+
if config.get("backend_id", "") != cls.id():
|
|
80
|
+
raise ValueError("Invalid backend id")
|
|
81
|
+
return cls(config)
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def quantize_recipe(self) -> str | None:
|
|
85
|
+
return self.config.get("quantize_recipe", None)
|
|
86
|
+
|
|
87
|
+
def call_component(
|
|
88
|
+
self,
|
|
89
|
+
input_model: types.Model,
|
|
90
|
+
output_model: types.Model,
|
|
91
|
+
component: types.Component,
|
|
92
|
+
):
|
|
93
|
+
return _call_component(component, self, input_model, output_model)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@functools.singledispatch
|
|
97
|
+
def _call_component(
|
|
98
|
+
component: types.Component,
|
|
99
|
+
backend: GoogleTensorBackend,
|
|
100
|
+
unused_input_model: types.Model,
|
|
101
|
+
unused_output_model: types.Model,
|
|
102
|
+
):
|
|
103
|
+
raise NotImplementedError(
|
|
104
|
+
f"{backend.id()} backend does not support"
|
|
105
|
+
f" {component.component_name} component."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@_call_component.register
|
|
110
|
+
def _apply_plugin(
|
|
111
|
+
component: components.ApplyPluginT,
|
|
112
|
+
backend: GoogleTensorBackend,
|
|
113
|
+
input_model: types.Model,
|
|
114
|
+
output_model: types.Model,
|
|
115
|
+
):
|
|
116
|
+
"""Calls the apply plugin component."""
|
|
117
|
+
try:
|
|
118
|
+
# If the plugin is not built from source (i.e. using ai_edge_litert wheel),
|
|
119
|
+
# we find the plugin library directory from the package path.
|
|
120
|
+
# Otherwise we use the default library path.
|
|
121
|
+
plugin_path = common.get_resource(COMPILER_PLUGIN_LIB_PATH)
|
|
122
|
+
lib_dir = os.path.dirname(plugin_path)
|
|
123
|
+
sdk_libs_dir = os.environ.get("GOOGLE_TENSOR_COMPILER_LIB", None)
|
|
124
|
+
|
|
125
|
+
extra_kwargs = {"libs": lib_dir, "sdk_libs_path": sdk_libs_dir}
|
|
126
|
+
except FileNotFoundError:
|
|
127
|
+
extra_kwargs = {}
|
|
128
|
+
|
|
129
|
+
# Add google_tensor specific flags from the backend config.
|
|
130
|
+
for flag, value in backend.config.items():
|
|
131
|
+
if _is_google_tensor_flag(flag):
|
|
132
|
+
extra_kwargs[flag] = value
|
|
133
|
+
|
|
134
|
+
for flag, value in backend.config.get("compilation_config", {}).items():
|
|
135
|
+
if _is_google_tensor_flag(flag):
|
|
136
|
+
extra_kwargs[flag] = value
|
|
137
|
+
|
|
138
|
+
return component(
|
|
139
|
+
input_model,
|
|
140
|
+
output_model,
|
|
141
|
+
backend.soc_manufacturer,
|
|
142
|
+
backend.soc_model,
|
|
143
|
+
**extra_kwargs,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@_call_component.register
|
|
148
|
+
def _aie_quantizer(
|
|
149
|
+
component: components.AieQuantizerT,
|
|
150
|
+
backend: GoogleTensorBackend,
|
|
151
|
+
input_model: types.Model,
|
|
152
|
+
output_model: types.Model,
|
|
153
|
+
):
|
|
154
|
+
return component(
|
|
155
|
+
input_model,
|
|
156
|
+
output_model,
|
|
157
|
+
quantization_recipe=backend.quantize_recipe,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@_call_component.register
|
|
162
|
+
def _mlir_transforms(
|
|
163
|
+
component: components.MlirTransformsT,
|
|
164
|
+
unused_backend: GoogleTensorBackend,
|
|
165
|
+
input_model: types.Model,
|
|
166
|
+
output_model: types.Model,
|
|
167
|
+
):
|
|
168
|
+
return component(input_model, output_model, [])
|