ai-edge-litert-nightly 2.2.0.dev20260102__cp312-cp312-manylinux_2_27_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. ai_edge_litert/__init__.py +1 -0
  2. ai_edge_litert/_pywrap_analyzer_wrapper.so +0 -0
  3. ai_edge_litert/_pywrap_litert_compiled_model_wrapper.so +0 -0
  4. ai_edge_litert/_pywrap_litert_interpreter_wrapper.so +0 -0
  5. ai_edge_litert/_pywrap_litert_tensor_buffer_wrapper.so +0 -0
  6. ai_edge_litert/_pywrap_modify_model_interface.so +0 -0
  7. ai_edge_litert/_pywrap_string_util.so +0 -0
  8. ai_edge_litert/_pywrap_tensorflow_lite_calibration_wrapper.so +0 -0
  9. ai_edge_litert/_pywrap_tensorflow_lite_metrics_wrapper.so +0 -0
  10. ai_edge_litert/any_pb2.py +37 -0
  11. ai_edge_litert/aot/__init__.py +0 -0
  12. ai_edge_litert/aot/ai_pack/__init__.py +0 -0
  13. ai_edge_litert/aot/ai_pack/export_lib.py +300 -0
  14. ai_edge_litert/aot/aot_compile.py +153 -0
  15. ai_edge_litert/aot/core/__init__.py +0 -0
  16. ai_edge_litert/aot/core/apply_plugin.py +148 -0
  17. ai_edge_litert/aot/core/common.py +97 -0
  18. ai_edge_litert/aot/core/components.py +93 -0
  19. ai_edge_litert/aot/core/mlir_transforms.py +36 -0
  20. ai_edge_litert/aot/core/tflxx_util.py +30 -0
  21. ai_edge_litert/aot/core/types.py +374 -0
  22. ai_edge_litert/aot/prepare_for_npu.py +152 -0
  23. ai_edge_litert/aot/vendors/__init__.py +22 -0
  24. ai_edge_litert/aot/vendors/example/__init__.py +0 -0
  25. ai_edge_litert/aot/vendors/example/example_backend.py +157 -0
  26. ai_edge_litert/aot/vendors/fallback_backend.py +128 -0
  27. ai_edge_litert/aot/vendors/google_tensor/__init__.py +0 -0
  28. ai_edge_litert/aot/vendors/google_tensor/google_tensor_backend.py +168 -0
  29. ai_edge_litert/aot/vendors/google_tensor/target.py +84 -0
  30. ai_edge_litert/aot/vendors/import_vendor.py +132 -0
  31. ai_edge_litert/aot/vendors/mediatek/__init__.py +0 -0
  32. ai_edge_litert/aot/vendors/mediatek/mediatek_backend.py +196 -0
  33. ai_edge_litert/aot/vendors/mediatek/target.py +94 -0
  34. ai_edge_litert/aot/vendors/qualcomm/__init__.py +0 -0
  35. ai_edge_litert/aot/vendors/qualcomm/qualcomm_backend.py +161 -0
  36. ai_edge_litert/aot/vendors/qualcomm/target.py +75 -0
  37. ai_edge_litert/api_pb2.py +43 -0
  38. ai_edge_litert/compiled_model.py +250 -0
  39. ai_edge_litert/descriptor_pb2.py +3361 -0
  40. ai_edge_litert/duration_pb2.py +37 -0
  41. ai_edge_litert/empty_pb2.py +37 -0
  42. ai_edge_litert/field_mask_pb2.py +37 -0
  43. ai_edge_litert/format_converter_wrapper_pybind11.so +0 -0
  44. ai_edge_litert/hardware_accelerator.py +22 -0
  45. ai_edge_litert/internal/__init__.py +0 -0
  46. ai_edge_litert/internal/litertlm_builder.py +584 -0
  47. ai_edge_litert/internal/litertlm_core.py +58 -0
  48. ai_edge_litert/internal/litertlm_header_schema_py_generated.py +1596 -0
  49. ai_edge_litert/internal/llm_metadata_pb2.py +45 -0
  50. ai_edge_litert/internal/llm_model_type_pb2.py +51 -0
  51. ai_edge_litert/internal/sampler_params_pb2.py +39 -0
  52. ai_edge_litert/internal/token_pb2.py +38 -0
  53. ai_edge_litert/interpreter.py +1039 -0
  54. ai_edge_litert/libLiteRt.so +0 -0
  55. ai_edge_litert/libpywrap_litert_common.so +0 -0
  56. ai_edge_litert/metrics_interface.py +48 -0
  57. ai_edge_litert/metrics_portable.py +70 -0
  58. ai_edge_litert/model_runtime_info_pb2.py +66 -0
  59. ai_edge_litert/plugin_pb2.py +46 -0
  60. ai_edge_litert/profiling_info_pb2.py +47 -0
  61. ai_edge_litert/pywrap_genai_ops.so +0 -0
  62. ai_edge_litert/schema_py_generated.py +19640 -0
  63. ai_edge_litert/source_context_pb2.py +37 -0
  64. ai_edge_litert/struct_pb2.py +47 -0
  65. ai_edge_litert/tensor_buffer.py +167 -0
  66. ai_edge_litert/timestamp_pb2.py +37 -0
  67. ai_edge_litert/tools/__init__.py +0 -0
  68. ai_edge_litert/tools/apply_plugin_main +0 -0
  69. ai_edge_litert/tools/flatbuffer_utils.py +534 -0
  70. ai_edge_litert/type_pb2.py +53 -0
  71. ai_edge_litert/vendors/google_tensor/compiler/libLiteRtCompilerPlugin_google_tensor.so +0 -0
  72. ai_edge_litert/vendors/mediatek/compiler/libLiteRtCompilerPlugin_MediaTek.so +0 -0
  73. ai_edge_litert/vendors/qualcomm/compiler/libLiteRtCompilerPlugin_Qualcomm.so +0 -0
  74. ai_edge_litert/wrappers_pb2.py +53 -0
  75. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/METADATA +52 -0
  76. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/RECORD +78 -0
  77. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/WHEEL +5 -0
  78. ai_edge_litert_nightly-2.2.0.dev20260102.dist-info/top_level.txt +1 -0
@@ -0,0 +1,152 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implementations for the main public API functionalities."""
17
+
18
+ import pathlib
19
+ from typing import cast
20
+
21
+ # pylint: disable=g-import-not-at-top
22
+ # pytype: disable=import-error
23
+ try:
24
+ from tqdm import auto as autotqdm
25
+ except ImportError:
26
+ from tqdm.tqdm import auto as autotqdm
27
+ # pytype: enable=import-error
28
+
29
+ from ai_edge_litert.aot.core import common
30
+ from ai_edge_litert.aot.core import components
31
+ from ai_edge_litert.aot.core import types
32
+ from ai_edge_litert.aot.vendors import import_vendor
33
+
34
+ # pylint: enable=g-import-not-at-top
35
+
36
+
37
+ def resolve_backend(config: types.Config) -> types.BackendT:
38
+ # Import the backend based on the ID.
39
+ backend_id = config.get("backend_id", None)
40
+ if backend_id is None:
41
+ raise ValueError("Backend ID is required.")
42
+ return import_vendor.import_vendor(backend_id)
43
+
44
+
45
+ def prepare_for_npu_multiple_configs(
46
+ flatbuffer: types.Model,
47
+ output_dir: pathlib.Path,
48
+ configs: list[tuple[types.BackendT, types.Config]],
49
+ plugin: components.ApplyPluginT,
50
+ transforms: components.MlirTransformsT | None = None,
51
+ quantizer: components.AieQuantizerT | None = None,
52
+ keep_going: bool = False,
53
+ ) -> types.CompilationResult:
54
+ """Prepares a TFLite model for NPU execution."""
55
+ backends = []
56
+ for backend_class, config in configs:
57
+ backend = backend_class.create(config)
58
+ backends += list(backend.specialize())
59
+
60
+ pipeline: list[types.Component] = [
61
+ c for c in [transforms, quantizer, plugin] if c is not None
62
+ ]
63
+ return compile_model(flatbuffer, output_dir, backends, pipeline, keep_going)
64
+
65
+
66
+ def prepare_for_npu(
67
+ flatbuffer: types.Model,
68
+ output_dir: pathlib.Path,
69
+ backend_class: types.BackendT,
70
+ config: types.Config,
71
+ plugin: components.ApplyPluginT,
72
+ transforms: components.MlirTransformsT | None = None,
73
+ quantizer: components.AieQuantizerT | None = None,
74
+ keep_going: bool = False,
75
+ ) -> types.CompilationResult:
76
+ """Prepares a TFLite model for NPU execution.
77
+
78
+ High level command that erforms various backend specific pre-processing steps
79
+ and then applies an NPU compiler to the given model.
80
+
81
+ Args:
82
+ flatbuffer: Path to the input flatbuffer file.
83
+ output_dir: Directory to write the output flatbuffer file.
84
+ backend_class: The backend to prepare the model for.
85
+ config: The configuration for the backend.
86
+ plugin: The plugin to apply to the model.
87
+ transforms: The transforms to apply to the model.
88
+ quantizer: The quantizer to apply to the model.
89
+ keep_going: Whether to keep going if some backends fail.
90
+
91
+ Returns:
92
+ List of the paths to the output flatbuffer file.
93
+
94
+ Raises:
95
+ ValueError: If the given path is not a valid flatbuffer file.
96
+ """
97
+
98
+ backend = backend_class.create(config)
99
+
100
+ pipeline: list[types.Component] = [
101
+ c for c in [transforms, quantizer, plugin] if c is not None
102
+ ]
103
+ backends = list(backend.specialize())
104
+ return compile_model(flatbuffer, output_dir, backends, pipeline, keep_going)
105
+
106
+
107
+ def compile_model(
108
+ flatbuffer: types.Model,
109
+ output_dir: pathlib.Path,
110
+ backends: list[types.Backend],
111
+ pipeline: list[types.Component],
112
+ keep_going: bool = False,
113
+ ) -> types.CompilationResult:
114
+ """Compiles a TFLite model for NPU execution."""
115
+ if flatbuffer.in_memory:
116
+ base_name = "model"
117
+ else:
118
+ base_name = flatbuffer.path.name.removesuffix(common.DOT_TFLITE)
119
+ compile_models = types.CompilationResult()
120
+ with autotqdm.tqdm(backends, desc="Backend") as t_backends:
121
+ for backend in t_backends:
122
+ component_input = flatbuffer
123
+ backend = cast(types.Backend, backend)
124
+ input_name_pref = base_name + backend.target_id_suffix
125
+ t_backends.set_description(f"Compiling {backend.target_id}")
126
+ try:
127
+ for component in pipeline:
128
+ component = cast(types.Component, component)
129
+ t_backends.set_description(
130
+ f"Compiling {backend.target_id}: {component.component_name}"
131
+ )
132
+ component_output = types.Model.create_from_path(
133
+ output_dir
134
+ / f"{input_name_pref}_{component.component_name}{common.DOT_TFLITE}"
135
+ )
136
+ backend.call_component(component_input, component_output, component)
137
+ if not component_output.in_memory and not common.is_tflite(
138
+ component_output.path
139
+ ):
140
+ raise ValueError(
141
+ f"{component.component_name} failed to produce a TFLite model."
142
+ )
143
+ component_input = component_output
144
+ compile_models.models_with_backend.append((backend, component_input))
145
+ except ValueError as e:
146
+ if keep_going:
147
+ print(f"Skipping failed compilation for {backend.target}. Error: {e}")
148
+ compile_models.failed_backends.append((backend, str(e)))
149
+ else:
150
+ raise
151
+
152
+ return compile_models
@@ -0,0 +1,22 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Vendor backends for LiteRt."""
16
+ import os
17
+
18
+ from ai_edge_litert.aot.vendors.mediatek import mediatek_backend as _
19
+ from ai_edge_litert.aot.vendors.qualcomm import qualcomm_backend as _
20
+
21
+ if os.environ.get("GOOGLE_TENSOR_COMPILER_LIB") is not None:
22
+ from ai_edge_litert.aot.vendors.google_tensor import google_tensor_backend as _ # pylint: disable=g-import-not-at-top
File without changes
@@ -0,0 +1,157 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Backend implementation for the example compiler plugin.."""
17
+
18
+ import functools
19
+ from typing import Any
20
+
21
+ from ai_edge_litert.aot.core import components
22
+ from ai_edge_litert.aot.core import types
23
+ from ai_edge_litert.aot.vendors import import_vendor
24
+
25
+
26
+ class ExampleTarget(types.Target):
27
+ """Compilation target for the example backend."""
28
+
29
+ def __init__(self, soc_manufacturer: str, soc_model: str):
30
+ self.soc_manufacturer = soc_manufacturer
31
+ self.soc_model = soc_model
32
+
33
+ def __hash__(self) -> int:
34
+ return hash((self.soc_manufacturer, self.soc_model))
35
+
36
+ def __eq__(self, other) -> bool:
37
+ return (
38
+ self.soc_manufacturer == other.soc_manufacturer
39
+ and self.soc_model == other.soc_model
40
+ )
41
+
42
+ def __repr__(self) -> str:
43
+ return f"{self.soc_manufacturer}_{self.soc_model}"
44
+
45
+ def flatten(self) -> dict[str, Any]:
46
+ return {
47
+ "soc_manufacturer": self.soc_manufacturer,
48
+ "soc_model": self.soc_model,
49
+ }
50
+
51
+ @classmethod
52
+ def backend_id(cls) -> str:
53
+ return "example"
54
+
55
+
56
+ # Note this is not a real target so not auto-registered unless the module is
57
+ # imported.
58
+ @import_vendor.register_backend
59
+ class ExampleBackend(types.Backend):
60
+ """Backend implementation for the example compiler plugin."""
61
+
62
+ def __init__(self, config: types.Config):
63
+ super().__init__(config)
64
+ self._compilation_config = config.get("compilation_config", None)
65
+
66
+ @classmethod
67
+ def target_(cls) -> ExampleTarget:
68
+ return ExampleTarget("ExampleSocManufacturer", "ExampleSocModel")
69
+
70
+ @property
71
+ def target(self) -> ExampleTarget:
72
+ return self.target_()
73
+
74
+ @classmethod
75
+ def soc_manufacturer(cls) -> str:
76
+ return cls.target_().soc_manufacturer
77
+
78
+ @classmethod
79
+ def soc_model(cls) -> str:
80
+ return cls.target_().soc_model
81
+
82
+ @classmethod
83
+ def id(cls) -> str:
84
+ return "example"
85
+
86
+ @property
87
+ def target_id(self) -> str:
88
+ return ""
89
+
90
+ @property
91
+ def shared_pass_names(self) -> list[str]:
92
+ return ["example-pass"]
93
+
94
+ @classmethod
95
+ def create(cls, config: types.Config) -> "ExampleBackend":
96
+ if config.get("backend_id", "") != cls.id():
97
+ raise ValueError("Invalid backend id")
98
+ return cls(config)
99
+
100
+ def call_component(
101
+ self,
102
+ input_model: types.Model,
103
+ output_model: types.Model,
104
+ component: types.Component,
105
+ ):
106
+ return _call_component(component, self, input_model, output_model)
107
+
108
+
109
+ @functools.singledispatch
110
+ def _call_component(
111
+ component: types.Component,
112
+ backend: ExampleBackend,
113
+ unused_input_model: types.Model,
114
+ unused_output_model: types.Model,
115
+ ):
116
+ raise NotImplementedError(
117
+ f"{backend.id()} backend does not support"
118
+ f" {component.component_name} component."
119
+ )
120
+
121
+
122
+ @_call_component.register
123
+ def _apply_plugin(
124
+ component: components.ApplyPluginT,
125
+ backend: ExampleBackend,
126
+ input_model: types.Model,
127
+ output_model: types.Model,
128
+ ):
129
+ return component(
130
+ input_model,
131
+ output_model,
132
+ backend.soc_manufacturer,
133
+ backend.soc_model,
134
+ )
135
+
136
+
137
+ @_call_component.register
138
+ def _aie_quantizer(
139
+ component: components.AieQuantizerT,
140
+ unused_backend: ExampleBackend,
141
+ input_model: types.Model,
142
+ output_model: types.Model,
143
+ ):
144
+ return component(
145
+ input_model,
146
+ output_model,
147
+ )
148
+
149
+
150
+ @_call_component.register
151
+ def _mlir_transforms(
152
+ component: components.MlirTransformsT,
153
+ backend: ExampleBackend,
154
+ input_model: types.Model,
155
+ output_model: types.Model,
156
+ ):
157
+ return component(input_model, output_model, backend.shared_pass_names)
@@ -0,0 +1,128 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """A Fallback backend for LITERT."""
17
+
18
+ import functools
19
+ from typing import Any
20
+
21
+ from ai_edge_litert.aot.core import components
22
+ from ai_edge_litert.aot.core import types
23
+
24
+
25
+ class FallbackTarget(types.Target):
26
+ """A virtual Compilation target."""
27
+
28
+ def __hash__(self) -> int:
29
+ return hash(self.backend_id())
30
+
31
+ def __eq__(self, other: types.Target) -> bool:
32
+ return self.backend_id() == other.backend_id()
33
+
34
+ def __repr__(self) -> str:
35
+ return f"{self.backend_id()}"
36
+
37
+ @classmethod
38
+ def backend_id(cls) -> str:
39
+ return "fallback"
40
+
41
+ def flatten(self) -> dict[str, Any]:
42
+ return {"backend_id": self.backend_id()}
43
+
44
+
45
+ class FallbackBackend(types.Backend):
46
+ """Fallback backend for LITERT."""
47
+
48
+ @property
49
+ def target(self) -> FallbackTarget:
50
+ return FallbackTarget()
51
+
52
+ @property
53
+ def target_id(self) -> str:
54
+ return repr(self.target)
55
+
56
+ @classmethod
57
+ def id(cls) -> str:
58
+ return "fallback"
59
+
60
+ @classmethod
61
+ def create(cls, config: types.Config) -> "FallbackBackend":
62
+ if config.get("backend_id", "") != cls.id():
63
+ raise ValueError("Invalid backend id")
64
+ return cls(config)
65
+
66
+ @property
67
+ def quantize_recipe(self) -> str | None:
68
+ return self.config.get("quantize_recipe", None)
69
+
70
+ def call_component(
71
+ self,
72
+ input_model: types.Model,
73
+ output_model: types.Model,
74
+ component: types.Component,
75
+ ):
76
+ return _call_component(component, self, input_model, output_model)
77
+
78
+
79
+ @functools.singledispatch
80
+ def _call_component(
81
+ component: types.Component,
82
+ backend: FallbackBackend,
83
+ unused_input_model: types.Model,
84
+ unused_output_model: types.Model,
85
+ ):
86
+ raise NotImplementedError(
87
+ f"{backend.id()} backend does not support"
88
+ f" {component.component_name} component."
89
+ )
90
+
91
+
92
+ @_call_component.register
93
+ def _apply_plugin(
94
+ component: components.ApplyPluginT,
95
+ backend: FallbackBackend,
96
+ input_model: types.Model,
97
+ output_model: types.Model,
98
+ ):
99
+ """A no-op component that just copies the input model to the output model."""
100
+ del component, backend
101
+ if input_model.in_memory:
102
+ output_model.set_bytes(input_model.model_bytes)
103
+ else:
104
+ output_model.set_path(input_model.path)
105
+
106
+
107
+ @_call_component.register
108
+ def _aie_quantizer(
109
+ component: components.AieQuantizerT,
110
+ backend: FallbackBackend,
111
+ input_model: types.Model,
112
+ output_model: types.Model,
113
+ ):
114
+ return component(
115
+ input_model,
116
+ output_model,
117
+ quantization_recipe=backend.quantize_recipe,
118
+ )
119
+
120
+
121
+ @_call_component.register
122
+ def _mlir_transforms(
123
+ component: components.MlirTransformsT,
124
+ unused_backend: FallbackBackend,
125
+ input_model: types.Model,
126
+ output_model: types.Model,
127
+ ):
128
+ return component(input_model, output_model, [])
File without changes
@@ -0,0 +1,168 @@
1
+ # Copyright 2025 Google LLC.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Backend implementation for the Google Tensor compiler plugin.."""
16
+
17
+ import copy
18
+ import functools
19
+ import os
20
+ import pathlib
21
+ from typing import Iterable
22
+
23
+ from ai_edge_litert.aot.core import common
24
+ from ai_edge_litert.aot.core import components
25
+ from ai_edge_litert.aot.core import types
26
+ from ai_edge_litert.aot.vendors import import_vendor
27
+ from ai_edge_litert.aot.vendors.google_tensor import target as target_lib
28
+
29
+ COMPILER_PLUGIN_LIB_PATH = pathlib.Path(
30
+ "vendors/google_tensor/compiler/libLiteRtCompilerPlugin_google_tensor.so"
31
+ )
32
+
33
+
34
+ # Returns true if the flag is a google_tensor flag.
35
+ def _is_google_tensor_flag(flag: str) -> bool:
36
+ return flag.startswith("google_tensor_")
37
+
38
+
39
+ @import_vendor.register_backend
40
+ class GoogleTensorBackend(types.Backend):
41
+ """Backend implementation for the Google Tensor compiler plugin."""
42
+
43
+ def __init__(self, config: types.Config):
44
+ super().__init__(config)
45
+ self._compilation_config = config.get("compilation_config", None)
46
+
47
+ @property
48
+ def soc_manufacturer(self) -> target_lib.SocManufacturer:
49
+ return target_lib.SocManufacturer.GOOGLE
50
+
51
+ @property
52
+ def soc_model(self) -> target_lib.SocModel:
53
+ return target_lib.SocModel(self.config.get("soc_model", "ALL"))
54
+
55
+ @property
56
+ def target(self) -> target_lib.Target:
57
+ return target_lib.Target(self.soc_model, self.soc_manufacturer)
58
+
59
+ @property
60
+ def target_id(self) -> str:
61
+ return repr(self.target)
62
+
63
+ def specialize(self) -> Iterable["GoogleTensorBackend"]:
64
+ if self.soc_model != target_lib.SocModel.ALL:
65
+ yield self
66
+ else:
67
+ for soc_model in target_lib.SocModel:
68
+ if soc_model != target_lib.SocModel.ALL:
69
+ new_config = copy.deepcopy(self.config)
70
+ new_config["soc_model"] = soc_model.value
71
+ yield self.create(new_config)
72
+
73
+ @classmethod
74
+ def id(cls) -> str:
75
+ return target_lib._GOOGLE_TENSOR_BACKEND_ID # pylint: disable=protected-access
76
+
77
+ @classmethod
78
+ def create(cls, config: types.Config) -> "GoogleTensorBackend":
79
+ if config.get("backend_id", "") != cls.id():
80
+ raise ValueError("Invalid backend id")
81
+ return cls(config)
82
+
83
+ @property
84
+ def quantize_recipe(self) -> str | None:
85
+ return self.config.get("quantize_recipe", None)
86
+
87
+ def call_component(
88
+ self,
89
+ input_model: types.Model,
90
+ output_model: types.Model,
91
+ component: types.Component,
92
+ ):
93
+ return _call_component(component, self, input_model, output_model)
94
+
95
+
96
+ @functools.singledispatch
97
+ def _call_component(
98
+ component: types.Component,
99
+ backend: GoogleTensorBackend,
100
+ unused_input_model: types.Model,
101
+ unused_output_model: types.Model,
102
+ ):
103
+ raise NotImplementedError(
104
+ f"{backend.id()} backend does not support"
105
+ f" {component.component_name} component."
106
+ )
107
+
108
+
109
+ @_call_component.register
110
+ def _apply_plugin(
111
+ component: components.ApplyPluginT,
112
+ backend: GoogleTensorBackend,
113
+ input_model: types.Model,
114
+ output_model: types.Model,
115
+ ):
116
+ """Calls the apply plugin component."""
117
+ try:
118
+ # If the plugin is not built from source (i.e. using ai_edge_litert wheel),
119
+ # we find the plugin library directory from the package path.
120
+ # Otherwise we use the default library path.
121
+ plugin_path = common.get_resource(COMPILER_PLUGIN_LIB_PATH)
122
+ lib_dir = os.path.dirname(plugin_path)
123
+ sdk_libs_dir = os.environ.get("GOOGLE_TENSOR_COMPILER_LIB", None)
124
+
125
+ extra_kwargs = {"libs": lib_dir, "sdk_libs_path": sdk_libs_dir}
126
+ except FileNotFoundError:
127
+ extra_kwargs = {}
128
+
129
+ # Add google_tensor specific flags from the backend config.
130
+ for flag, value in backend.config.items():
131
+ if _is_google_tensor_flag(flag):
132
+ extra_kwargs[flag] = value
133
+
134
+ for flag, value in backend.config.get("compilation_config", {}).items():
135
+ if _is_google_tensor_flag(flag):
136
+ extra_kwargs[flag] = value
137
+
138
+ return component(
139
+ input_model,
140
+ output_model,
141
+ backend.soc_manufacturer,
142
+ backend.soc_model,
143
+ **extra_kwargs,
144
+ )
145
+
146
+
147
+ @_call_component.register
148
+ def _aie_quantizer(
149
+ component: components.AieQuantizerT,
150
+ backend: GoogleTensorBackend,
151
+ input_model: types.Model,
152
+ output_model: types.Model,
153
+ ):
154
+ return component(
155
+ input_model,
156
+ output_model,
157
+ quantization_recipe=backend.quantize_recipe,
158
+ )
159
+
160
+
161
+ @_call_component.register
162
+ def _mlir_transforms(
163
+ component: components.MlirTransformsT,
164
+ unused_backend: GoogleTensorBackend,
165
+ input_model: types.Model,
166
+ output_model: types.Model,
167
+ ):
168
+ return component(input_model, output_model, [])