ai-edge-litert-nightly 2.2.0.dev20260102__cp312-cp312-manylinux_2_27_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_litert/__init__.py +1 -0
- ai_edge_litert/_pywrap_analyzer_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_compiled_model_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_interpreter_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_tensor_buffer_wrapper.so +0 -0
- ai_edge_litert/_pywrap_modify_model_interface.so +0 -0
- ai_edge_litert/_pywrap_string_util.so +0 -0
- ai_edge_litert/_pywrap_tensorflow_lite_calibration_wrapper.so +0 -0
- ai_edge_litert/_pywrap_tensorflow_lite_metrics_wrapper.so +0 -0
- ai_edge_litert/any_pb2.py +37 -0
- ai_edge_litert/aot/__init__.py +0 -0
- ai_edge_litert/aot/ai_pack/__init__.py +0 -0
- ai_edge_litert/aot/ai_pack/export_lib.py +300 -0
- ai_edge_litert/aot/aot_compile.py +153 -0
- ai_edge_litert/aot/core/__init__.py +0 -0
- ai_edge_litert/aot/core/apply_plugin.py +148 -0
- ai_edge_litert/aot/core/common.py +97 -0
- ai_edge_litert/aot/core/components.py +93 -0
- ai_edge_litert/aot/core/mlir_transforms.py +36 -0
- ai_edge_litert/aot/core/tflxx_util.py +30 -0
- ai_edge_litert/aot/core/types.py +374 -0
- ai_edge_litert/aot/prepare_for_npu.py +152 -0
- ai_edge_litert/aot/vendors/__init__.py +22 -0
- ai_edge_litert/aot/vendors/example/__init__.py +0 -0
- ai_edge_litert/aot/vendors/example/example_backend.py +157 -0
- ai_edge_litert/aot/vendors/fallback_backend.py +128 -0
- ai_edge_litert/aot/vendors/google_tensor/__init__.py +0 -0
- ai_edge_litert/aot/vendors/google_tensor/google_tensor_backend.py +168 -0
- ai_edge_litert/aot/vendors/google_tensor/target.py +84 -0
- ai_edge_litert/aot/vendors/import_vendor.py +132 -0
- ai_edge_litert/aot/vendors/mediatek/__init__.py +0 -0
- ai_edge_litert/aot/vendors/mediatek/mediatek_backend.py +196 -0
- ai_edge_litert/aot/vendors/mediatek/target.py +94 -0
- ai_edge_litert/aot/vendors/qualcomm/__init__.py +0 -0
- ai_edge_litert/aot/vendors/qualcomm/qualcomm_backend.py +161 -0
- ai_edge_litert/aot/vendors/qualcomm/target.py +75 -0
- ai_edge_litert/api_pb2.py +43 -0
- ai_edge_litert/compiled_model.py +250 -0
- ai_edge_litert/descriptor_pb2.py +3361 -0
- ai_edge_litert/duration_pb2.py +37 -0
- ai_edge_litert/empty_pb2.py +37 -0
- ai_edge_litert/field_mask_pb2.py +37 -0
- ai_edge_litert/format_converter_wrapper_pybind11.so +0 -0
- ai_edge_litert/hardware_accelerator.py +22 -0
- ai_edge_litert/internal/__init__.py +0 -0
- ai_edge_litert/internal/litertlm_builder.py +584 -0
- ai_edge_litert/internal/litertlm_core.py +58 -0
- ai_edge_litert/internal/litertlm_header_schema_py_generated.py +1596 -0
- ai_edge_litert/internal/llm_metadata_pb2.py +45 -0
- ai_edge_litert/internal/llm_model_type_pb2.py +51 -0
- ai_edge_litert/internal/sampler_params_pb2.py +39 -0
- ai_edge_litert/internal/token_pb2.py +38 -0
- ai_edge_litert/interpreter.py +1039 -0
- ai_edge_litert/libLiteRt.so +0 -0
- ai_edge_litert/libpywrap_litert_common.so +0 -0
- ai_edge_litert/metrics_interface.py +48 -0
- ai_edge_litert/metrics_portable.py +70 -0
- ai_edge_litert/model_runtime_info_pb2.py +66 -0
- ai_edge_litert/plugin_pb2.py +46 -0
- ai_edge_litert/profiling_info_pb2.py +47 -0
- ai_edge_litert/pywrap_genai_ops.so +0 -0
- ai_edge_litert/schema_py_generated.py +19640 -0
- ai_edge_litert/source_context_pb2.py +37 -0
- ai_edge_litert/struct_pb2.py +47 -0
- ai_edge_litert/tensor_buffer.py +167 -0
- ai_edge_litert/timestamp_pb2.py +37 -0
- ai_edge_litert/tools/__init__.py +0 -0
- ai_edge_litert/tools/apply_plugin_main +0 -0
- ai_edge_litert/tools/flatbuffer_utils.py +534 -0
- ai_edge_litert/type_pb2.py +53 -0
- ai_edge_litert/vendors/google_tensor/compiler/libLiteRtCompilerPlugin_google_tensor.so +0 -0
- ai_edge_litert/vendors/mediatek/compiler/libLiteRtCompilerPlugin_MediaTek.so +0 -0
- ai_edge_litert/vendors/qualcomm/compiler/libLiteRtCompilerPlugin_Qualcomm.so +0 -0
- ai_edge_litert/wrappers_pb2.py +53 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/METADATA +52 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/RECORD +78 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/WHEEL +5 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Constants and other small generic utilities."""
|
|
17
|
+
|
|
18
|
+
from importlib import resources
|
|
19
|
+
import os
|
|
20
|
+
import pathlib
|
|
21
|
+
|
|
22
|
+
TFLITE = "tflite"
|
|
23
|
+
DOT_TFLITE = f".{TFLITE}"
|
|
24
|
+
NPU = "npu"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
_WORKSPACE_PREFIX = "litert"
|
|
28
|
+
_AI_EDGE_LITERT_PREFIX = "ai_edge_litert"
|
|
29
|
+
_LITERT_ROOT = ""
|
|
30
|
+
_PYTHON_ROOT = "python/aot"
|
|
31
|
+
|
|
32
|
+
MODULE_ROOT = ".".join([
|
|
33
|
+
_WORKSPACE_PREFIX,
|
|
34
|
+
_LITERT_ROOT.replace("/", "."),
|
|
35
|
+
_PYTHON_ROOT.replace("/", "."),
|
|
36
|
+
])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_resource(
|
|
40
|
+
litert_relative_path: pathlib.Path, is_dir=False
|
|
41
|
+
) -> pathlib.Path:
|
|
42
|
+
"""Returns the path to a resource in the Litert workspace."""
|
|
43
|
+
try:
|
|
44
|
+
resource_root = resources.files(_WORKSPACE_PREFIX)
|
|
45
|
+
except ModuleNotFoundError:
|
|
46
|
+
resource_root = resources.files(_AI_EDGE_LITERT_PREFIX)
|
|
47
|
+
litert_resource = resource_root.joinpath(
|
|
48
|
+
_LITERT_ROOT, str(litert_relative_path)
|
|
49
|
+
)
|
|
50
|
+
if not is_dir and not litert_resource.is_file():
|
|
51
|
+
raise FileNotFoundError(f"Resource {litert_resource} does not exist.")
|
|
52
|
+
return pathlib.Path(str(litert_resource))
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def is_tflite(path: pathlib.Path) -> bool:
|
|
56
|
+
return path.exists() and path.is_file() and path.suffix == f".{TFLITE}"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def construct_ld_library_path() -> str:
|
|
60
|
+
"""Constructs a string suitable for the LD_LIBRARY_PATH environment variable.
|
|
61
|
+
|
|
62
|
+
This function is used in ai_edge_litert python package, when the shared
|
|
63
|
+
libraries are not in a static location. This function will construct the
|
|
64
|
+
LD_LIBRARY_PATH environment variable using the ai_edge_litert directory, and
|
|
65
|
+
all subdirectories.
|
|
66
|
+
|
|
67
|
+
If the module is built from source, this function will return an empty string.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
A string suitable for the LD_LIBRARY_PATH environment variable.
|
|
71
|
+
"""
|
|
72
|
+
try:
|
|
73
|
+
resource_root = resources.files(_AI_EDGE_LITERT_PREFIX)
|
|
74
|
+
except ModuleNotFoundError:
|
|
75
|
+
# Bulit from source case.
|
|
76
|
+
return ""
|
|
77
|
+
root_package_path = str(resource_root)
|
|
78
|
+
|
|
79
|
+
library_paths = set()
|
|
80
|
+
|
|
81
|
+
library_paths.add(os.path.abspath(root_package_path))
|
|
82
|
+
|
|
83
|
+
for dirpath, _, _ in os.walk(root_package_path):
|
|
84
|
+
library_paths.add(os.path.abspath(dirpath))
|
|
85
|
+
|
|
86
|
+
sorted_paths = sorted(list(library_paths))
|
|
87
|
+
new_ld_library_path = os.pathsep.join(sorted_paths)
|
|
88
|
+
current_ld_library_path = os.environ.get("LD_LIBRARY_PATH")
|
|
89
|
+
|
|
90
|
+
if current_ld_library_path:
|
|
91
|
+
if current_ld_library_path not in new_ld_library_path:
|
|
92
|
+
lib_paths = f"{new_ld_library_path}{os.pathsep}{current_ld_library_path}"
|
|
93
|
+
else:
|
|
94
|
+
lib_paths = new_ld_library_path
|
|
95
|
+
else:
|
|
96
|
+
lib_paths = new_ld_library_path
|
|
97
|
+
return lib_paths
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Interfaces for specific components used in the LiteRt AOT flow."""
|
|
17
|
+
|
|
18
|
+
import abc
|
|
19
|
+
import sys
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
from ai_edge_litert.aot.core import types
|
|
23
|
+
|
|
24
|
+
# pylint: disable=g-importing-member
|
|
25
|
+
# pylint: disable=g-import-not-at-top
|
|
26
|
+
# pylint: disable=g-bad-import-order
|
|
27
|
+
if sys.version_info < (3, 10):
|
|
28
|
+
from typing_extensions import TypeAlias
|
|
29
|
+
else:
|
|
30
|
+
from typing import TypeAlias
|
|
31
|
+
# pylint: enable=g-bad-import-order
|
|
32
|
+
# pylint: enable=g-import-not-at-top
|
|
33
|
+
# pylint: enable=g-importing-member
|
|
34
|
+
|
|
35
|
+
QuantRecipe: TypeAlias = list[dict[str, Any]] | str
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class AieQuantizerT(metaclass=abc.ABCMeta):
|
|
39
|
+
"""Interface for AIE quantizer components."""
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def component_name(self) -> str:
|
|
43
|
+
return "aie_quantizer"
|
|
44
|
+
|
|
45
|
+
@abc.abstractmethod
|
|
46
|
+
def __call__(
|
|
47
|
+
self,
|
|
48
|
+
input_model: types.Model,
|
|
49
|
+
output_model: types.Model,
|
|
50
|
+
quantization_recipe: QuantRecipe | None = None,
|
|
51
|
+
*args,
|
|
52
|
+
**kwargs,
|
|
53
|
+
):
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ApplyPluginT(metaclass=abc.ABCMeta):
|
|
58
|
+
"""Interface for apply plugin components."""
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def default_err(self) -> str:
|
|
62
|
+
# NOTE: Capture stderr from underlying binary.
|
|
63
|
+
return "none"
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def component_name(self) -> str:
|
|
67
|
+
return "apply_plugin"
|
|
68
|
+
|
|
69
|
+
@abc.abstractmethod
|
|
70
|
+
def __call__(
|
|
71
|
+
self,
|
|
72
|
+
input_model: types.Model,
|
|
73
|
+
output_model: types.Model,
|
|
74
|
+
soc_manufacturer: str,
|
|
75
|
+
soc_model: str,
|
|
76
|
+
*args,
|
|
77
|
+
**kwargs,
|
|
78
|
+
):
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class MlirTransformsT(metaclass=abc.ABCMeta):
|
|
83
|
+
"""Interface for MLIR transforms components."""
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def component_name(self) -> str:
|
|
87
|
+
return "mlir_transforms"
|
|
88
|
+
|
|
89
|
+
@abc.abstractmethod
|
|
90
|
+
def __call__(
|
|
91
|
+
self, input_model: types.Model, output_model: types.Model, *args, **kwargs
|
|
92
|
+
):
|
|
93
|
+
pass
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Wrapper for suite of MLIR passes."""
|
|
17
|
+
|
|
18
|
+
from ai_edge_litert.aot.core import components
|
|
19
|
+
from ai_edge_litert.aot.core import tflxx_util
|
|
20
|
+
from ai_edge_litert.aot.core import types
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MlirTransforms(components.MlirTransformsT):
|
|
24
|
+
"""Wrapper for suite of MLIR passes."""
|
|
25
|
+
|
|
26
|
+
def __call__(
|
|
27
|
+
self,
|
|
28
|
+
input_model: types.Model,
|
|
29
|
+
output_model: types.Model,
|
|
30
|
+
pass_name: str,
|
|
31
|
+
):
|
|
32
|
+
if not input_model.in_memory:
|
|
33
|
+
input_model.load()
|
|
34
|
+
input_bytes = input_model.model_bytes
|
|
35
|
+
output_bytes = tflxx_util.call_tflxx(input_bytes, pass_name)
|
|
36
|
+
output_model.set_bytes(output_bytes)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
# pylint: disable=g-import-not-at-top
|
|
17
|
+
# pytype: disable=import-error
|
|
18
|
+
# pytype: disable=not-callable
|
|
19
|
+
|
|
20
|
+
"""Shim layer for TFLXX while it is in experimental."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
import importlib.util
|
|
24
|
+
from typing import Callable
|
|
25
|
+
|
|
26
|
+
call_tflxx: Callable[[bytes, str], bytes] = lambda input, pass_name: input
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def tflxx_enabled() -> bool:
|
|
30
|
+
return False
|
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Basic types used in the LiteRt AOT flow."""
|
|
17
|
+
|
|
18
|
+
import abc
|
|
19
|
+
from collections.abc import Iterable
|
|
20
|
+
import dataclasses
|
|
21
|
+
import pathlib
|
|
22
|
+
import sys
|
|
23
|
+
from typing import Any, MutableMapping, Protocol, Type
|
|
24
|
+
|
|
25
|
+
# pylint: disable=g-importing-member
|
|
26
|
+
# pylint: disable=g-import-not-at-top
|
|
27
|
+
# pylint: disable=g-bad-import-order
|
|
28
|
+
if sys.version_info < (3, 10):
|
|
29
|
+
from typing_extensions import TypeAlias
|
|
30
|
+
else:
|
|
31
|
+
from typing import TypeAlias
|
|
32
|
+
# pylint: enable=g-bad-import-order
|
|
33
|
+
# pylint: enable=g-import-not-at-top
|
|
34
|
+
# pylint: enable=g-importing-member
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclasses.dataclass(frozen=True)
|
|
38
|
+
class SubgraphPartitionStats:
|
|
39
|
+
"""Subgraph partition stats."""
|
|
40
|
+
|
|
41
|
+
subgraph_index: int
|
|
42
|
+
num_ops_offloaded: int
|
|
43
|
+
num_total_ops: int
|
|
44
|
+
num_partitions_offloaded: int
|
|
45
|
+
|
|
46
|
+
def __str__(self) -> str:
|
|
47
|
+
is_full_offload = self.num_ops_offloaded == self.num_total_ops
|
|
48
|
+
return (
|
|
49
|
+
'Subgraph'
|
|
50
|
+
f' {self.subgraph_index} {"fully" if is_full_offload else "partially"}'
|
|
51
|
+
f' compiled:\t{self.num_ops_offloaded} /'
|
|
52
|
+
f' {self.num_total_ops} ops offloaded to'
|
|
53
|
+
f' {self.num_partitions_offloaded} partitions.'
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclasses.dataclass(frozen=True)
|
|
58
|
+
class PartitionStats:
|
|
59
|
+
"""Model partition stats."""
|
|
60
|
+
|
|
61
|
+
subgraph_stats: list[SubgraphPartitionStats]
|
|
62
|
+
|
|
63
|
+
def __str__(self) -> str:
|
|
64
|
+
return '\n'.join(str(s) for s in self.subgraph_stats)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class Model:
|
|
68
|
+
"""A model.
|
|
69
|
+
|
|
70
|
+
Note: If the model is not in memory, data_ will be a path to a file on disk.
|
|
71
|
+
If the model is in memory, data_ will be the model bytes.
|
|
72
|
+
|
|
73
|
+
However, there's no guarantee that the path will be a valid path to a file
|
|
74
|
+
on disk, and/or that the file are a valid TFLite model.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
data_: pathlib.Path | bytes
|
|
78
|
+
partition_stats: PartitionStats | None = None
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
path: pathlib.Path | str | None = None,
|
|
83
|
+
model_bytes: bytes | None = None,
|
|
84
|
+
):
|
|
85
|
+
if path is not None:
|
|
86
|
+
if isinstance(path, str):
|
|
87
|
+
path = pathlib.Path(path)
|
|
88
|
+
if model_bytes:
|
|
89
|
+
raise ValueError('Cannot specify both path and model_bytes.')
|
|
90
|
+
self.data_ = path
|
|
91
|
+
else:
|
|
92
|
+
if model_bytes is None:
|
|
93
|
+
raise ValueError('Cannot specify neither path nor model_bytes.')
|
|
94
|
+
self.data_ = model_bytes
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def in_memory(self) -> bool:
|
|
98
|
+
return isinstance(self.data_, bytes)
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def path(self) -> pathlib.Path:
|
|
102
|
+
if not isinstance(self.data_, pathlib.Path):
|
|
103
|
+
raise ValueError('Model is not on disk.')
|
|
104
|
+
return self.data_
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def model_bytes(self) -> bytes:
|
|
108
|
+
if not isinstance(self.data_, bytes):
|
|
109
|
+
raise ValueError('Model is not in memory.')
|
|
110
|
+
return self.data_
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def create_from_path(cls, path: pathlib.Path) -> 'Model':
|
|
114
|
+
return Model(path=path, model_bytes=None)
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def create_from_bytes(cls, model_bytes: bytes) -> 'Model':
|
|
118
|
+
return Model(path=None, model_bytes=model_bytes)
|
|
119
|
+
|
|
120
|
+
def set_path(self, path: pathlib.Path | str):
|
|
121
|
+
if isinstance(path, str):
|
|
122
|
+
path = pathlib.Path(path)
|
|
123
|
+
self.data_ = path
|
|
124
|
+
|
|
125
|
+
def set_bytes(self, model_bytes: bytes):
|
|
126
|
+
self.data_ = model_bytes
|
|
127
|
+
|
|
128
|
+
def load(self):
|
|
129
|
+
"""Loads the model from the given path.
|
|
130
|
+
|
|
131
|
+
Raises:
|
|
132
|
+
ValueError: If the model is already in memory.
|
|
133
|
+
"""
|
|
134
|
+
if not isinstance(self.data_, pathlib.Path):
|
|
135
|
+
raise ValueError('Cannot load a model that is already in memory.')
|
|
136
|
+
self.data_ = self.data_.read_bytes()
|
|
137
|
+
|
|
138
|
+
def save(self, path: pathlib.Path | str, export_only: bool = False):
|
|
139
|
+
"""Saves the model to the given path from the in-memory model content.
|
|
140
|
+
|
|
141
|
+
If export_only is True, the model will be copied to the given path without
|
|
142
|
+
modifying the internal state, regardless of whether the model is already on
|
|
143
|
+
disk or in memory.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
path: The path to save the model to.
|
|
147
|
+
export_only: Whether to only export the model without modifying the
|
|
148
|
+
internal stat (i.e. transfer the in-memory model to disk).
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
ValueError: If export_only is False and the model is not in memory.
|
|
152
|
+
"""
|
|
153
|
+
if isinstance(path, str):
|
|
154
|
+
path = pathlib.Path(path)
|
|
155
|
+
if isinstance(self.data_, pathlib.Path):
|
|
156
|
+
if not export_only:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
'Cannot save a model that is not in memory. Use export_only=True'
|
|
159
|
+
' for copying the model to a new path.'
|
|
160
|
+
)
|
|
161
|
+
with open(self.data_, 'rb') as f:
|
|
162
|
+
model_content = f.read()
|
|
163
|
+
else:
|
|
164
|
+
model_content = self.data_
|
|
165
|
+
path.write_bytes(model_content)
|
|
166
|
+
if not export_only:
|
|
167
|
+
self.data_ = path
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@dataclasses.dataclass()
|
|
171
|
+
class CompilationResult:
|
|
172
|
+
"""Compilation result, as a collection of compiled models."""
|
|
173
|
+
|
|
174
|
+
models_with_backend: list[tuple['Backend', Model]] = dataclasses.field(
|
|
175
|
+
default_factory=list
|
|
176
|
+
)
|
|
177
|
+
failed_backends: list[tuple['Backend', str]] = dataclasses.field(
|
|
178
|
+
default_factory=list
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def models(self) -> list[Model]:
|
|
183
|
+
return [model for _, model in self.models_with_backend]
|
|
184
|
+
|
|
185
|
+
def load(self):
|
|
186
|
+
for _, model in self.models_with_backend:
|
|
187
|
+
if not model.in_memory:
|
|
188
|
+
model.load()
|
|
189
|
+
|
|
190
|
+
def export(self, output_dir: pathlib.Path | str, model_name: str = 'model'):
|
|
191
|
+
if isinstance(output_dir, str):
|
|
192
|
+
output_dir = pathlib.Path(output_dir)
|
|
193
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
194
|
+
for backend, model in self.models_with_backend:
|
|
195
|
+
model.save(
|
|
196
|
+
output_dir / (model_name + backend.target_id_suffix + '.tflite'),
|
|
197
|
+
export_only=True,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
def compilation_report(self) -> str:
|
|
201
|
+
"""Returns a human readable compilation report."""
|
|
202
|
+
report = []
|
|
203
|
+
for backend, model in self.models_with_backend:
|
|
204
|
+
report.append(f'{backend.target_id}')
|
|
205
|
+
report.append('==========================')
|
|
206
|
+
report.append(f'Partition Stats:\n{model.partition_stats}\n')
|
|
207
|
+
report = '\n'.join(report)
|
|
208
|
+
|
|
209
|
+
failed_report = []
|
|
210
|
+
if self.failed_backends:
|
|
211
|
+
failed_report.append('==========================')
|
|
212
|
+
failed_report.append('COMPILATION FAILURES:')
|
|
213
|
+
failed_report.append('==========================')
|
|
214
|
+
for backend, error in self.failed_backends:
|
|
215
|
+
failed_report.append(f'{backend.target_id}\t{error}')
|
|
216
|
+
failed_report = '\n'.join(failed_report)
|
|
217
|
+
return '\n'.join([report, failed_report])
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class Component(Protocol):
|
|
221
|
+
"""An arbitrary module in the AOT flow that inputs and outputs a Model.
|
|
222
|
+
|
|
223
|
+
For example quantizer, graph rewriter, compiler plugin etc.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def component_name(self) -> str:
|
|
228
|
+
...
|
|
229
|
+
|
|
230
|
+
def __call__(self, input_model: Model, output_model: Model, *args, **kwargs):
|
|
231
|
+
...
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
# A user provided configuration. This will contain all the information needed
|
|
235
|
+
# to select the proper backend and run components (e.g. quant recipe,
|
|
236
|
+
# backend id etc). Backends will validate and resolve configurations and are
|
|
237
|
+
# ultimately responsible deciding how to configure the components.
|
|
238
|
+
# NOTE: Consider a typed config approach (proto, data class, etc.)
|
|
239
|
+
Config: TypeAlias = MutableMapping[str, Any]
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# Backend specific compilation configuration.
|
|
243
|
+
BackendCompilationConfig: TypeAlias = MutableMapping[str, Any]
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
# The following is experimental and for protyping only.
|
|
247
|
+
class CompilationConfig:
|
|
248
|
+
"""A typed configuration."""
|
|
249
|
+
|
|
250
|
+
target: 'Target'
|
|
251
|
+
compilation_config: BackendCompilationConfig = dataclasses.field(
|
|
252
|
+
default_factory=dict
|
|
253
|
+
)
|
|
254
|
+
quant_recipe: str | None = None
|
|
255
|
+
|
|
256
|
+
def __init__(self, target: 'Target', **kwargs: Any):
|
|
257
|
+
self.target = target
|
|
258
|
+
self.quant_recipe = kwargs.pop('quantize_recipe', None)
|
|
259
|
+
self.compilation_config = kwargs
|
|
260
|
+
|
|
261
|
+
def to_dict(self) -> dict[str, Any]:
|
|
262
|
+
ret = self.target.flatten()
|
|
263
|
+
ret['compilation_config'] = self.compilation_config
|
|
264
|
+
if self.quant_recipe is not None:
|
|
265
|
+
ret['quantize_recipe'] = self.quant_recipe
|
|
266
|
+
return ret
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class Backend(metaclass=abc.ABCMeta):
|
|
270
|
+
"""A backend pertaining to a particular SoC vendor.
|
|
271
|
+
|
|
272
|
+
Mainly responsible for resolving configurations and managing vendor specific
|
|
273
|
+
resources (e.g. .so etc).
|
|
274
|
+
"""
|
|
275
|
+
|
|
276
|
+
# NOTE: Only initialize through "create".
|
|
277
|
+
def __init__(self, config: Config):
|
|
278
|
+
self._config = config
|
|
279
|
+
|
|
280
|
+
@classmethod
|
|
281
|
+
@abc.abstractmethod
|
|
282
|
+
def create(cls, config: Config) -> 'Backend':
|
|
283
|
+
"""Creates a backend instance.
|
|
284
|
+
|
|
285
|
+
If no target is specified, the backend will represent all targets.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
config: The compilation configuration.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
The backend instance.
|
|
292
|
+
"""
|
|
293
|
+
|
|
294
|
+
@classmethod
|
|
295
|
+
@abc.abstractmethod
|
|
296
|
+
def id(cls) -> str:
|
|
297
|
+
pass
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
@abc.abstractmethod
|
|
301
|
+
def target(self) -> 'Target':
|
|
302
|
+
pass
|
|
303
|
+
|
|
304
|
+
@property
|
|
305
|
+
@abc.abstractmethod
|
|
306
|
+
def target_id(self) -> str:
|
|
307
|
+
pass
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def target_id_suffix(self) -> str:
|
|
311
|
+
if self.target_id:
|
|
312
|
+
return '_' + self.target_id
|
|
313
|
+
return ''
|
|
314
|
+
|
|
315
|
+
@property
|
|
316
|
+
def config(self) -> Config:
|
|
317
|
+
return self._config
|
|
318
|
+
|
|
319
|
+
@property
|
|
320
|
+
def soc_manufacturer(self) -> str:
|
|
321
|
+
"""Manufacturer name or enum."""
|
|
322
|
+
raise NotImplementedError()
|
|
323
|
+
|
|
324
|
+
@property
|
|
325
|
+
def soc_model(self) -> str:
|
|
326
|
+
"""Model name or enum."""
|
|
327
|
+
raise NotImplementedError()
|
|
328
|
+
|
|
329
|
+
@property
|
|
330
|
+
def shared_pass_names(self) -> list[str]:
|
|
331
|
+
"""Names of shared passes."""
|
|
332
|
+
raise NotImplementedError()
|
|
333
|
+
|
|
334
|
+
@property
|
|
335
|
+
def quantize_recipe(self) -> str | None:
|
|
336
|
+
"""Optional quantization recipe."""
|
|
337
|
+
return None
|
|
338
|
+
|
|
339
|
+
@abc.abstractmethod
|
|
340
|
+
def call_component(
|
|
341
|
+
self, input_model: Model, output_model: Model, component: Component
|
|
342
|
+
):
|
|
343
|
+
pass
|
|
344
|
+
|
|
345
|
+
def specialize(self) -> Iterable['Backend']:
|
|
346
|
+
yield self
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
BackendT: TypeAlias = Type[Backend]
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class Target(metaclass=abc.ABCMeta):
|
|
353
|
+
"""Compilation target."""
|
|
354
|
+
|
|
355
|
+
@abc.abstractmethod
|
|
356
|
+
def __hash__(self) -> int:
|
|
357
|
+
pass
|
|
358
|
+
|
|
359
|
+
@abc.abstractmethod
|
|
360
|
+
def __eq__(self, other) -> bool:
|
|
361
|
+
pass
|
|
362
|
+
|
|
363
|
+
@abc.abstractmethod
|
|
364
|
+
def __repr__(self) -> str:
|
|
365
|
+
pass
|
|
366
|
+
|
|
367
|
+
@classmethod
|
|
368
|
+
@abc.abstractmethod
|
|
369
|
+
def backend_id(cls) -> str:
|
|
370
|
+
pass
|
|
371
|
+
|
|
372
|
+
@abc.abstractmethod
|
|
373
|
+
def flatten(self) -> dict[str, Any]:
|
|
374
|
+
return {'backend_id': self.backend_id()}
|