ai-edge-litert-nightly 1.4.0.dev20250813__cp310-cp310-manylinux_2_27_aarch64.whl → 1.4.0.dev20250815__cp310-cp310-manylinux_2_27_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-litert-nightly might be problematic. Click here for more details.
- ai_edge_litert/__init__.py +1 -1
- ai_edge_litert/libpywrap_litert_common.so +0 -0
- {ai_edge_litert_nightly-1.4.0.dev20250813.dist-info → ai_edge_litert_nightly-1.4.0.dev20250815.dist-info}/METADATA +1 -1
- {ai_edge_litert_nightly-1.4.0.dev20250813.dist-info → ai_edge_litert_nightly-1.4.0.dev20250815.dist-info}/RECORD +6 -31
- ai_edge_litert/aot/__init__.py +0 -0
- ai_edge_litert/aot/ai_pack/__init__.py +0 -0
- ai_edge_litert/aot/ai_pack/export_lib.py +0 -281
- ai_edge_litert/aot/aot_compile.py +0 -152
- ai_edge_litert/aot/core/__init__.py +0 -0
- ai_edge_litert/aot/core/apply_plugin.py +0 -146
- ai_edge_litert/aot/core/common.py +0 -95
- ai_edge_litert/aot/core/components.py +0 -93
- ai_edge_litert/aot/core/mlir_transforms.py +0 -36
- ai_edge_litert/aot/core/tflxx_util.py +0 -30
- ai_edge_litert/aot/core/types.py +0 -374
- ai_edge_litert/aot/prepare_for_npu.py +0 -152
- ai_edge_litert/aot/vendors/__init__.py +0 -18
- ai_edge_litert/aot/vendors/example/__init__.py +0 -0
- ai_edge_litert/aot/vendors/example/example_backend.py +0 -157
- ai_edge_litert/aot/vendors/fallback_backend.py +0 -128
- ai_edge_litert/aot/vendors/import_vendor.py +0 -132
- ai_edge_litert/aot/vendors/mediatek/__init__.py +0 -0
- ai_edge_litert/aot/vendors/mediatek/mediatek_backend.py +0 -196
- ai_edge_litert/aot/vendors/mediatek/target.py +0 -91
- ai_edge_litert/aot/vendors/qualcomm/__init__.py +0 -0
- ai_edge_litert/aot/vendors/qualcomm/qualcomm_backend.py +0 -161
- ai_edge_litert/aot/vendors/qualcomm/target.py +0 -74
- ai_edge_litert/libLiteRtRuntimeCApi.so +0 -0
- ai_edge_litert/tools/apply_plugin_main +0 -0
- {ai_edge_litert_nightly-1.4.0.dev20250813.dist-info → ai_edge_litert_nightly-1.4.0.dev20250815.dist-info}/WHEEL +0 -0
- {ai_edge_litert_nightly-1.4.0.dev20250813.dist-info → ai_edge_litert_nightly-1.4.0.dev20250815.dist-info}/top_level.txt +0 -0
|
@@ -1,152 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
"""Implementations for the main public API functionalities."""
|
|
17
|
-
|
|
18
|
-
import pathlib
|
|
19
|
-
from typing import cast
|
|
20
|
-
|
|
21
|
-
# pylint: disable=g-import-not-at-top
|
|
22
|
-
# pytype: disable=import-error
|
|
23
|
-
try:
|
|
24
|
-
from tqdm import auto as autotqdm
|
|
25
|
-
except ImportError:
|
|
26
|
-
from tqdm.tqdm import auto as autotqdm
|
|
27
|
-
# pytype: enable=import-error
|
|
28
|
-
|
|
29
|
-
from ai_edge_litert.aot.core import common
|
|
30
|
-
from ai_edge_litert.aot.core import components
|
|
31
|
-
from ai_edge_litert.aot.core import types
|
|
32
|
-
from ai_edge_litert.aot.vendors import import_vendor
|
|
33
|
-
|
|
34
|
-
# pylint: enable=g-import-not-at-top
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def resolve_backend(config: types.Config) -> types.BackendT:
|
|
38
|
-
# Import the backend based on the ID.
|
|
39
|
-
backend_id = config.get("backend_id", None)
|
|
40
|
-
if backend_id is None:
|
|
41
|
-
raise ValueError("Backend ID is required.")
|
|
42
|
-
return import_vendor.import_vendor(backend_id)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def prepare_for_npu_multiple_configs(
|
|
46
|
-
flatbuffer: types.Model,
|
|
47
|
-
output_dir: pathlib.Path,
|
|
48
|
-
configs: list[tuple[types.BackendT, types.Config]],
|
|
49
|
-
plugin: components.ApplyPluginT,
|
|
50
|
-
transforms: components.MlirTransformsT | None = None,
|
|
51
|
-
quantizer: components.AieQuantizerT | None = None,
|
|
52
|
-
keep_going: bool = False,
|
|
53
|
-
) -> types.CompilationResult:
|
|
54
|
-
"""Prepares a TFLite model for NPU execution."""
|
|
55
|
-
backends = []
|
|
56
|
-
for backend_class, config in configs:
|
|
57
|
-
backend = backend_class.create(config)
|
|
58
|
-
backends += list(backend.specialize())
|
|
59
|
-
|
|
60
|
-
pipeline: list[types.Component] = [
|
|
61
|
-
c for c in [transforms, quantizer, plugin] if c is not None
|
|
62
|
-
]
|
|
63
|
-
return compile_model(flatbuffer, output_dir, backends, pipeline, keep_going)
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def prepare_for_npu(
|
|
67
|
-
flatbuffer: types.Model,
|
|
68
|
-
output_dir: pathlib.Path,
|
|
69
|
-
backend_class: types.BackendT,
|
|
70
|
-
config: types.Config,
|
|
71
|
-
plugin: components.ApplyPluginT,
|
|
72
|
-
transforms: components.MlirTransformsT | None = None,
|
|
73
|
-
quantizer: components.AieQuantizerT | None = None,
|
|
74
|
-
keep_going: bool = False,
|
|
75
|
-
) -> types.CompilationResult:
|
|
76
|
-
"""Prepares a TFLite model for NPU execution.
|
|
77
|
-
|
|
78
|
-
High level command that erforms various backend specific pre-processing steps
|
|
79
|
-
and then applies an NPU compiler to the given model.
|
|
80
|
-
|
|
81
|
-
Args:
|
|
82
|
-
flatbuffer: Path to the input flatbuffer file.
|
|
83
|
-
output_dir: Directory to write the output flatbuffer file.
|
|
84
|
-
backend_class: The backend to prepare the model for.
|
|
85
|
-
config: The configuration for the backend.
|
|
86
|
-
plugin: The plugin to apply to the model.
|
|
87
|
-
transforms: The transforms to apply to the model.
|
|
88
|
-
quantizer: The quantizer to apply to the model.
|
|
89
|
-
keep_going: Whether to keep going if some backends fail.
|
|
90
|
-
|
|
91
|
-
Returns:
|
|
92
|
-
List of the paths to the output flatbuffer file.
|
|
93
|
-
|
|
94
|
-
Raises:
|
|
95
|
-
ValueError: If the given path is not a valid flatbuffer file.
|
|
96
|
-
"""
|
|
97
|
-
|
|
98
|
-
backend = backend_class.create(config)
|
|
99
|
-
|
|
100
|
-
pipeline: list[types.Component] = [
|
|
101
|
-
c for c in [transforms, quantizer, plugin] if c is not None
|
|
102
|
-
]
|
|
103
|
-
backends = list(backend.specialize())
|
|
104
|
-
return compile_model(flatbuffer, output_dir, backends, pipeline, keep_going)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
def compile_model(
|
|
108
|
-
flatbuffer: types.Model,
|
|
109
|
-
output_dir: pathlib.Path,
|
|
110
|
-
backends: list[types.Backend],
|
|
111
|
-
pipeline: list[types.Component],
|
|
112
|
-
keep_going: bool = False,
|
|
113
|
-
) -> types.CompilationResult:
|
|
114
|
-
"""Compiles a TFLite model for NPU execution."""
|
|
115
|
-
if flatbuffer.in_memory:
|
|
116
|
-
base_name = "model"
|
|
117
|
-
else:
|
|
118
|
-
base_name = flatbuffer.path.name.removesuffix(common.DOT_TFLITE)
|
|
119
|
-
compile_models = types.CompilationResult()
|
|
120
|
-
with autotqdm.tqdm(backends, desc="Backend") as t_backends:
|
|
121
|
-
for backend in t_backends:
|
|
122
|
-
component_input = flatbuffer
|
|
123
|
-
backend = cast(types.Backend, backend)
|
|
124
|
-
input_name_pref = base_name + backend.target_id_suffix
|
|
125
|
-
t_backends.set_description(f"Compiling {backend.target_id}")
|
|
126
|
-
try:
|
|
127
|
-
for component in pipeline:
|
|
128
|
-
component = cast(types.Component, component)
|
|
129
|
-
t_backends.set_description(
|
|
130
|
-
f"Compiling {backend.target_id}: {component.component_name}"
|
|
131
|
-
)
|
|
132
|
-
component_output = types.Model.create_from_path(
|
|
133
|
-
output_dir
|
|
134
|
-
/ f"{input_name_pref}_{component.component_name}{common.DOT_TFLITE}"
|
|
135
|
-
)
|
|
136
|
-
backend.call_component(component_input, component_output, component)
|
|
137
|
-
if not component_output.in_memory and not common.is_tflite(
|
|
138
|
-
component_output.path
|
|
139
|
-
):
|
|
140
|
-
raise ValueError(
|
|
141
|
-
f"{component.component_name} failed to produce a TFLite model."
|
|
142
|
-
)
|
|
143
|
-
component_input = component_output
|
|
144
|
-
compile_models.models_with_backend.append((backend, component_input))
|
|
145
|
-
except ValueError as e:
|
|
146
|
-
if keep_going:
|
|
147
|
-
print(f"Skipping failed compilation for {backend.target}. Error: {e}")
|
|
148
|
-
compile_models.failed_backends.append((backend, str(e)))
|
|
149
|
-
else:
|
|
150
|
-
raise
|
|
151
|
-
|
|
152
|
-
return compile_models
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
"""Vendor backends for LiteRt."""
|
|
16
|
-
|
|
17
|
-
from ai_edge_litert.aot.vendors.mediatek import mediatek_backend as _
|
|
18
|
-
from ai_edge_litert.aot.vendors.qualcomm import qualcomm_backend as _
|
|
File without changes
|
|
@@ -1,157 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
"""Backend implementation for the example compiler plugin.."""
|
|
17
|
-
|
|
18
|
-
import functools
|
|
19
|
-
from typing import Any
|
|
20
|
-
|
|
21
|
-
from ai_edge_litert.aot.core import components
|
|
22
|
-
from ai_edge_litert.aot.core import types
|
|
23
|
-
from ai_edge_litert.aot.vendors import import_vendor
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class ExampleTarget(types.Target):
|
|
27
|
-
"""Compilation target for the example backend."""
|
|
28
|
-
|
|
29
|
-
def __init__(self, soc_manufacturer: str, soc_model: str):
|
|
30
|
-
self.soc_manufacturer = soc_manufacturer
|
|
31
|
-
self.soc_model = soc_model
|
|
32
|
-
|
|
33
|
-
def __hash__(self) -> int:
|
|
34
|
-
return hash((self.soc_manufacturer, self.soc_model))
|
|
35
|
-
|
|
36
|
-
def __eq__(self, other) -> bool:
|
|
37
|
-
return (
|
|
38
|
-
self.soc_manufacturer == other.soc_manufacturer
|
|
39
|
-
and self.soc_model == other.soc_model
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
def __repr__(self) -> str:
|
|
43
|
-
return f"{self.soc_manufacturer}_{self.soc_model}"
|
|
44
|
-
|
|
45
|
-
def flatten(self) -> dict[str, Any]:
|
|
46
|
-
return {
|
|
47
|
-
"soc_manufacturer": self.soc_manufacturer,
|
|
48
|
-
"soc_model": self.soc_model,
|
|
49
|
-
}
|
|
50
|
-
|
|
51
|
-
@classmethod
|
|
52
|
-
def backend_id(cls) -> str:
|
|
53
|
-
return "example"
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
# Note this is not a real target so not auto-registered unless the module is
|
|
57
|
-
# imported.
|
|
58
|
-
@import_vendor.register_backend
|
|
59
|
-
class ExampleBackend(types.Backend):
|
|
60
|
-
"""Backend implementation for the example compiler plugin."""
|
|
61
|
-
|
|
62
|
-
def __init__(self, config: types.Config):
|
|
63
|
-
super().__init__(config)
|
|
64
|
-
self._compilation_config = config.get("compilation_config", None)
|
|
65
|
-
|
|
66
|
-
@classmethod
|
|
67
|
-
def target_(cls) -> ExampleTarget:
|
|
68
|
-
return ExampleTarget("ExampleSocManufacturer", "ExampleSocModel")
|
|
69
|
-
|
|
70
|
-
@property
|
|
71
|
-
def target(self) -> ExampleTarget:
|
|
72
|
-
return self.target_()
|
|
73
|
-
|
|
74
|
-
@classmethod
|
|
75
|
-
def soc_manufacturer(cls) -> str:
|
|
76
|
-
return cls.target_().soc_manufacturer
|
|
77
|
-
|
|
78
|
-
@classmethod
|
|
79
|
-
def soc_model(cls) -> str:
|
|
80
|
-
return cls.target_().soc_model
|
|
81
|
-
|
|
82
|
-
@classmethod
|
|
83
|
-
def id(cls) -> str:
|
|
84
|
-
return "example"
|
|
85
|
-
|
|
86
|
-
@property
|
|
87
|
-
def target_id(self) -> str:
|
|
88
|
-
return ""
|
|
89
|
-
|
|
90
|
-
@property
|
|
91
|
-
def shared_pass_names(self) -> list[str]:
|
|
92
|
-
return ["example-pass"]
|
|
93
|
-
|
|
94
|
-
@classmethod
|
|
95
|
-
def create(cls, config: types.Config) -> "ExampleBackend":
|
|
96
|
-
if config.get("backend_id", "") != cls.id():
|
|
97
|
-
raise ValueError("Invalid backend id")
|
|
98
|
-
return cls(config)
|
|
99
|
-
|
|
100
|
-
def call_component(
|
|
101
|
-
self,
|
|
102
|
-
input_model: types.Model,
|
|
103
|
-
output_model: types.Model,
|
|
104
|
-
component: types.Component,
|
|
105
|
-
):
|
|
106
|
-
return _call_component(component, self, input_model, output_model)
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
@functools.singledispatch
|
|
110
|
-
def _call_component(
|
|
111
|
-
component: types.Component,
|
|
112
|
-
backend: ExampleBackend,
|
|
113
|
-
unused_input_model: types.Model,
|
|
114
|
-
unused_output_model: types.Model,
|
|
115
|
-
):
|
|
116
|
-
raise NotImplementedError(
|
|
117
|
-
f"{backend.id()} backend does not support"
|
|
118
|
-
f" {component.component_name} component."
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
@_call_component.register
|
|
123
|
-
def _apply_plugin(
|
|
124
|
-
component: components.ApplyPluginT,
|
|
125
|
-
backend: ExampleBackend,
|
|
126
|
-
input_model: types.Model,
|
|
127
|
-
output_model: types.Model,
|
|
128
|
-
):
|
|
129
|
-
return component(
|
|
130
|
-
input_model,
|
|
131
|
-
output_model,
|
|
132
|
-
backend.soc_manufacturer,
|
|
133
|
-
backend.soc_model,
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
@_call_component.register
|
|
138
|
-
def _aie_quantizer(
|
|
139
|
-
component: components.AieQuantizerT,
|
|
140
|
-
unused_backend: ExampleBackend,
|
|
141
|
-
input_model: types.Model,
|
|
142
|
-
output_model: types.Model,
|
|
143
|
-
):
|
|
144
|
-
return component(
|
|
145
|
-
input_model,
|
|
146
|
-
output_model,
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
@_call_component.register
|
|
151
|
-
def _mlir_transforms(
|
|
152
|
-
component: components.MlirTransformsT,
|
|
153
|
-
backend: ExampleBackend,
|
|
154
|
-
input_model: types.Model,
|
|
155
|
-
output_model: types.Model,
|
|
156
|
-
):
|
|
157
|
-
return component(input_model, output_model, backend.shared_pass_names)
|
|
@@ -1,128 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
"""A Fallback backend for LITERT."""
|
|
17
|
-
|
|
18
|
-
import functools
|
|
19
|
-
from typing import Any
|
|
20
|
-
|
|
21
|
-
from ai_edge_litert.aot.core import components
|
|
22
|
-
from ai_edge_litert.aot.core import types
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class FallbackTarget(types.Target):
|
|
26
|
-
"""A virtual Compilation target."""
|
|
27
|
-
|
|
28
|
-
def __hash__(self) -> int:
|
|
29
|
-
return hash(self.backend_id())
|
|
30
|
-
|
|
31
|
-
def __eq__(self, other: types.Target) -> bool:
|
|
32
|
-
return self.backend_id() == other.backend_id()
|
|
33
|
-
|
|
34
|
-
def __repr__(self) -> str:
|
|
35
|
-
return f"{self.backend_id()}"
|
|
36
|
-
|
|
37
|
-
@classmethod
|
|
38
|
-
def backend_id(cls) -> str:
|
|
39
|
-
return "fallback"
|
|
40
|
-
|
|
41
|
-
def flatten(self) -> dict[str, Any]:
|
|
42
|
-
return {"backend_id": self.backend_id()}
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class FallbackBackend(types.Backend):
|
|
46
|
-
"""Fallback backend for LITERT."""
|
|
47
|
-
|
|
48
|
-
@property
|
|
49
|
-
def target(self) -> FallbackTarget:
|
|
50
|
-
return FallbackTarget()
|
|
51
|
-
|
|
52
|
-
@property
|
|
53
|
-
def target_id(self) -> str:
|
|
54
|
-
return repr(self.target)
|
|
55
|
-
|
|
56
|
-
@classmethod
|
|
57
|
-
def id(cls) -> str:
|
|
58
|
-
return "fallback"
|
|
59
|
-
|
|
60
|
-
@classmethod
|
|
61
|
-
def create(cls, config: types.Config) -> "FallbackBackend":
|
|
62
|
-
if config.get("backend_id", "") != cls.id():
|
|
63
|
-
raise ValueError("Invalid backend id")
|
|
64
|
-
return cls(config)
|
|
65
|
-
|
|
66
|
-
@property
|
|
67
|
-
def quantize_recipe(self) -> str | None:
|
|
68
|
-
return self.config.get("quantize_recipe", None)
|
|
69
|
-
|
|
70
|
-
def call_component(
|
|
71
|
-
self,
|
|
72
|
-
input_model: types.Model,
|
|
73
|
-
output_model: types.Model,
|
|
74
|
-
component: types.Component,
|
|
75
|
-
):
|
|
76
|
-
return _call_component(component, self, input_model, output_model)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
@functools.singledispatch
|
|
80
|
-
def _call_component(
|
|
81
|
-
component: types.Component,
|
|
82
|
-
backend: FallbackBackend,
|
|
83
|
-
unused_input_model: types.Model,
|
|
84
|
-
unused_output_model: types.Model,
|
|
85
|
-
):
|
|
86
|
-
raise NotImplementedError(
|
|
87
|
-
f"{backend.id()} backend does not support"
|
|
88
|
-
f" {component.component_name} component."
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
@_call_component.register
|
|
93
|
-
def _apply_plugin(
|
|
94
|
-
component: components.ApplyPluginT,
|
|
95
|
-
backend: FallbackBackend,
|
|
96
|
-
input_model: types.Model,
|
|
97
|
-
output_model: types.Model,
|
|
98
|
-
):
|
|
99
|
-
"""A no-op component that just copies the input model to the output model."""
|
|
100
|
-
del component, backend
|
|
101
|
-
if input_model.in_memory:
|
|
102
|
-
output_model.set_bytes(input_model.model_bytes)
|
|
103
|
-
else:
|
|
104
|
-
output_model.set_path(input_model.path)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
@_call_component.register
|
|
108
|
-
def _aie_quantizer(
|
|
109
|
-
component: components.AieQuantizerT,
|
|
110
|
-
backend: FallbackBackend,
|
|
111
|
-
input_model: types.Model,
|
|
112
|
-
output_model: types.Model,
|
|
113
|
-
):
|
|
114
|
-
return component(
|
|
115
|
-
input_model,
|
|
116
|
-
output_model,
|
|
117
|
-
quantization_recipe=backend.quantize_recipe,
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
@_call_component.register
|
|
122
|
-
def _mlir_transforms(
|
|
123
|
-
component: components.MlirTransformsT,
|
|
124
|
-
unused_backend: FallbackBackend,
|
|
125
|
-
input_model: types.Model,
|
|
126
|
-
output_model: types.Model,
|
|
127
|
-
):
|
|
128
|
-
return component(input_model, output_model, [])
|
|
@@ -1,132 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
"""Utility for dynamically importing vendor backends based on ID."""
|
|
17
|
-
|
|
18
|
-
import copy
|
|
19
|
-
import dataclasses
|
|
20
|
-
from typing import Any, Iterable
|
|
21
|
-
|
|
22
|
-
from ai_edge_litert.aot.core import types
|
|
23
|
-
from ai_edge_litert.aot.vendors import fallback_backend
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@dataclasses.dataclass
|
|
27
|
-
class VendorRegistry:
|
|
28
|
-
"""A vendor registry."""
|
|
29
|
-
|
|
30
|
-
backend_class: types.BackendT
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
_VENDOR_REGISTRY: dict[str, VendorRegistry] = {}
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def register_backend(
|
|
37
|
-
backend_class: types.BackendT,
|
|
38
|
-
):
|
|
39
|
-
backend_id = backend_class.id()
|
|
40
|
-
_VENDOR_REGISTRY[backend_id] = VendorRegistry(backend_class)
|
|
41
|
-
|
|
42
|
-
return backend_class
|
|
43
|
-
|
|
44
|
-
register_backend(fallback_backend.FallbackBackend)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def import_vendor(backend_id: str) -> types.BackendT:
|
|
48
|
-
"""Imports a vendor backend class based on its ID.
|
|
49
|
-
|
|
50
|
-
Args:
|
|
51
|
-
backend_id: The ID of the backend to import.
|
|
52
|
-
|
|
53
|
-
Returns:
|
|
54
|
-
The imported backend class.
|
|
55
|
-
|
|
56
|
-
Raises:
|
|
57
|
-
ValueError: If the backend ID is not supported.
|
|
58
|
-
"""
|
|
59
|
-
vendor_module = _VENDOR_REGISTRY.get(backend_id, None)
|
|
60
|
-
if vendor_module is None:
|
|
61
|
-
raise ValueError(f'Unsupported backend id: {backend_id}')
|
|
62
|
-
|
|
63
|
-
return vendor_module.backend_class
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class AllRegisteredTarget(types.Target):
|
|
67
|
-
"""A virtual Compilation target."""
|
|
68
|
-
|
|
69
|
-
def __hash__(self) -> int:
|
|
70
|
-
raise NotImplementedError()
|
|
71
|
-
|
|
72
|
-
def __eq__(self, other) -> bool:
|
|
73
|
-
raise NotImplementedError()
|
|
74
|
-
|
|
75
|
-
def __repr__(self) -> str:
|
|
76
|
-
raise NotImplementedError()
|
|
77
|
-
|
|
78
|
-
@classmethod
|
|
79
|
-
def backend_id(cls) -> str:
|
|
80
|
-
return 'all'
|
|
81
|
-
|
|
82
|
-
def flatten(self) -> dict[str, Any]:
|
|
83
|
-
return {'backend_id': self.backend_id()}
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
@register_backend
|
|
87
|
-
class AllRegisteredBackend(types.Backend):
|
|
88
|
-
"""A virtual backend that represents all registered backends."""
|
|
89
|
-
|
|
90
|
-
# NOTE: Only initialize through "create".
|
|
91
|
-
def __init__(self, config: types.Config):
|
|
92
|
-
self._config = config
|
|
93
|
-
|
|
94
|
-
@classmethod
|
|
95
|
-
def create(cls, config: types.Config) -> 'AllRegisteredBackend':
|
|
96
|
-
return AllRegisteredBackend(config)
|
|
97
|
-
|
|
98
|
-
@classmethod
|
|
99
|
-
def id(cls) -> str:
|
|
100
|
-
return 'all'
|
|
101
|
-
|
|
102
|
-
@property
|
|
103
|
-
def target(self) -> AllRegisteredTarget:
|
|
104
|
-
return AllRegisteredTarget()
|
|
105
|
-
|
|
106
|
-
@property
|
|
107
|
-
def target_id(self) -> str:
|
|
108
|
-
return ''
|
|
109
|
-
|
|
110
|
-
@property
|
|
111
|
-
def config(self) -> types.Config:
|
|
112
|
-
return self._config
|
|
113
|
-
|
|
114
|
-
def call_component(
|
|
115
|
-
self,
|
|
116
|
-
input_model: types.Model,
|
|
117
|
-
output_model: types.Model,
|
|
118
|
-
component: types.Component,
|
|
119
|
-
):
|
|
120
|
-
del input_model, output_model, component
|
|
121
|
-
raise NotImplementedError(
|
|
122
|
-
'AllRegisteredBackend does not support any component.'
|
|
123
|
-
)
|
|
124
|
-
|
|
125
|
-
def specialize(self) -> Iterable[types.Backend]:
|
|
126
|
-
for backend_id, vendor_module in _VENDOR_REGISTRY.items():
|
|
127
|
-
if backend_id == 'all':
|
|
128
|
-
continue
|
|
129
|
-
config = copy.deepcopy(self.config)
|
|
130
|
-
config['backend_id'] = backend_id
|
|
131
|
-
backend = vendor_module.backend_class.create(config)
|
|
132
|
-
yield from backend.specialize()
|
|
File without changes
|