ai-edge-litert-nightly 2.2.0.dev20260102__cp312-cp312-manylinux_2_27_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_litert/__init__.py +1 -0
- ai_edge_litert/_pywrap_analyzer_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_compiled_model_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_interpreter_wrapper.so +0 -0
- ai_edge_litert/_pywrap_litert_tensor_buffer_wrapper.so +0 -0
- ai_edge_litert/_pywrap_modify_model_interface.so +0 -0
- ai_edge_litert/_pywrap_string_util.so +0 -0
- ai_edge_litert/_pywrap_tensorflow_lite_calibration_wrapper.so +0 -0
- ai_edge_litert/_pywrap_tensorflow_lite_metrics_wrapper.so +0 -0
- ai_edge_litert/any_pb2.py +37 -0
- ai_edge_litert/aot/__init__.py +0 -0
- ai_edge_litert/aot/ai_pack/__init__.py +0 -0
- ai_edge_litert/aot/ai_pack/export_lib.py +300 -0
- ai_edge_litert/aot/aot_compile.py +153 -0
- ai_edge_litert/aot/core/__init__.py +0 -0
- ai_edge_litert/aot/core/apply_plugin.py +148 -0
- ai_edge_litert/aot/core/common.py +97 -0
- ai_edge_litert/aot/core/components.py +93 -0
- ai_edge_litert/aot/core/mlir_transforms.py +36 -0
- ai_edge_litert/aot/core/tflxx_util.py +30 -0
- ai_edge_litert/aot/core/types.py +374 -0
- ai_edge_litert/aot/prepare_for_npu.py +152 -0
- ai_edge_litert/aot/vendors/__init__.py +22 -0
- ai_edge_litert/aot/vendors/example/__init__.py +0 -0
- ai_edge_litert/aot/vendors/example/example_backend.py +157 -0
- ai_edge_litert/aot/vendors/fallback_backend.py +128 -0
- ai_edge_litert/aot/vendors/google_tensor/__init__.py +0 -0
- ai_edge_litert/aot/vendors/google_tensor/google_tensor_backend.py +168 -0
- ai_edge_litert/aot/vendors/google_tensor/target.py +84 -0
- ai_edge_litert/aot/vendors/import_vendor.py +132 -0
- ai_edge_litert/aot/vendors/mediatek/__init__.py +0 -0
- ai_edge_litert/aot/vendors/mediatek/mediatek_backend.py +196 -0
- ai_edge_litert/aot/vendors/mediatek/target.py +94 -0
- ai_edge_litert/aot/vendors/qualcomm/__init__.py +0 -0
- ai_edge_litert/aot/vendors/qualcomm/qualcomm_backend.py +161 -0
- ai_edge_litert/aot/vendors/qualcomm/target.py +75 -0
- ai_edge_litert/api_pb2.py +43 -0
- ai_edge_litert/compiled_model.py +250 -0
- ai_edge_litert/descriptor_pb2.py +3361 -0
- ai_edge_litert/duration_pb2.py +37 -0
- ai_edge_litert/empty_pb2.py +37 -0
- ai_edge_litert/field_mask_pb2.py +37 -0
- ai_edge_litert/format_converter_wrapper_pybind11.so +0 -0
- ai_edge_litert/hardware_accelerator.py +22 -0
- ai_edge_litert/internal/__init__.py +0 -0
- ai_edge_litert/internal/litertlm_builder.py +584 -0
- ai_edge_litert/internal/litertlm_core.py +58 -0
- ai_edge_litert/internal/litertlm_header_schema_py_generated.py +1596 -0
- ai_edge_litert/internal/llm_metadata_pb2.py +45 -0
- ai_edge_litert/internal/llm_model_type_pb2.py +51 -0
- ai_edge_litert/internal/sampler_params_pb2.py +39 -0
- ai_edge_litert/internal/token_pb2.py +38 -0
- ai_edge_litert/interpreter.py +1039 -0
- ai_edge_litert/libLiteRt.so +0 -0
- ai_edge_litert/libpywrap_litert_common.so +0 -0
- ai_edge_litert/metrics_interface.py +48 -0
- ai_edge_litert/metrics_portable.py +70 -0
- ai_edge_litert/model_runtime_info_pb2.py +66 -0
- ai_edge_litert/plugin_pb2.py +46 -0
- ai_edge_litert/profiling_info_pb2.py +47 -0
- ai_edge_litert/pywrap_genai_ops.so +0 -0
- ai_edge_litert/schema_py_generated.py +19640 -0
- ai_edge_litert/source_context_pb2.py +37 -0
- ai_edge_litert/struct_pb2.py +47 -0
- ai_edge_litert/tensor_buffer.py +167 -0
- ai_edge_litert/timestamp_pb2.py +37 -0
- ai_edge_litert/tools/__init__.py +0 -0
- ai_edge_litert/tools/apply_plugin_main +0 -0
- ai_edge_litert/tools/flatbuffer_utils.py +534 -0
- ai_edge_litert/type_pb2.py +53 -0
- ai_edge_litert/vendors/google_tensor/compiler/libLiteRtCompilerPlugin_google_tensor.so +0 -0
- ai_edge_litert/vendors/mediatek/compiler/libLiteRtCompilerPlugin_MediaTek.so +0 -0
- ai_edge_litert/vendors/qualcomm/compiler/libLiteRtCompilerPlugin_Qualcomm.so +0 -0
- ai_edge_litert/wrappers_pb2.py +53 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/METADATA +52 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/RECORD +78 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/WHEEL +5 -0
- ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Compilation target for Google Tensor SOCs."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
import sys
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from ai_edge_litert.aot.core import types
|
|
22
|
+
|
|
23
|
+
# pylint: disable=g-importing-member
|
|
24
|
+
# pylint: disable=g-import-not-at-top
|
|
25
|
+
# pylint: disable=g-bad-import-order
|
|
26
|
+
if sys.version_info >= (3, 11):
|
|
27
|
+
from enum import StrEnum # pylint: disable=g-importing-member
|
|
28
|
+
else:
|
|
29
|
+
from backports.strenum import StrEnum # pylint: disable=g-importing-member
|
|
30
|
+
# pylint: enable=g-bad-import-order
|
|
31
|
+
# pylint: enable=g-import-not-at-top
|
|
32
|
+
# pylint: enable=g-importing-member
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
_GOOGLE_TENSOR_BACKEND_ID = "GOOGLE"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SocModel(StrEnum):
|
|
39
|
+
"""Google Tensor SOC model."""
|
|
40
|
+
|
|
41
|
+
ALL = "ALL"
|
|
42
|
+
|
|
43
|
+
TENSOR_G3 = "Tensor_G3"
|
|
44
|
+
TENSOR_G4 = "Tensor_G4"
|
|
45
|
+
TENSOR_G5 = "Tensor_G5"
|
|
46
|
+
TENSOR_G6 = "Tensor_G6"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SocManufacturer(StrEnum):
|
|
50
|
+
"""Google Tensor SOC manufacturer."""
|
|
51
|
+
|
|
52
|
+
GOOGLE = "Google"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclasses.dataclass
|
|
56
|
+
class Target(types.Target):
|
|
57
|
+
"""Compilation target for Google Tensor SOCs."""
|
|
58
|
+
|
|
59
|
+
soc_model: SocModel
|
|
60
|
+
soc_manufacturer: SocManufacturer = SocManufacturer.GOOGLE
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def backend_id(cls) -> str:
|
|
64
|
+
return _GOOGLE_TENSOR_BACKEND_ID
|
|
65
|
+
|
|
66
|
+
def __hash__(self) -> int:
|
|
67
|
+
return hash((self.soc_manufacturer, self.soc_model))
|
|
68
|
+
|
|
69
|
+
def __eq__(self, other: "Target") -> bool:
|
|
70
|
+
return (
|
|
71
|
+
self.soc_manufacturer == other.soc_manufacturer
|
|
72
|
+
and self.soc_model == other.soc_model
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def __repr__(self) -> str:
|
|
76
|
+
return f"{self.soc_manufacturer.value}_{self.soc_model.value}"
|
|
77
|
+
|
|
78
|
+
def flatten(self) -> dict[str, Any]:
|
|
79
|
+
flattend_target = super().flatten()
|
|
80
|
+
flattend_target.update({
|
|
81
|
+
"soc_manufacturer": self.soc_manufacturer.value,
|
|
82
|
+
"soc_model": self.soc_model.value,
|
|
83
|
+
})
|
|
84
|
+
return flattend_target
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Utility for dynamically importing vendor backends based on ID."""
|
|
17
|
+
|
|
18
|
+
import copy
|
|
19
|
+
import dataclasses
|
|
20
|
+
from typing import Any, Iterable
|
|
21
|
+
|
|
22
|
+
from ai_edge_litert.aot.core import types
|
|
23
|
+
from ai_edge_litert.aot.vendors import fallback_backend
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclasses.dataclass
|
|
27
|
+
class VendorRegistry:
|
|
28
|
+
"""A vendor registry."""
|
|
29
|
+
|
|
30
|
+
backend_class: types.BackendT
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
_VENDOR_REGISTRY: dict[str, VendorRegistry] = {}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def register_backend(
|
|
37
|
+
backend_class: types.BackendT,
|
|
38
|
+
):
|
|
39
|
+
backend_id = backend_class.id()
|
|
40
|
+
_VENDOR_REGISTRY[backend_id] = VendorRegistry(backend_class)
|
|
41
|
+
|
|
42
|
+
return backend_class
|
|
43
|
+
|
|
44
|
+
register_backend(fallback_backend.FallbackBackend)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def import_vendor(backend_id: str) -> types.BackendT:
|
|
48
|
+
"""Imports a vendor backend class based on its ID.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
backend_id: The ID of the backend to import.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
The imported backend class.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValueError: If the backend ID is not supported.
|
|
58
|
+
"""
|
|
59
|
+
vendor_module = _VENDOR_REGISTRY.get(backend_id, None)
|
|
60
|
+
if vendor_module is None:
|
|
61
|
+
raise ValueError(f'Unsupported backend id: {backend_id}')
|
|
62
|
+
|
|
63
|
+
return vendor_module.backend_class
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class AllRegisteredTarget(types.Target):
|
|
67
|
+
"""A virtual Compilation target."""
|
|
68
|
+
|
|
69
|
+
def __hash__(self) -> int:
|
|
70
|
+
raise NotImplementedError()
|
|
71
|
+
|
|
72
|
+
def __eq__(self, other) -> bool:
|
|
73
|
+
raise NotImplementedError()
|
|
74
|
+
|
|
75
|
+
def __repr__(self) -> str:
|
|
76
|
+
raise NotImplementedError()
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def backend_id(cls) -> str:
|
|
80
|
+
return 'all'
|
|
81
|
+
|
|
82
|
+
def flatten(self) -> dict[str, Any]:
|
|
83
|
+
return {'backend_id': self.backend_id()}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@register_backend
|
|
87
|
+
class AllRegisteredBackend(types.Backend):
|
|
88
|
+
"""A virtual backend that represents all registered backends."""
|
|
89
|
+
|
|
90
|
+
# NOTE: Only initialize through "create".
|
|
91
|
+
def __init__(self, config: types.Config):
|
|
92
|
+
self._config = config
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def create(cls, config: types.Config) -> 'AllRegisteredBackend':
|
|
96
|
+
return AllRegisteredBackend(config)
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def id(cls) -> str:
|
|
100
|
+
return 'all'
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def target(self) -> AllRegisteredTarget:
|
|
104
|
+
return AllRegisteredTarget()
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def target_id(self) -> str:
|
|
108
|
+
return ''
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def config(self) -> types.Config:
|
|
112
|
+
return self._config
|
|
113
|
+
|
|
114
|
+
def call_component(
|
|
115
|
+
self,
|
|
116
|
+
input_model: types.Model,
|
|
117
|
+
output_model: types.Model,
|
|
118
|
+
component: types.Component,
|
|
119
|
+
):
|
|
120
|
+
del input_model, output_model, component
|
|
121
|
+
raise NotImplementedError(
|
|
122
|
+
'AllRegisteredBackend does not support any component.'
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def specialize(self) -> Iterable[types.Backend]:
|
|
126
|
+
for backend_id, vendor_module in _VENDOR_REGISTRY.items():
|
|
127
|
+
if backend_id == 'all':
|
|
128
|
+
continue
|
|
129
|
+
config = copy.deepcopy(self.config)
|
|
130
|
+
config['backend_id'] = backend_id
|
|
131
|
+
backend = vendor_module.backend_class.create(config)
|
|
132
|
+
yield from backend.specialize()
|
|
File without changes
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Backend implementation for the example compiler plugin.."""
|
|
17
|
+
|
|
18
|
+
import copy
|
|
19
|
+
import functools
|
|
20
|
+
import itertools
|
|
21
|
+
import os
|
|
22
|
+
import pathlib
|
|
23
|
+
from typing import Iterable
|
|
24
|
+
|
|
25
|
+
from ai_edge_litert.aot.core import common
|
|
26
|
+
from ai_edge_litert.aot.core import components
|
|
27
|
+
from ai_edge_litert.aot.core import types
|
|
28
|
+
from ai_edge_litert.aot.vendors import import_vendor
|
|
29
|
+
from ai_edge_litert.aot.vendors.mediatek import target as target_lib
|
|
30
|
+
|
|
31
|
+
COMPILER_PLUGIN_LIB_PATH = pathlib.Path(
|
|
32
|
+
"vendors/mediatek/compiler/libLiteRtCompilerPlugin_MediaTek.so"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@import_vendor.register_backend
|
|
37
|
+
class MediaTekBackend(types.Backend):
|
|
38
|
+
"""Backend implementation for the example compiler plugin."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, config: types.Config):
|
|
41
|
+
super().__init__(config)
|
|
42
|
+
self._compilation_config = config.get("compilation_config", None)
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def soc_manufacturer(self) -> target_lib.SocManufacturer:
|
|
46
|
+
return target_lib.SocManufacturer.MEDIATEK
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def soc_model(self) -> target_lib.SocModel:
|
|
50
|
+
return target_lib.SocModel(self.config.get("soc_model", "ALL"))
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def android_os_version(self) -> target_lib.AndroidOsVersion:
|
|
54
|
+
return target_lib.AndroidOsVersion(
|
|
55
|
+
self.config.get("android_os_version", "ALL")
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def target(self) -> target_lib.Target:
|
|
60
|
+
return target_lib.Target(
|
|
61
|
+
self.soc_model, self.soc_manufacturer, self.android_os_version
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def target_id(self) -> str:
|
|
66
|
+
return repr(self.target)
|
|
67
|
+
|
|
68
|
+
def specialize(self) -> Iterable["MediaTekBackend"]:
|
|
69
|
+
if (
|
|
70
|
+
self.soc_model != target_lib.SocModel.ALL
|
|
71
|
+
and self.android_os_version != target_lib.AndroidOsVersion.ALL
|
|
72
|
+
):
|
|
73
|
+
yield self
|
|
74
|
+
else:
|
|
75
|
+
if self.soc_model == target_lib.SocModel.ALL:
|
|
76
|
+
soc_models = filter(
|
|
77
|
+
lambda x: x != target_lib.SocModel.ALL, target_lib.SocModel
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
soc_models = [self.soc_model]
|
|
81
|
+
if self.android_os_version == target_lib.AndroidOsVersion.ALL:
|
|
82
|
+
android_os_versions = [
|
|
83
|
+
x
|
|
84
|
+
for x in target_lib.AndroidOsVersion
|
|
85
|
+
if x != target_lib.AndroidOsVersion.ALL
|
|
86
|
+
]
|
|
87
|
+
else:
|
|
88
|
+
android_os_versions = [self.android_os_version]
|
|
89
|
+
for soc_model, android_os_version in itertools.product(
|
|
90
|
+
soc_models, android_os_versions
|
|
91
|
+
):
|
|
92
|
+
new_config = copy.deepcopy(self.config)
|
|
93
|
+
new_config["soc_model"] = soc_model.value
|
|
94
|
+
new_config["android_os_version"] = android_os_version.value
|
|
95
|
+
yield self.create(new_config)
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def id(cls) -> str:
|
|
99
|
+
return target_lib._MEDIATEK_BACKEND_ID # pylint: disable=protected-access
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
def create(cls, config: types.Config) -> "MediaTekBackend":
|
|
103
|
+
if config.get("backend_id", "") != cls.id():
|
|
104
|
+
raise ValueError("Invalid backend id")
|
|
105
|
+
return cls(config)
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def quantize_recipe(self) -> str | None:
|
|
109
|
+
return self.config.get("quantize_recipe", None)
|
|
110
|
+
|
|
111
|
+
def call_component(
|
|
112
|
+
self,
|
|
113
|
+
input_model: types.Model,
|
|
114
|
+
output_model: types.Model,
|
|
115
|
+
component: types.Component,
|
|
116
|
+
):
|
|
117
|
+
return _call_component(component, self, input_model, output_model)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@functools.singledispatch
|
|
121
|
+
def _call_component(
|
|
122
|
+
component: types.Component,
|
|
123
|
+
backend: MediaTekBackend,
|
|
124
|
+
unused_input_model: types.Model,
|
|
125
|
+
unused_output_model: types.Model,
|
|
126
|
+
):
|
|
127
|
+
raise NotImplementedError(
|
|
128
|
+
f"{backend.id()} backend does not support"
|
|
129
|
+
f" {component.component_name} component."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# TODO(toribiosteven): Translate SOC | OS version to the corresponding
|
|
134
|
+
# MediaTek SDK version and pass to the plugin.
|
|
135
|
+
@_call_component.register
|
|
136
|
+
def _apply_plugin(
|
|
137
|
+
component: components.ApplyPluginT,
|
|
138
|
+
backend: MediaTekBackend,
|
|
139
|
+
input_model: types.Model,
|
|
140
|
+
output_model: types.Model,
|
|
141
|
+
):
|
|
142
|
+
"""Calls the apply plugin component."""
|
|
143
|
+
try:
|
|
144
|
+
# If the plugin is not built from source (i.e. using ai_edge_litert wheel),
|
|
145
|
+
# we find the plugin library directory from the package path.
|
|
146
|
+
# Otherwise we use the default library path.
|
|
147
|
+
plugin_path = common.get_resource(COMPILER_PLUGIN_LIB_PATH)
|
|
148
|
+
lib_dir = os.path.dirname(plugin_path)
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
# pytype: disable=import-error
|
|
152
|
+
import ai_edge_litert_sdk_mediatek # pylint: disable=g-import-not-at-top
|
|
153
|
+
# pytype: enable=import-error
|
|
154
|
+
|
|
155
|
+
# TODO(weiyiw): Translate SOC | OS version to the corresponding
|
|
156
|
+
# MediaTek SDK version and pass to the plugin.
|
|
157
|
+
sdk_version = "v8"
|
|
158
|
+
sdk_libs_path = str(
|
|
159
|
+
ai_edge_litert_sdk_mediatek.path_to_sdk_libs(sdk_version)
|
|
160
|
+
)
|
|
161
|
+
except ImportError:
|
|
162
|
+
sdk_libs_path = None
|
|
163
|
+
extra_kwargs = {"libs": lib_dir, "sdk_libs_path": sdk_libs_path}
|
|
164
|
+
except FileNotFoundError:
|
|
165
|
+
extra_kwargs = {}
|
|
166
|
+
return component(
|
|
167
|
+
input_model,
|
|
168
|
+
output_model,
|
|
169
|
+
backend.soc_manufacturer,
|
|
170
|
+
backend.soc_model.lower(),
|
|
171
|
+
**extra_kwargs,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@_call_component.register
|
|
176
|
+
def _aie_quantizer(
|
|
177
|
+
component: components.AieQuantizerT,
|
|
178
|
+
backend: MediaTekBackend,
|
|
179
|
+
input_model: types.Model,
|
|
180
|
+
output_model: types.Model,
|
|
181
|
+
):
|
|
182
|
+
return component(
|
|
183
|
+
input_model,
|
|
184
|
+
output_model,
|
|
185
|
+
quantization_recipe=backend.quantize_recipe,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@_call_component.register
|
|
190
|
+
def _mlir_transforms(
|
|
191
|
+
component: components.MlirTransformsT,
|
|
192
|
+
unused_backend: MediaTekBackend,
|
|
193
|
+
input_model: types.Model,
|
|
194
|
+
output_model: types.Model,
|
|
195
|
+
):
|
|
196
|
+
return component(input_model, output_model, [])
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Compilation target for MediaTek SOCs."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import sys
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from ai_edge_litert.aot.core import types
|
|
8
|
+
|
|
9
|
+
# pylint: disable=g-importing-member
|
|
10
|
+
# pylint: disable=g-import-not-at-top
|
|
11
|
+
# pylint: disable=g-bad-import-order
|
|
12
|
+
if sys.version_info >= (3, 11):
|
|
13
|
+
from enum import StrEnum # pylint: disable=g-importing-member
|
|
14
|
+
else:
|
|
15
|
+
from backports.strenum import StrEnum # pylint: disable=g-importing-member
|
|
16
|
+
# pylint: enable=g-bad-import-order
|
|
17
|
+
# pylint: enable=g-import-not-at-top
|
|
18
|
+
# pylint: enable=g-importing-member
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
_MEDIATEK_BACKEND_ID = "mediatek"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SocModel(StrEnum):
|
|
25
|
+
"""MediaTek SOC model."""
|
|
26
|
+
|
|
27
|
+
ALL = "ALL"
|
|
28
|
+
|
|
29
|
+
MT6853 = "MT6853"
|
|
30
|
+
MT6877 = "MT6877"
|
|
31
|
+
MT6878 = "MT6878"
|
|
32
|
+
MT6879 = "MT6879"
|
|
33
|
+
MT6886 = "MT6886"
|
|
34
|
+
MT6893 = "MT6893"
|
|
35
|
+
MT6895 = "MT6895"
|
|
36
|
+
MT6897 = "MT6897"
|
|
37
|
+
MT6983 = "MT6983"
|
|
38
|
+
MT6985 = "MT6985"
|
|
39
|
+
MT6989 = "MT6989"
|
|
40
|
+
MT6991 = "MT6991"
|
|
41
|
+
MT8171 = "MT8171"
|
|
42
|
+
MT8188 = "MT8188"
|
|
43
|
+
MT8189 = "MT8189"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class SocManufacturer(StrEnum):
|
|
47
|
+
"""MediaTek SOC manufacturer."""
|
|
48
|
+
|
|
49
|
+
MEDIATEK = "MediaTek"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AndroidOsVersion(StrEnum):
|
|
53
|
+
"""Android OS version."""
|
|
54
|
+
|
|
55
|
+
ALL = "ALL"
|
|
56
|
+
|
|
57
|
+
ANDROID_15 = "ANDROID_15"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclasses.dataclass
|
|
61
|
+
class Target(types.Target):
|
|
62
|
+
"""Compilation target for MediaTek SOCs."""
|
|
63
|
+
|
|
64
|
+
soc_model: SocModel
|
|
65
|
+
soc_manufacturer: SocManufacturer = SocManufacturer.MEDIATEK
|
|
66
|
+
android_os_version: AndroidOsVersion = AndroidOsVersion.ANDROID_15
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def backend_id(cls) -> str:
|
|
70
|
+
return _MEDIATEK_BACKEND_ID
|
|
71
|
+
|
|
72
|
+
def __hash__(self) -> int:
|
|
73
|
+
return hash(
|
|
74
|
+
(self.soc_manufacturer, self.soc_model, self.android_os_version)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def __eq__(self, other: "Target") -> bool:
|
|
78
|
+
return (
|
|
79
|
+
self.soc_manufacturer == other.soc_manufacturer
|
|
80
|
+
and self.soc_model == other.soc_model
|
|
81
|
+
and self.android_os_version == other.android_os_version
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def __repr__(self) -> str:
|
|
85
|
+
return f"{self.soc_manufacturer.value}_{self.soc_model.value}_{self.android_os_version.value}"
|
|
86
|
+
|
|
87
|
+
def flatten(self) -> dict[str, Any]:
|
|
88
|
+
flattend_target = super().flatten()
|
|
89
|
+
flattend_target.update({
|
|
90
|
+
"soc_manufacturer": self.soc_manufacturer.value,
|
|
91
|
+
"soc_model": self.soc_model.value,
|
|
92
|
+
"android_os_version": self.android_os_version.value,
|
|
93
|
+
})
|
|
94
|
+
return flattend_target
|
|
File without changes
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Backend implementation for the example compiler plugin.."""
|
|
17
|
+
|
|
18
|
+
import copy
|
|
19
|
+
import functools
|
|
20
|
+
import os
|
|
21
|
+
import pathlib
|
|
22
|
+
from typing import Iterable
|
|
23
|
+
|
|
24
|
+
from ai_edge_litert.aot.core import common
|
|
25
|
+
from ai_edge_litert.aot.core import components
|
|
26
|
+
from ai_edge_litert.aot.core import types
|
|
27
|
+
from ai_edge_litert.aot.vendors import import_vendor
|
|
28
|
+
from ai_edge_litert.aot.vendors.qualcomm import target as target_lib
|
|
29
|
+
|
|
30
|
+
COMPILER_PLUGIN_LIB_PATH = pathlib.Path(
|
|
31
|
+
"vendors/qualcomm/compiler/libLiteRtCompilerPlugin_Qualcomm.so"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@import_vendor.register_backend
|
|
36
|
+
class QualcommBackend(types.Backend):
|
|
37
|
+
"""Backend implementation for the example compiler plugin."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, config: types.Config):
|
|
40
|
+
super().__init__(config)
|
|
41
|
+
self._compilation_config = config.get("compilation_config", None)
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def soc_manufacturer(self) -> target_lib.SocManufacturer:
|
|
45
|
+
return target_lib.SocManufacturer.QUALCOMM
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def soc_model(self) -> target_lib.SocModel:
|
|
49
|
+
return target_lib.SocModel(self.config.get("soc_model", "ALL"))
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def target(self) -> target_lib.Target:
|
|
53
|
+
return target_lib.Target(self.soc_model, self.soc_manufacturer)
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def target_id(self) -> str:
|
|
57
|
+
return repr(self.target)
|
|
58
|
+
|
|
59
|
+
def specialize(self) -> Iterable["QualcommBackend"]:
|
|
60
|
+
if self.soc_model != target_lib.SocModel.ALL:
|
|
61
|
+
yield self
|
|
62
|
+
else:
|
|
63
|
+
for soc_model in target_lib.SocModel:
|
|
64
|
+
if soc_model != target_lib.SocModel.ALL:
|
|
65
|
+
new_config = copy.deepcopy(self.config)
|
|
66
|
+
new_config["soc_model"] = soc_model.value
|
|
67
|
+
yield self.create(new_config)
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def id(cls) -> str:
|
|
71
|
+
return target_lib._QUALCOMM_BACKEND_ID # pylint: disable=protected-access
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def create(cls, config: types.Config) -> "QualcommBackend":
|
|
75
|
+
if config.get("backend_id", "") != cls.id():
|
|
76
|
+
raise ValueError("Invalid backend id")
|
|
77
|
+
return cls(config)
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def quantize_recipe(self) -> str | None:
|
|
81
|
+
return self.config.get("quantize_recipe", None)
|
|
82
|
+
|
|
83
|
+
def call_component(
|
|
84
|
+
self,
|
|
85
|
+
input_model: types.Model,
|
|
86
|
+
output_model: types.Model,
|
|
87
|
+
component: types.Component,
|
|
88
|
+
):
|
|
89
|
+
return _call_component(component, self, input_model, output_model)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@functools.singledispatch
|
|
93
|
+
def _call_component(
|
|
94
|
+
component: types.Component,
|
|
95
|
+
backend: QualcommBackend,
|
|
96
|
+
unused_input_model: types.Model,
|
|
97
|
+
unused_output_model: types.Model,
|
|
98
|
+
):
|
|
99
|
+
raise NotImplementedError(
|
|
100
|
+
f"{backend.id()} backend does not support"
|
|
101
|
+
f" {component.component_name} component."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@_call_component.register
|
|
106
|
+
def _apply_plugin(
|
|
107
|
+
component: components.ApplyPluginT,
|
|
108
|
+
backend: QualcommBackend,
|
|
109
|
+
input_model: types.Model,
|
|
110
|
+
output_model: types.Model,
|
|
111
|
+
):
|
|
112
|
+
"""Calls the apply plugin component."""
|
|
113
|
+
try:
|
|
114
|
+
# If the plugin is not built from source (i.e. using ai_edge_litert wheel),
|
|
115
|
+
# we find the plugin library directory from the package path.
|
|
116
|
+
# Otherwise we use the default library path.
|
|
117
|
+
plugin_path = common.get_resource(COMPILER_PLUGIN_LIB_PATH)
|
|
118
|
+
lib_dir = os.path.dirname(plugin_path)
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
# pytype: disable=import-error
|
|
122
|
+
import ai_edge_litert_sdk_qualcomm # pylint: disable=g-import-not-at-top
|
|
123
|
+
# pytype: enable=import-error
|
|
124
|
+
|
|
125
|
+
sdk_libs_path = str(ai_edge_litert_sdk_qualcomm.path_to_sdk_libs())
|
|
126
|
+
except ImportError:
|
|
127
|
+
sdk_libs_path = None
|
|
128
|
+
extra_kwargs = {"libs": lib_dir, "sdk_libs_path": sdk_libs_path}
|
|
129
|
+
except FileNotFoundError:
|
|
130
|
+
extra_kwargs = {}
|
|
131
|
+
return component(
|
|
132
|
+
input_model,
|
|
133
|
+
output_model,
|
|
134
|
+
backend.soc_manufacturer,
|
|
135
|
+
backend.soc_model,
|
|
136
|
+
**extra_kwargs,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@_call_component.register
|
|
141
|
+
def _aie_quantizer(
|
|
142
|
+
component: components.AieQuantizerT,
|
|
143
|
+
backend: QualcommBackend,
|
|
144
|
+
input_model: types.Model,
|
|
145
|
+
output_model: types.Model,
|
|
146
|
+
):
|
|
147
|
+
return component(
|
|
148
|
+
input_model,
|
|
149
|
+
output_model,
|
|
150
|
+
quantization_recipe=backend.quantize_recipe,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@_call_component.register
|
|
155
|
+
def _mlir_transforms(
|
|
156
|
+
component: components.MlirTransformsT,
|
|
157
|
+
unused_backend: QualcommBackend,
|
|
158
|
+
input_model: types.Model,
|
|
159
|
+
output_model: types.Model,
|
|
160
|
+
):
|
|
161
|
+
return component(input_model, output_model, [])
|