ai-edge-litert-nightly 2.2.0.dev20260102__cp312-cp312-manylinux_2_27_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. ai_edge_litert/__init__.py +1 -0
  2. ai_edge_litert/_pywrap_analyzer_wrapper.so +0 -0
  3. ai_edge_litert/_pywrap_litert_compiled_model_wrapper.so +0 -0
  4. ai_edge_litert/_pywrap_litert_interpreter_wrapper.so +0 -0
  5. ai_edge_litert/_pywrap_litert_tensor_buffer_wrapper.so +0 -0
  6. ai_edge_litert/_pywrap_modify_model_interface.so +0 -0
  7. ai_edge_litert/_pywrap_string_util.so +0 -0
  8. ai_edge_litert/_pywrap_tensorflow_lite_calibration_wrapper.so +0 -0
  9. ai_edge_litert/_pywrap_tensorflow_lite_metrics_wrapper.so +0 -0
  10. ai_edge_litert/any_pb2.py +37 -0
  11. ai_edge_litert/aot/__init__.py +0 -0
  12. ai_edge_litert/aot/ai_pack/__init__.py +0 -0
  13. ai_edge_litert/aot/ai_pack/export_lib.py +300 -0
  14. ai_edge_litert/aot/aot_compile.py +153 -0
  15. ai_edge_litert/aot/core/__init__.py +0 -0
  16. ai_edge_litert/aot/core/apply_plugin.py +148 -0
  17. ai_edge_litert/aot/core/common.py +97 -0
  18. ai_edge_litert/aot/core/components.py +93 -0
  19. ai_edge_litert/aot/core/mlir_transforms.py +36 -0
  20. ai_edge_litert/aot/core/tflxx_util.py +30 -0
  21. ai_edge_litert/aot/core/types.py +374 -0
  22. ai_edge_litert/aot/prepare_for_npu.py +152 -0
  23. ai_edge_litert/aot/vendors/__init__.py +22 -0
  24. ai_edge_litert/aot/vendors/example/__init__.py +0 -0
  25. ai_edge_litert/aot/vendors/example/example_backend.py +157 -0
  26. ai_edge_litert/aot/vendors/fallback_backend.py +128 -0
  27. ai_edge_litert/aot/vendors/google_tensor/__init__.py +0 -0
  28. ai_edge_litert/aot/vendors/google_tensor/google_tensor_backend.py +168 -0
  29. ai_edge_litert/aot/vendors/google_tensor/target.py +84 -0
  30. ai_edge_litert/aot/vendors/import_vendor.py +132 -0
  31. ai_edge_litert/aot/vendors/mediatek/__init__.py +0 -0
  32. ai_edge_litert/aot/vendors/mediatek/mediatek_backend.py +196 -0
  33. ai_edge_litert/aot/vendors/mediatek/target.py +94 -0
  34. ai_edge_litert/aot/vendors/qualcomm/__init__.py +0 -0
  35. ai_edge_litert/aot/vendors/qualcomm/qualcomm_backend.py +161 -0
  36. ai_edge_litert/aot/vendors/qualcomm/target.py +75 -0
  37. ai_edge_litert/api_pb2.py +43 -0
  38. ai_edge_litert/compiled_model.py +250 -0
  39. ai_edge_litert/descriptor_pb2.py +3361 -0
  40. ai_edge_litert/duration_pb2.py +37 -0
  41. ai_edge_litert/empty_pb2.py +37 -0
  42. ai_edge_litert/field_mask_pb2.py +37 -0
  43. ai_edge_litert/format_converter_wrapper_pybind11.so +0 -0
  44. ai_edge_litert/hardware_accelerator.py +22 -0
  45. ai_edge_litert/internal/__init__.py +0 -0
  46. ai_edge_litert/internal/litertlm_builder.py +584 -0
  47. ai_edge_litert/internal/litertlm_core.py +58 -0
  48. ai_edge_litert/internal/litertlm_header_schema_py_generated.py +1596 -0
  49. ai_edge_litert/internal/llm_metadata_pb2.py +45 -0
  50. ai_edge_litert/internal/llm_model_type_pb2.py +51 -0
  51. ai_edge_litert/internal/sampler_params_pb2.py +39 -0
  52. ai_edge_litert/internal/token_pb2.py +38 -0
  53. ai_edge_litert/interpreter.py +1039 -0
  54. ai_edge_litert/libLiteRt.so +0 -0
  55. ai_edge_litert/libpywrap_litert_common.so +0 -0
  56. ai_edge_litert/metrics_interface.py +48 -0
  57. ai_edge_litert/metrics_portable.py +70 -0
  58. ai_edge_litert/model_runtime_info_pb2.py +66 -0
  59. ai_edge_litert/plugin_pb2.py +46 -0
  60. ai_edge_litert/profiling_info_pb2.py +47 -0
  61. ai_edge_litert/pywrap_genai_ops.so +0 -0
  62. ai_edge_litert/schema_py_generated.py +19640 -0
  63. ai_edge_litert/source_context_pb2.py +37 -0
  64. ai_edge_litert/struct_pb2.py +47 -0
  65. ai_edge_litert/tensor_buffer.py +167 -0
  66. ai_edge_litert/timestamp_pb2.py +37 -0
  67. ai_edge_litert/tools/__init__.py +0 -0
  68. ai_edge_litert/tools/apply_plugin_main +0 -0
  69. ai_edge_litert/tools/flatbuffer_utils.py +534 -0
  70. ai_edge_litert/type_pb2.py +53 -0
  71. ai_edge_litert/vendors/google_tensor/compiler/libLiteRtCompilerPlugin_google_tensor.so +0 -0
  72. ai_edge_litert/vendors/mediatek/compiler/libLiteRtCompilerPlugin_MediaTek.so +0 -0
  73. ai_edge_litert/vendors/qualcomm/compiler/libLiteRtCompilerPlugin_Qualcomm.so +0 -0
  74. ai_edge_litert/wrappers_pb2.py +53 -0
  75. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/METADATA +52 -0
  76. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/RECORD +78 -0
  77. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/WHEEL +5 -0
  78. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/top_level.txt +1 -0
@@ -0,0 +1,97 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Constants and other small generic utilities."""
17
+
18
+ from importlib import resources
19
+ import os
20
+ import pathlib
21
+
22
+ TFLITE = "tflite"
23
+ DOT_TFLITE = f".{TFLITE}"
24
+ NPU = "npu"
25
+
26
+
27
+ _WORKSPACE_PREFIX = "litert"
28
+ _AI_EDGE_LITERT_PREFIX = "ai_edge_litert"
29
+ _LITERT_ROOT = ""
30
+ _PYTHON_ROOT = "python/aot"
31
+
32
+ MODULE_ROOT = ".".join([
33
+ _WORKSPACE_PREFIX,
34
+ _LITERT_ROOT.replace("/", "."),
35
+ _PYTHON_ROOT.replace("/", "."),
36
+ ])
37
+
38
+
39
+ def get_resource(
40
+ litert_relative_path: pathlib.Path, is_dir=False
41
+ ) -> pathlib.Path:
42
+ """Returns the path to a resource in the Litert workspace."""
43
+ try:
44
+ resource_root = resources.files(_WORKSPACE_PREFIX)
45
+ except ModuleNotFoundError:
46
+ resource_root = resources.files(_AI_EDGE_LITERT_PREFIX)
47
+ litert_resource = resource_root.joinpath(
48
+ _LITERT_ROOT, str(litert_relative_path)
49
+ )
50
+ if not is_dir and not litert_resource.is_file():
51
+ raise FileNotFoundError(f"Resource {litert_resource} does not exist.")
52
+ return pathlib.Path(str(litert_resource))
53
+
54
+
55
+ def is_tflite(path: pathlib.Path) -> bool:
56
+ return path.exists() and path.is_file() and path.suffix == f".{TFLITE}"
57
+
58
+
59
+ def construct_ld_library_path() -> str:
60
+ """Constructs a string suitable for the LD_LIBRARY_PATH environment variable.
61
+
62
+ This function is used in ai_edge_litert python package, when the shared
63
+ libraries are not in a static location. This function will construct the
64
+ LD_LIBRARY_PATH environment variable using the ai_edge_litert directory, and
65
+ all subdirectories.
66
+
67
+ If the module is built from source, this function will return an empty string.
68
+
69
+ Returns:
70
+ A string suitable for the LD_LIBRARY_PATH environment variable.
71
+ """
72
+ try:
73
+ resource_root = resources.files(_AI_EDGE_LITERT_PREFIX)
74
+ except ModuleNotFoundError:
75
+ # Bulit from source case.
76
+ return ""
77
+ root_package_path = str(resource_root)
78
+
79
+ library_paths = set()
80
+
81
+ library_paths.add(os.path.abspath(root_package_path))
82
+
83
+ for dirpath, _, _ in os.walk(root_package_path):
84
+ library_paths.add(os.path.abspath(dirpath))
85
+
86
+ sorted_paths = sorted(list(library_paths))
87
+ new_ld_library_path = os.pathsep.join(sorted_paths)
88
+ current_ld_library_path = os.environ.get("LD_LIBRARY_PATH")
89
+
90
+ if current_ld_library_path:
91
+ if current_ld_library_path not in new_ld_library_path:
92
+ lib_paths = f"{new_ld_library_path}{os.pathsep}{current_ld_library_path}"
93
+ else:
94
+ lib_paths = new_ld_library_path
95
+ else:
96
+ lib_paths = new_ld_library_path
97
+ return lib_paths
@@ -0,0 +1,93 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Interfaces for specific components used in the LiteRt AOT flow."""
17
+
18
+ import abc
19
+ import sys
20
+ from typing import Any
21
+
22
+ from ai_edge_litert.aot.core import types
23
+
24
+ # pylint: disable=g-importing-member
25
+ # pylint: disable=g-import-not-at-top
26
+ # pylint: disable=g-bad-import-order
27
+ if sys.version_info < (3, 10):
28
+ from typing_extensions import TypeAlias
29
+ else:
30
+ from typing import TypeAlias
31
+ # pylint: enable=g-bad-import-order
32
+ # pylint: enable=g-import-not-at-top
33
+ # pylint: enable=g-importing-member
34
+
35
+ QuantRecipe: TypeAlias = list[dict[str, Any]] | str
36
+
37
+
38
+ class AieQuantizerT(metaclass=abc.ABCMeta):
39
+ """Interface for AIE quantizer components."""
40
+
41
+ @property
42
+ def component_name(self) -> str:
43
+ return "aie_quantizer"
44
+
45
+ @abc.abstractmethod
46
+ def __call__(
47
+ self,
48
+ input_model: types.Model,
49
+ output_model: types.Model,
50
+ quantization_recipe: QuantRecipe | None = None,
51
+ *args,
52
+ **kwargs,
53
+ ):
54
+ pass
55
+
56
+
57
+ class ApplyPluginT(metaclass=abc.ABCMeta):
58
+ """Interface for apply plugin components."""
59
+
60
+ @property
61
+ def default_err(self) -> str:
62
+ # NOTE: Capture stderr from underlying binary.
63
+ return "none"
64
+
65
+ @property
66
+ def component_name(self) -> str:
67
+ return "apply_plugin"
68
+
69
+ @abc.abstractmethod
70
+ def __call__(
71
+ self,
72
+ input_model: types.Model,
73
+ output_model: types.Model,
74
+ soc_manufacturer: str,
75
+ soc_model: str,
76
+ *args,
77
+ **kwargs,
78
+ ):
79
+ pass
80
+
81
+
82
+ class MlirTransformsT(metaclass=abc.ABCMeta):
83
+ """Interface for MLIR transforms components."""
84
+
85
+ @property
86
+ def component_name(self) -> str:
87
+ return "mlir_transforms"
88
+
89
+ @abc.abstractmethod
90
+ def __call__(
91
+ self, input_model: types.Model, output_model: types.Model, *args, **kwargs
92
+ ):
93
+ pass
@@ -0,0 +1,36 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Wrapper for suite of MLIR passes."""
17
+
18
+ from ai_edge_litert.aot.core import components
19
+ from ai_edge_litert.aot.core import tflxx_util
20
+ from ai_edge_litert.aot.core import types
21
+
22
+
23
+ class MlirTransforms(components.MlirTransformsT):
24
+ """Wrapper for suite of MLIR passes."""
25
+
26
+ def __call__(
27
+ self,
28
+ input_model: types.Model,
29
+ output_model: types.Model,
30
+ pass_name: str,
31
+ ):
32
+ if not input_model.in_memory:
33
+ input_model.load()
34
+ input_bytes = input_model.model_bytes
35
+ output_bytes = tflxx_util.call_tflxx(input_bytes, pass_name)
36
+ output_model.set_bytes(output_bytes)
@@ -0,0 +1,30 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # pylint: disable=g-import-not-at-top
17
+ # pytype: disable=import-error
18
+ # pytype: disable=not-callable
19
+
20
+ """Shim layer for TFLXX while it is in experimental."""
21
+
22
+
23
+ import importlib.util
24
+ from typing import Callable
25
+
26
+ call_tflxx: Callable[[bytes, str], bytes] = lambda input, pass_name: input
27
+
28
+
29
+ def tflxx_enabled() -> bool:
30
+ return False
@@ -0,0 +1,374 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Basic types used in the LiteRt AOT flow."""
17
+
18
+ import abc
19
+ from collections.abc import Iterable
20
+ import dataclasses
21
+ import pathlib
22
+ import sys
23
+ from typing import Any, MutableMapping, Protocol, Type
24
+
25
+ # pylint: disable=g-importing-member
26
+ # pylint: disable=g-import-not-at-top
27
+ # pylint: disable=g-bad-import-order
28
+ if sys.version_info < (3, 10):
29
+ from typing_extensions import TypeAlias
30
+ else:
31
+ from typing import TypeAlias
32
+ # pylint: enable=g-bad-import-order
33
+ # pylint: enable=g-import-not-at-top
34
+ # pylint: enable=g-importing-member
35
+
36
+
37
+ @dataclasses.dataclass(frozen=True)
38
+ class SubgraphPartitionStats:
39
+ """Subgraph partition stats."""
40
+
41
+ subgraph_index: int
42
+ num_ops_offloaded: int
43
+ num_total_ops: int
44
+ num_partitions_offloaded: int
45
+
46
+ def __str__(self) -> str:
47
+ is_full_offload = self.num_ops_offloaded == self.num_total_ops
48
+ return (
49
+ 'Subgraph'
50
+ f' {self.subgraph_index} {"fully" if is_full_offload else "partially"}'
51
+ f' compiled:\t{self.num_ops_offloaded} /'
52
+ f' {self.num_total_ops} ops offloaded to'
53
+ f' {self.num_partitions_offloaded} partitions.'
54
+ )
55
+
56
+
57
+ @dataclasses.dataclass(frozen=True)
58
+ class PartitionStats:
59
+ """Model partition stats."""
60
+
61
+ subgraph_stats: list[SubgraphPartitionStats]
62
+
63
+ def __str__(self) -> str:
64
+ return '\n'.join(str(s) for s in self.subgraph_stats)
65
+
66
+
67
+ class Model:
68
+ """A model.
69
+
70
+ Note: If the model is not in memory, data_ will be a path to a file on disk.
71
+ If the model is in memory, data_ will be the model bytes.
72
+
73
+ However, there's no guarantee that the path will be a valid path to a file
74
+ on disk, and/or that the file are a valid TFLite model.
75
+ """
76
+
77
+ data_: pathlib.Path | bytes
78
+ partition_stats: PartitionStats | None = None
79
+
80
+ def __init__(
81
+ self,
82
+ path: pathlib.Path | str | None = None,
83
+ model_bytes: bytes | None = None,
84
+ ):
85
+ if path is not None:
86
+ if isinstance(path, str):
87
+ path = pathlib.Path(path)
88
+ if model_bytes:
89
+ raise ValueError('Cannot specify both path and model_bytes.')
90
+ self.data_ = path
91
+ else:
92
+ if model_bytes is None:
93
+ raise ValueError('Cannot specify neither path nor model_bytes.')
94
+ self.data_ = model_bytes
95
+
96
+ @property
97
+ def in_memory(self) -> bool:
98
+ return isinstance(self.data_, bytes)
99
+
100
+ @property
101
+ def path(self) -> pathlib.Path:
102
+ if not isinstance(self.data_, pathlib.Path):
103
+ raise ValueError('Model is not on disk.')
104
+ return self.data_
105
+
106
+ @property
107
+ def model_bytes(self) -> bytes:
108
+ if not isinstance(self.data_, bytes):
109
+ raise ValueError('Model is not in memory.')
110
+ return self.data_
111
+
112
+ @classmethod
113
+ def create_from_path(cls, path: pathlib.Path) -> 'Model':
114
+ return Model(path=path, model_bytes=None)
115
+
116
+ @classmethod
117
+ def create_from_bytes(cls, model_bytes: bytes) -> 'Model':
118
+ return Model(path=None, model_bytes=model_bytes)
119
+
120
+ def set_path(self, path: pathlib.Path | str):
121
+ if isinstance(path, str):
122
+ path = pathlib.Path(path)
123
+ self.data_ = path
124
+
125
+ def set_bytes(self, model_bytes: bytes):
126
+ self.data_ = model_bytes
127
+
128
+ def load(self):
129
+ """Loads the model from the given path.
130
+
131
+ Raises:
132
+ ValueError: If the model is already in memory.
133
+ """
134
+ if not isinstance(self.data_, pathlib.Path):
135
+ raise ValueError('Cannot load a model that is already in memory.')
136
+ self.data_ = self.data_.read_bytes()
137
+
138
+ def save(self, path: pathlib.Path | str, export_only: bool = False):
139
+ """Saves the model to the given path from the in-memory model content.
140
+
141
+ If export_only is True, the model will be copied to the given path without
142
+ modifying the internal state, regardless of whether the model is already on
143
+ disk or in memory.
144
+
145
+ Args:
146
+ path: The path to save the model to.
147
+ export_only: Whether to only export the model without modifying the
148
+ internal stat (i.e. transfer the in-memory model to disk).
149
+
150
+ Raises:
151
+ ValueError: If export_only is False and the model is not in memory.
152
+ """
153
+ if isinstance(path, str):
154
+ path = pathlib.Path(path)
155
+ if isinstance(self.data_, pathlib.Path):
156
+ if not export_only:
157
+ raise ValueError(
158
+ 'Cannot save a model that is not in memory. Use export_only=True'
159
+ ' for copying the model to a new path.'
160
+ )
161
+ with open(self.data_, 'rb') as f:
162
+ model_content = f.read()
163
+ else:
164
+ model_content = self.data_
165
+ path.write_bytes(model_content)
166
+ if not export_only:
167
+ self.data_ = path
168
+
169
+
170
+ @dataclasses.dataclass()
171
+ class CompilationResult:
172
+ """Compilation result, as a collection of compiled models."""
173
+
174
+ models_with_backend: list[tuple['Backend', Model]] = dataclasses.field(
175
+ default_factory=list
176
+ )
177
+ failed_backends: list[tuple['Backend', str]] = dataclasses.field(
178
+ default_factory=list
179
+ )
180
+
181
+ @property
182
+ def models(self) -> list[Model]:
183
+ return [model for _, model in self.models_with_backend]
184
+
185
+ def load(self):
186
+ for _, model in self.models_with_backend:
187
+ if not model.in_memory:
188
+ model.load()
189
+
190
+ def export(self, output_dir: pathlib.Path | str, model_name: str = 'model'):
191
+ if isinstance(output_dir, str):
192
+ output_dir = pathlib.Path(output_dir)
193
+ output_dir.mkdir(parents=True, exist_ok=True)
194
+ for backend, model in self.models_with_backend:
195
+ model.save(
196
+ output_dir / (model_name + backend.target_id_suffix + '.tflite'),
197
+ export_only=True,
198
+ )
199
+
200
+ def compilation_report(self) -> str:
201
+ """Returns a human readable compilation report."""
202
+ report = []
203
+ for backend, model in self.models_with_backend:
204
+ report.append(f'{backend.target_id}')
205
+ report.append('==========================')
206
+ report.append(f'Partition Stats:\n{model.partition_stats}\n')
207
+ report = '\n'.join(report)
208
+
209
+ failed_report = []
210
+ if self.failed_backends:
211
+ failed_report.append('==========================')
212
+ failed_report.append('COMPILATION FAILURES:')
213
+ failed_report.append('==========================')
214
+ for backend, error in self.failed_backends:
215
+ failed_report.append(f'{backend.target_id}\t{error}')
216
+ failed_report = '\n'.join(failed_report)
217
+ return '\n'.join([report, failed_report])
218
+
219
+
220
+ class Component(Protocol):
221
+ """An arbitrary module in the AOT flow that inputs and outputs a Model.
222
+
223
+ For example quantizer, graph rewriter, compiler plugin etc.
224
+ """
225
+
226
+ @property
227
+ def component_name(self) -> str:
228
+ ...
229
+
230
+ def __call__(self, input_model: Model, output_model: Model, *args, **kwargs):
231
+ ...
232
+
233
+
234
+ # A user provided configuration. This will contain all the information needed
235
+ # to select the proper backend and run components (e.g. quant recipe,
236
+ # backend id etc). Backends will validate and resolve configurations and are
237
+ # ultimately responsible deciding how to configure the components.
238
+ # NOTE: Consider a typed config approach (proto, data class, etc.)
239
+ Config: TypeAlias = MutableMapping[str, Any]
240
+
241
+
242
+ # Backend specific compilation configuration.
243
+ BackendCompilationConfig: TypeAlias = MutableMapping[str, Any]
244
+
245
+
246
+ # The following is experimental and for protyping only.
247
+ class CompilationConfig:
248
+ """A typed configuration."""
249
+
250
+ target: 'Target'
251
+ compilation_config: BackendCompilationConfig = dataclasses.field(
252
+ default_factory=dict
253
+ )
254
+ quant_recipe: str | None = None
255
+
256
+ def __init__(self, target: 'Target', **kwargs: Any):
257
+ self.target = target
258
+ self.quant_recipe = kwargs.pop('quantize_recipe', None)
259
+ self.compilation_config = kwargs
260
+
261
+ def to_dict(self) -> dict[str, Any]:
262
+ ret = self.target.flatten()
263
+ ret['compilation_config'] = self.compilation_config
264
+ if self.quant_recipe is not None:
265
+ ret['quantize_recipe'] = self.quant_recipe
266
+ return ret
267
+
268
+
269
+ class Backend(metaclass=abc.ABCMeta):
270
+ """A backend pertaining to a particular SoC vendor.
271
+
272
+ Mainly responsible for resolving configurations and managing vendor specific
273
+ resources (e.g. .so etc).
274
+ """
275
+
276
+ # NOTE: Only initialize through "create".
277
+ def __init__(self, config: Config):
278
+ self._config = config
279
+
280
+ @classmethod
281
+ @abc.abstractmethod
282
+ def create(cls, config: Config) -> 'Backend':
283
+ """Creates a backend instance.
284
+
285
+ If no target is specified, the backend will represent all targets.
286
+
287
+ Args:
288
+ config: The compilation configuration.
289
+
290
+ Returns:
291
+ The backend instance.
292
+ """
293
+
294
+ @classmethod
295
+ @abc.abstractmethod
296
+ def id(cls) -> str:
297
+ pass
298
+
299
+ @property
300
+ @abc.abstractmethod
301
+ def target(self) -> 'Target':
302
+ pass
303
+
304
+ @property
305
+ @abc.abstractmethod
306
+ def target_id(self) -> str:
307
+ pass
308
+
309
+ @property
310
+ def target_id_suffix(self) -> str:
311
+ if self.target_id:
312
+ return '_' + self.target_id
313
+ return ''
314
+
315
+ @property
316
+ def config(self) -> Config:
317
+ return self._config
318
+
319
+ @property
320
+ def soc_manufacturer(self) -> str:
321
+ """Manufacturer name or enum."""
322
+ raise NotImplementedError()
323
+
324
+ @property
325
+ def soc_model(self) -> str:
326
+ """Model name or enum."""
327
+ raise NotImplementedError()
328
+
329
+ @property
330
+ def shared_pass_names(self) -> list[str]:
331
+ """Names of shared passes."""
332
+ raise NotImplementedError()
333
+
334
+ @property
335
+ def quantize_recipe(self) -> str | None:
336
+ """Optional quantization recipe."""
337
+ return None
338
+
339
+ @abc.abstractmethod
340
+ def call_component(
341
+ self, input_model: Model, output_model: Model, component: Component
342
+ ):
343
+ pass
344
+
345
+ def specialize(self) -> Iterable['Backend']:
346
+ yield self
347
+
348
+
349
+ BackendT: TypeAlias = Type[Backend]
350
+
351
+
352
+ class Target(metaclass=abc.ABCMeta):
353
+ """Compilation target."""
354
+
355
+ @abc.abstractmethod
356
+ def __hash__(self) -> int:
357
+ pass
358
+
359
+ @abc.abstractmethod
360
+ def __eq__(self, other) -> bool:
361
+ pass
362
+
363
+ @abc.abstractmethod
364
+ def __repr__(self) -> str:
365
+ pass
366
+
367
+ @classmethod
368
+ @abc.abstractmethod
369
+ def backend_id(cls) -> str:
370
+ pass
371
+
372
+ @abc.abstractmethod
373
+ def flatten(self) -> dict[str, Any]:
374
+ return {'backend_id': self.backend_id()}