ai-edge-litert-nightly 2.2.0.dev20260102__cp312-cp312-manylinux_2_27_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. ai_edge_litert/__init__.py +1 -0
  2. ai_edge_litert/_pywrap_analyzer_wrapper.so +0 -0
  3. ai_edge_litert/_pywrap_litert_compiled_model_wrapper.so +0 -0
  4. ai_edge_litert/_pywrap_litert_interpreter_wrapper.so +0 -0
  5. ai_edge_litert/_pywrap_litert_tensor_buffer_wrapper.so +0 -0
  6. ai_edge_litert/_pywrap_modify_model_interface.so +0 -0
  7. ai_edge_litert/_pywrap_string_util.so +0 -0
  8. ai_edge_litert/_pywrap_tensorflow_lite_calibration_wrapper.so +0 -0
  9. ai_edge_litert/_pywrap_tensorflow_lite_metrics_wrapper.so +0 -0
  10. ai_edge_litert/any_pb2.py +37 -0
  11. ai_edge_litert/aot/__init__.py +0 -0
  12. ai_edge_litert/aot/ai_pack/__init__.py +0 -0
  13. ai_edge_litert/aot/ai_pack/export_lib.py +300 -0
  14. ai_edge_litert/aot/aot_compile.py +153 -0
  15. ai_edge_litert/aot/core/__init__.py +0 -0
  16. ai_edge_litert/aot/core/apply_plugin.py +148 -0
  17. ai_edge_litert/aot/core/common.py +97 -0
  18. ai_edge_litert/aot/core/components.py +93 -0
  19. ai_edge_litert/aot/core/mlir_transforms.py +36 -0
  20. ai_edge_litert/aot/core/tflxx_util.py +30 -0
  21. ai_edge_litert/aot/core/types.py +374 -0
  22. ai_edge_litert/aot/prepare_for_npu.py +152 -0
  23. ai_edge_litert/aot/vendors/__init__.py +22 -0
  24. ai_edge_litert/aot/vendors/example/__init__.py +0 -0
  25. ai_edge_litert/aot/vendors/example/example_backend.py +157 -0
  26. ai_edge_litert/aot/vendors/fallback_backend.py +128 -0
  27. ai_edge_litert/aot/vendors/google_tensor/__init__.py +0 -0
  28. ai_edge_litert/aot/vendors/google_tensor/google_tensor_backend.py +168 -0
  29. ai_edge_litert/aot/vendors/google_tensor/target.py +84 -0
  30. ai_edge_litert/aot/vendors/import_vendor.py +132 -0
  31. ai_edge_litert/aot/vendors/mediatek/__init__.py +0 -0
  32. ai_edge_litert/aot/vendors/mediatek/mediatek_backend.py +196 -0
  33. ai_edge_litert/aot/vendors/mediatek/target.py +94 -0
  34. ai_edge_litert/aot/vendors/qualcomm/__init__.py +0 -0
  35. ai_edge_litert/aot/vendors/qualcomm/qualcomm_backend.py +161 -0
  36. ai_edge_litert/aot/vendors/qualcomm/target.py +75 -0
  37. ai_edge_litert/api_pb2.py +43 -0
  38. ai_edge_litert/compiled_model.py +250 -0
  39. ai_edge_litert/descriptor_pb2.py +3361 -0
  40. ai_edge_litert/duration_pb2.py +37 -0
  41. ai_edge_litert/empty_pb2.py +37 -0
  42. ai_edge_litert/field_mask_pb2.py +37 -0
  43. ai_edge_litert/format_converter_wrapper_pybind11.so +0 -0
  44. ai_edge_litert/hardware_accelerator.py +22 -0
  45. ai_edge_litert/internal/__init__.py +0 -0
  46. ai_edge_litert/internal/litertlm_builder.py +584 -0
  47. ai_edge_litert/internal/litertlm_core.py +58 -0
  48. ai_edge_litert/internal/litertlm_header_schema_py_generated.py +1596 -0
  49. ai_edge_litert/internal/llm_metadata_pb2.py +45 -0
  50. ai_edge_litert/internal/llm_model_type_pb2.py +51 -0
  51. ai_edge_litert/internal/sampler_params_pb2.py +39 -0
  52. ai_edge_litert/internal/token_pb2.py +38 -0
  53. ai_edge_litert/interpreter.py +1039 -0
  54. ai_edge_litert/libLiteRt.so +0 -0
  55. ai_edge_litert/libpywrap_litert_common.so +0 -0
  56. ai_edge_litert/metrics_interface.py +48 -0
  57. ai_edge_litert/metrics_portable.py +70 -0
  58. ai_edge_litert/model_runtime_info_pb2.py +66 -0
  59. ai_edge_litert/plugin_pb2.py +46 -0
  60. ai_edge_litert/profiling_info_pb2.py +47 -0
  61. ai_edge_litert/pywrap_genai_ops.so +0 -0
  62. ai_edge_litert/schema_py_generated.py +19640 -0
  63. ai_edge_litert/source_context_pb2.py +37 -0
  64. ai_edge_litert/struct_pb2.py +47 -0
  65. ai_edge_litert/tensor_buffer.py +167 -0
  66. ai_edge_litert/timestamp_pb2.py +37 -0
  67. ai_edge_litert/tools/__init__.py +0 -0
  68. ai_edge_litert/tools/apply_plugin_main +0 -0
  69. ai_edge_litert/tools/flatbuffer_utils.py +534 -0
  70. ai_edge_litert/type_pb2.py +53 -0
  71. ai_edge_litert/vendors/google_tensor/compiler/libLiteRtCompilerPlugin_google_tensor.so +0 -0
  72. ai_edge_litert/vendors/mediatek/compiler/libLiteRtCompilerPlugin_MediaTek.so +0 -0
  73. ai_edge_litert/vendors/qualcomm/compiler/libLiteRtCompilerPlugin_Qualcomm.so +0 -0
  74. ai_edge_litert/wrappers_pb2.py +53 -0
  75. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/METADATA +52 -0
  76. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/RECORD +78 -0
  77. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/WHEEL +5 -0
  78. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/top_level.txt +1 -0
@@ -0,0 +1,84 @@
1
+ # Copyright 2025 Google LLC.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Compilation target for Google Tensor SOCs."""
16
+
17
+ import dataclasses
18
+ import sys
19
+ from typing import Any
20
+
21
+ from ai_edge_litert.aot.core import types
22
+
23
+ # pylint: disable=g-importing-member
24
+ # pylint: disable=g-import-not-at-top
25
+ # pylint: disable=g-bad-import-order
26
+ if sys.version_info >= (3, 11):
27
+ from enum import StrEnum # pylint: disable=g-importing-member
28
+ else:
29
+ from backports.strenum import StrEnum # pylint: disable=g-importing-member
30
+ # pylint: enable=g-bad-import-order
31
+ # pylint: enable=g-import-not-at-top
32
+ # pylint: enable=g-importing-member
33
+
34
+
35
+ _GOOGLE_TENSOR_BACKEND_ID = "GOOGLE"
36
+
37
+
38
+ class SocModel(StrEnum):
39
+ """Google Tensor SOC model."""
40
+
41
+ ALL = "ALL"
42
+
43
+ TENSOR_G3 = "Tensor_G3"
44
+ TENSOR_G4 = "Tensor_G4"
45
+ TENSOR_G5 = "Tensor_G5"
46
+ TENSOR_G6 = "Tensor_G6"
47
+
48
+
49
+ class SocManufacturer(StrEnum):
50
+ """Google Tensor SOC manufacturer."""
51
+
52
+ GOOGLE = "Google"
53
+
54
+
55
+ @dataclasses.dataclass
56
+ class Target(types.Target):
57
+ """Compilation target for Google Tensor SOCs."""
58
+
59
+ soc_model: SocModel
60
+ soc_manufacturer: SocManufacturer = SocManufacturer.GOOGLE
61
+
62
+ @classmethod
63
+ def backend_id(cls) -> str:
64
+ return _GOOGLE_TENSOR_BACKEND_ID
65
+
66
+ def __hash__(self) -> int:
67
+ return hash((self.soc_manufacturer, self.soc_model))
68
+
69
+ def __eq__(self, other: "Target") -> bool:
70
+ return (
71
+ self.soc_manufacturer == other.soc_manufacturer
72
+ and self.soc_model == other.soc_model
73
+ )
74
+
75
+ def __repr__(self) -> str:
76
+ return f"{self.soc_manufacturer.value}_{self.soc_model.value}"
77
+
78
+ def flatten(self) -> dict[str, Any]:
79
+ flattend_target = super().flatten()
80
+ flattend_target.update({
81
+ "soc_manufacturer": self.soc_manufacturer.value,
82
+ "soc_model": self.soc_model.value,
83
+ })
84
+ return flattend_target
@@ -0,0 +1,132 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utility for dynamically importing vendor backends based on ID."""
17
+
18
+ import copy
19
+ import dataclasses
20
+ from typing import Any, Iterable
21
+
22
+ from ai_edge_litert.aot.core import types
23
+ from ai_edge_litert.aot.vendors import fallback_backend
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class VendorRegistry:
28
+ """A vendor registry."""
29
+
30
+ backend_class: types.BackendT
31
+
32
+
33
+ _VENDOR_REGISTRY: dict[str, VendorRegistry] = {}
34
+
35
+
36
+ def register_backend(
37
+ backend_class: types.BackendT,
38
+ ):
39
+ backend_id = backend_class.id()
40
+ _VENDOR_REGISTRY[backend_id] = VendorRegistry(backend_class)
41
+
42
+ return backend_class
43
+
44
+ register_backend(fallback_backend.FallbackBackend)
45
+
46
+
47
+ def import_vendor(backend_id: str) -> types.BackendT:
48
+ """Imports a vendor backend class based on its ID.
49
+
50
+ Args:
51
+ backend_id: The ID of the backend to import.
52
+
53
+ Returns:
54
+ The imported backend class.
55
+
56
+ Raises:
57
+ ValueError: If the backend ID is not supported.
58
+ """
59
+ vendor_module = _VENDOR_REGISTRY.get(backend_id, None)
60
+ if vendor_module is None:
61
+ raise ValueError(f'Unsupported backend id: {backend_id}')
62
+
63
+ return vendor_module.backend_class
64
+
65
+
66
+ class AllRegisteredTarget(types.Target):
67
+ """A virtual Compilation target."""
68
+
69
+ def __hash__(self) -> int:
70
+ raise NotImplementedError()
71
+
72
+ def __eq__(self, other) -> bool:
73
+ raise NotImplementedError()
74
+
75
+ def __repr__(self) -> str:
76
+ raise NotImplementedError()
77
+
78
+ @classmethod
79
+ def backend_id(cls) -> str:
80
+ return 'all'
81
+
82
+ def flatten(self) -> dict[str, Any]:
83
+ return {'backend_id': self.backend_id()}
84
+
85
+
86
+ @register_backend
87
+ class AllRegisteredBackend(types.Backend):
88
+ """A virtual backend that represents all registered backends."""
89
+
90
+ # NOTE: Only initialize through "create".
91
+ def __init__(self, config: types.Config):
92
+ self._config = config
93
+
94
+ @classmethod
95
+ def create(cls, config: types.Config) -> 'AllRegisteredBackend':
96
+ return AllRegisteredBackend(config)
97
+
98
+ @classmethod
99
+ def id(cls) -> str:
100
+ return 'all'
101
+
102
+ @property
103
+ def target(self) -> AllRegisteredTarget:
104
+ return AllRegisteredTarget()
105
+
106
+ @property
107
+ def target_id(self) -> str:
108
+ return ''
109
+
110
+ @property
111
+ def config(self) -> types.Config:
112
+ return self._config
113
+
114
+ def call_component(
115
+ self,
116
+ input_model: types.Model,
117
+ output_model: types.Model,
118
+ component: types.Component,
119
+ ):
120
+ del input_model, output_model, component
121
+ raise NotImplementedError(
122
+ 'AllRegisteredBackend does not support any component.'
123
+ )
124
+
125
+ def specialize(self) -> Iterable[types.Backend]:
126
+ for backend_id, vendor_module in _VENDOR_REGISTRY.items():
127
+ if backend_id == 'all':
128
+ continue
129
+ config = copy.deepcopy(self.config)
130
+ config['backend_id'] = backend_id
131
+ backend = vendor_module.backend_class.create(config)
132
+ yield from backend.specialize()
File without changes
@@ -0,0 +1,196 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Backend implementation for the example compiler plugin.."""
17
+
18
+ import copy
19
+ import functools
20
+ import itertools
21
+ import os
22
+ import pathlib
23
+ from typing import Iterable
24
+
25
+ from ai_edge_litert.aot.core import common
26
+ from ai_edge_litert.aot.core import components
27
+ from ai_edge_litert.aot.core import types
28
+ from ai_edge_litert.aot.vendors import import_vendor
29
+ from ai_edge_litert.aot.vendors.mediatek import target as target_lib
30
+
31
+ COMPILER_PLUGIN_LIB_PATH = pathlib.Path(
32
+ "vendors/mediatek/compiler/libLiteRtCompilerPlugin_MediaTek.so"
33
+ )
34
+
35
+
36
+ @import_vendor.register_backend
37
+ class MediaTekBackend(types.Backend):
38
+ """Backend implementation for the example compiler plugin."""
39
+
40
+ def __init__(self, config: types.Config):
41
+ super().__init__(config)
42
+ self._compilation_config = config.get("compilation_config", None)
43
+
44
+ @property
45
+ def soc_manufacturer(self) -> target_lib.SocManufacturer:
46
+ return target_lib.SocManufacturer.MEDIATEK
47
+
48
+ @property
49
+ def soc_model(self) -> target_lib.SocModel:
50
+ return target_lib.SocModel(self.config.get("soc_model", "ALL"))
51
+
52
+ @property
53
+ def android_os_version(self) -> target_lib.AndroidOsVersion:
54
+ return target_lib.AndroidOsVersion(
55
+ self.config.get("android_os_version", "ALL")
56
+ )
57
+
58
+ @property
59
+ def target(self) -> target_lib.Target:
60
+ return target_lib.Target(
61
+ self.soc_model, self.soc_manufacturer, self.android_os_version
62
+ )
63
+
64
+ @property
65
+ def target_id(self) -> str:
66
+ return repr(self.target)
67
+
68
+ def specialize(self) -> Iterable["MediaTekBackend"]:
69
+ if (
70
+ self.soc_model != target_lib.SocModel.ALL
71
+ and self.android_os_version != target_lib.AndroidOsVersion.ALL
72
+ ):
73
+ yield self
74
+ else:
75
+ if self.soc_model == target_lib.SocModel.ALL:
76
+ soc_models = filter(
77
+ lambda x: x != target_lib.SocModel.ALL, target_lib.SocModel
78
+ )
79
+ else:
80
+ soc_models = [self.soc_model]
81
+ if self.android_os_version == target_lib.AndroidOsVersion.ALL:
82
+ android_os_versions = [
83
+ x
84
+ for x in target_lib.AndroidOsVersion
85
+ if x != target_lib.AndroidOsVersion.ALL
86
+ ]
87
+ else:
88
+ android_os_versions = [self.android_os_version]
89
+ for soc_model, android_os_version in itertools.product(
90
+ soc_models, android_os_versions
91
+ ):
92
+ new_config = copy.deepcopy(self.config)
93
+ new_config["soc_model"] = soc_model.value
94
+ new_config["android_os_version"] = android_os_version.value
95
+ yield self.create(new_config)
96
+
97
+ @classmethod
98
+ def id(cls) -> str:
99
+ return target_lib._MEDIATEK_BACKEND_ID # pylint: disable=protected-access
100
+
101
+ @classmethod
102
+ def create(cls, config: types.Config) -> "MediaTekBackend":
103
+ if config.get("backend_id", "") != cls.id():
104
+ raise ValueError("Invalid backend id")
105
+ return cls(config)
106
+
107
+ @property
108
+ def quantize_recipe(self) -> str | None:
109
+ return self.config.get("quantize_recipe", None)
110
+
111
+ def call_component(
112
+ self,
113
+ input_model: types.Model,
114
+ output_model: types.Model,
115
+ component: types.Component,
116
+ ):
117
+ return _call_component(component, self, input_model, output_model)
118
+
119
+
120
+ @functools.singledispatch
121
+ def _call_component(
122
+ component: types.Component,
123
+ backend: MediaTekBackend,
124
+ unused_input_model: types.Model,
125
+ unused_output_model: types.Model,
126
+ ):
127
+ raise NotImplementedError(
128
+ f"{backend.id()} backend does not support"
129
+ f" {component.component_name} component."
130
+ )
131
+
132
+
133
+ # TODO(toribiosteven): Translate SOC | OS version to the corresponding
134
+ # MediaTek SDK version and pass to the plugin.
135
+ @_call_component.register
136
+ def _apply_plugin(
137
+ component: components.ApplyPluginT,
138
+ backend: MediaTekBackend,
139
+ input_model: types.Model,
140
+ output_model: types.Model,
141
+ ):
142
+ """Calls the apply plugin component."""
143
+ try:
144
+ # If the plugin is not built from source (i.e. using ai_edge_litert wheel),
145
+ # we find the plugin library directory from the package path.
146
+ # Otherwise we use the default library path.
147
+ plugin_path = common.get_resource(COMPILER_PLUGIN_LIB_PATH)
148
+ lib_dir = os.path.dirname(plugin_path)
149
+
150
+ try:
151
+ # pytype: disable=import-error
152
+ import ai_edge_litert_sdk_mediatek # pylint: disable=g-import-not-at-top
153
+ # pytype: enable=import-error
154
+
155
+ # TODO(weiyiw): Translate SOC | OS version to the corresponding
156
+ # MediaTek SDK version and pass to the plugin.
157
+ sdk_version = "v8"
158
+ sdk_libs_path = str(
159
+ ai_edge_litert_sdk_mediatek.path_to_sdk_libs(sdk_version)
160
+ )
161
+ except ImportError:
162
+ sdk_libs_path = None
163
+ extra_kwargs = {"libs": lib_dir, "sdk_libs_path": sdk_libs_path}
164
+ except FileNotFoundError:
165
+ extra_kwargs = {}
166
+ return component(
167
+ input_model,
168
+ output_model,
169
+ backend.soc_manufacturer,
170
+ backend.soc_model.lower(),
171
+ **extra_kwargs,
172
+ )
173
+
174
+
175
+ @_call_component.register
176
+ def _aie_quantizer(
177
+ component: components.AieQuantizerT,
178
+ backend: MediaTekBackend,
179
+ input_model: types.Model,
180
+ output_model: types.Model,
181
+ ):
182
+ return component(
183
+ input_model,
184
+ output_model,
185
+ quantization_recipe=backend.quantize_recipe,
186
+ )
187
+
188
+
189
+ @_call_component.register
190
+ def _mlir_transforms(
191
+ component: components.MlirTransformsT,
192
+ unused_backend: MediaTekBackend,
193
+ input_model: types.Model,
194
+ output_model: types.Model,
195
+ ):
196
+ return component(input_model, output_model, [])
@@ -0,0 +1,94 @@
1
+ """Compilation target for MediaTek SOCs."""
2
+
3
+ import dataclasses
4
+ import sys
5
+ from typing import Any
6
+
7
+ from ai_edge_litert.aot.core import types
8
+
9
+ # pylint: disable=g-importing-member
10
+ # pylint: disable=g-import-not-at-top
11
+ # pylint: disable=g-bad-import-order
12
+ if sys.version_info >= (3, 11):
13
+ from enum import StrEnum # pylint: disable=g-importing-member
14
+ else:
15
+ from backports.strenum import StrEnum # pylint: disable=g-importing-member
16
+ # pylint: enable=g-bad-import-order
17
+ # pylint: enable=g-import-not-at-top
18
+ # pylint: enable=g-importing-member
19
+
20
+
21
+ _MEDIATEK_BACKEND_ID = "mediatek"
22
+
23
+
24
+ class SocModel(StrEnum):
25
+ """MediaTek SOC model."""
26
+
27
+ ALL = "ALL"
28
+
29
+ MT6853 = "MT6853"
30
+ MT6877 = "MT6877"
31
+ MT6878 = "MT6878"
32
+ MT6879 = "MT6879"
33
+ MT6886 = "MT6886"
34
+ MT6893 = "MT6893"
35
+ MT6895 = "MT6895"
36
+ MT6897 = "MT6897"
37
+ MT6983 = "MT6983"
38
+ MT6985 = "MT6985"
39
+ MT6989 = "MT6989"
40
+ MT6991 = "MT6991"
41
+ MT8171 = "MT8171"
42
+ MT8188 = "MT8188"
43
+ MT8189 = "MT8189"
44
+
45
+
46
+ class SocManufacturer(StrEnum):
47
+ """MediaTek SOC manufacturer."""
48
+
49
+ MEDIATEK = "MediaTek"
50
+
51
+
52
+ class AndroidOsVersion(StrEnum):
53
+ """Android OS version."""
54
+
55
+ ALL = "ALL"
56
+
57
+ ANDROID_15 = "ANDROID_15"
58
+
59
+
60
+ @dataclasses.dataclass
61
+ class Target(types.Target):
62
+ """Compilation target for MediaTek SOCs."""
63
+
64
+ soc_model: SocModel
65
+ soc_manufacturer: SocManufacturer = SocManufacturer.MEDIATEK
66
+ android_os_version: AndroidOsVersion = AndroidOsVersion.ANDROID_15
67
+
68
+ @classmethod
69
+ def backend_id(cls) -> str:
70
+ return _MEDIATEK_BACKEND_ID
71
+
72
+ def __hash__(self) -> int:
73
+ return hash(
74
+ (self.soc_manufacturer, self.soc_model, self.android_os_version)
75
+ )
76
+
77
+ def __eq__(self, other: "Target") -> bool:
78
+ return (
79
+ self.soc_manufacturer == other.soc_manufacturer
80
+ and self.soc_model == other.soc_model
81
+ and self.android_os_version == other.android_os_version
82
+ )
83
+
84
+ def __repr__(self) -> str:
85
+ return f"{self.soc_manufacturer.value}_{self.soc_model.value}_{self.android_os_version.value}"
86
+
87
+ def flatten(self) -> dict[str, Any]:
88
+ flattend_target = super().flatten()
89
+ flattend_target.update({
90
+ "soc_manufacturer": self.soc_manufacturer.value,
91
+ "soc_model": self.soc_model.value,
92
+ "android_os_version": self.android_os_version.value,
93
+ })
94
+ return flattend_target
File without changes
@@ -0,0 +1,161 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Backend implementation for the example compiler plugin.."""
17
+
18
+ import copy
19
+ import functools
20
+ import os
21
+ import pathlib
22
+ from typing import Iterable
23
+
24
+ from ai_edge_litert.aot.core import common
25
+ from ai_edge_litert.aot.core import components
26
+ from ai_edge_litert.aot.core import types
27
+ from ai_edge_litert.aot.vendors import import_vendor
28
+ from ai_edge_litert.aot.vendors.qualcomm import target as target_lib
29
+
30
+ COMPILER_PLUGIN_LIB_PATH = pathlib.Path(
31
+ "vendors/qualcomm/compiler/libLiteRtCompilerPlugin_Qualcomm.so"
32
+ )
33
+
34
+
35
+ @import_vendor.register_backend
36
+ class QualcommBackend(types.Backend):
37
+ """Backend implementation for the example compiler plugin."""
38
+
39
+ def __init__(self, config: types.Config):
40
+ super().__init__(config)
41
+ self._compilation_config = config.get("compilation_config", None)
42
+
43
+ @property
44
+ def soc_manufacturer(self) -> target_lib.SocManufacturer:
45
+ return target_lib.SocManufacturer.QUALCOMM
46
+
47
+ @property
48
+ def soc_model(self) -> target_lib.SocModel:
49
+ return target_lib.SocModel(self.config.get("soc_model", "ALL"))
50
+
51
+ @property
52
+ def target(self) -> target_lib.Target:
53
+ return target_lib.Target(self.soc_model, self.soc_manufacturer)
54
+
55
+ @property
56
+ def target_id(self) -> str:
57
+ return repr(self.target)
58
+
59
+ def specialize(self) -> Iterable["QualcommBackend"]:
60
+ if self.soc_model != target_lib.SocModel.ALL:
61
+ yield self
62
+ else:
63
+ for soc_model in target_lib.SocModel:
64
+ if soc_model != target_lib.SocModel.ALL:
65
+ new_config = copy.deepcopy(self.config)
66
+ new_config["soc_model"] = soc_model.value
67
+ yield self.create(new_config)
68
+
69
+ @classmethod
70
+ def id(cls) -> str:
71
+ return target_lib._QUALCOMM_BACKEND_ID # pylint: disable=protected-access
72
+
73
+ @classmethod
74
+ def create(cls, config: types.Config) -> "QualcommBackend":
75
+ if config.get("backend_id", "") != cls.id():
76
+ raise ValueError("Invalid backend id")
77
+ return cls(config)
78
+
79
+ @property
80
+ def quantize_recipe(self) -> str | None:
81
+ return self.config.get("quantize_recipe", None)
82
+
83
+ def call_component(
84
+ self,
85
+ input_model: types.Model,
86
+ output_model: types.Model,
87
+ component: types.Component,
88
+ ):
89
+ return _call_component(component, self, input_model, output_model)
90
+
91
+
92
+ @functools.singledispatch
93
+ def _call_component(
94
+ component: types.Component,
95
+ backend: QualcommBackend,
96
+ unused_input_model: types.Model,
97
+ unused_output_model: types.Model,
98
+ ):
99
+ raise NotImplementedError(
100
+ f"{backend.id()} backend does not support"
101
+ f" {component.component_name} component."
102
+ )
103
+
104
+
105
+ @_call_component.register
106
+ def _apply_plugin(
107
+ component: components.ApplyPluginT,
108
+ backend: QualcommBackend,
109
+ input_model: types.Model,
110
+ output_model: types.Model,
111
+ ):
112
+ """Calls the apply plugin component."""
113
+ try:
114
+ # If the plugin is not built from source (i.e. using ai_edge_litert wheel),
115
+ # we find the plugin library directory from the package path.
116
+ # Otherwise we use the default library path.
117
+ plugin_path = common.get_resource(COMPILER_PLUGIN_LIB_PATH)
118
+ lib_dir = os.path.dirname(plugin_path)
119
+
120
+ try:
121
+ # pytype: disable=import-error
122
+ import ai_edge_litert_sdk_qualcomm # pylint: disable=g-import-not-at-top
123
+ # pytype: enable=import-error
124
+
125
+ sdk_libs_path = str(ai_edge_litert_sdk_qualcomm.path_to_sdk_libs())
126
+ except ImportError:
127
+ sdk_libs_path = None
128
+ extra_kwargs = {"libs": lib_dir, "sdk_libs_path": sdk_libs_path}
129
+ except FileNotFoundError:
130
+ extra_kwargs = {}
131
+ return component(
132
+ input_model,
133
+ output_model,
134
+ backend.soc_manufacturer,
135
+ backend.soc_model,
136
+ **extra_kwargs,
137
+ )
138
+
139
+
140
+ @_call_component.register
141
+ def _aie_quantizer(
142
+ component: components.AieQuantizerT,
143
+ backend: QualcommBackend,
144
+ input_model: types.Model,
145
+ output_model: types.Model,
146
+ ):
147
+ return component(
148
+ input_model,
149
+ output_model,
150
+ quantization_recipe=backend.quantize_recipe,
151
+ )
152
+
153
+
154
+ @_call_component.register
155
+ def _mlir_transforms(
156
+ component: components.MlirTransformsT,
157
+ unused_backend: QualcommBackend,
158
+ input_model: types.Model,
159
+ output_model: types.Model,
160
+ ):
161
+ return component(input_model, output_model, [])