litert-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/litert_cli.ipynb +313 -0
- examples/models/presets/default.py +19 -0
- examples/run_cli_demo.sh +38 -0
- examples/run_cli_npu.sh +89 -0
- examples/run_commands.sh +67 -0
- examples/run_models.sh +63 -0
- examples/run_smoke_tests.sh +58 -0
- examples/utils.ps1 +163 -0
- examples/utils.sh +184 -0
- litert_cli/__init__.py +15 -0
- litert_cli/commands/benchmark/__init__.py +16 -0
- litert_cli/commands/benchmark/android.py +212 -0
- litert_cli/commands/benchmark/cli.py +294 -0
- litert_cli/commands/benchmark/desktop.py +228 -0
- litert_cli/commands/benchmark/gcp.py +336 -0
- litert_cli/commands/clean.py +73 -0
- litert_cli/commands/compile.py +211 -0
- litert_cli/commands/convert/__init__.py +20 -0
- litert_cli/commands/convert/cli.py +255 -0
- litert_cli/commands/convert/generic.py +211 -0
- litert_cli/commands/convert/huggingface.py +175 -0
- litert_cli/commands/delete.py +56 -0
- litert_cli/commands/download.py +274 -0
- litert_cli/commands/import.py +124 -0
- litert_cli/commands/list.py +132 -0
- litert_cli/commands/lm.py +74 -0
- litert_cli/commands/quantize.py +193 -0
- litert_cli/commands/run/__init__.py +16 -0
- litert_cli/commands/run/android.py +394 -0
- litert_cli/commands/run/cli.py +297 -0
- litert_cli/commands/run/desktop.py +340 -0
- litert_cli/commands/visualize.py +234 -0
- litert_cli/core/android_utils.py +304 -0
- litert_cli/core/android_utils_test.py +236 -0
- litert_cli/core/constants.py +131 -0
- litert_cli/core/deps.py +180 -0
- litert_cli/core/deps_test.py +101 -0
- litert_cli/core/inputs.py +203 -0
- litert_cli/core/inputs_test.py +176 -0
- litert_cli/core/log_filters.py +50 -0
- litert_cli/core/models.py +96 -0
- litert_cli/core/npu_utils.py +382 -0
- litert_cli/core/targets_manager.py +192 -0
- litert_cli/core/utils.py +58 -0
- litert_cli/litert.py +119 -0
- litert_cli/litert_help_test.py +51 -0
- litert_cli/litert_test.py +88 -0
- litert_cli/models/__init__.py +145 -0
- litert_cli/models/asr/__init__.py +15 -0
- litert_cli/models/asr/asr_model.py +108 -0
- litert_cli/models/asr/parakeet_ctc.py +165 -0
- litert_cli/models/asr/runner.py +482 -0
- litert_cli/models/base.py +57 -0
- litert_cli/test_data/dummy_calib_data.py +26 -0
- litert_cli/test_data/dummy_cv_model.py +52 -0
- litert_cli/test_data/dummy_cv_model.tflite +0 -0
- litert_cli/test_data/generate_test_inputs.py +51 -0
- litert_cli/test_data/mobilenet_v3_calib_data.py +25 -0
- litert_cli/test_data/quantize_recipe.json +16 -0
- litert_cli/test_data/resnet18.py +31 -0
- litert_cli-0.1.0.dist-info/METADATA +38 -0
- litert_cli-0.1.0.dist-info/RECORD +67 -0
- litert_cli-0.1.0.dist-info/WHEEL +5 -0
- litert_cli-0.1.0.dist-info/entry_points.txt +2 -0
- litert_cli-0.1.0.dist-info/licenses/LICENSE +202 -0
- litert_cli-0.1.0.dist-info/top_level.txt +3 -0
- tools/build_wheels.py +122 -0
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
# Copyright 2026 The LiteRT CLI Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Benchmarking on Google AI Edge Portal in GCP."""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import json
|
|
21
|
+
import os
|
|
22
|
+
import pathlib
|
|
23
|
+
import subprocess
|
|
24
|
+
import time
|
|
25
|
+
import urllib.error
|
|
26
|
+
import urllib.request
|
|
27
|
+
import uuid
|
|
28
|
+
|
|
29
|
+
import click
|
|
30
|
+
|
|
31
|
+
_DEFAULT_GCP_PROJECT = os.environ.get("LITERT_GCP_PROJECT")
|
|
32
|
+
_DEFAULT_GCP_LOCATION = "us-central1"
|
|
33
|
+
_GCP_BUCKET = os.environ.get("LITERT_GCP_BUCKET")
|
|
34
|
+
_DEFAULT_PORTAL_ENDPOINT = "https://aiedgeportal.googleapis.com/v1alpha"
|
|
35
|
+
# NOTE: Keep in sync with Google AI Edge Portal runtime versions.
|
|
36
|
+
_DEFAULT_PORTAL_LITERT_RUNTIME_VERSION = "litert-v2.0.3"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_submission_url(
|
|
40
|
+
portal_endpoint: str, gcp_project: str, gcp_location: str, job_id: str
|
|
41
|
+
) -> str:
|
|
42
|
+
"""Returns the URL for submitting a benchmark job."""
|
|
43
|
+
return (
|
|
44
|
+
f"{portal_endpoint}/projects/{gcp_project}/locations/{gcp_location}"
|
|
45
|
+
f"/benchmarks?benchmarkId={job_id}"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _get_operation_url(portal_endpoint: str, op_name: str) -> str:
|
|
50
|
+
"""Returns the URL for polling operation status."""
|
|
51
|
+
return f"{portal_endpoint}/{op_name}"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _get_console_url(
|
|
55
|
+
gcp_project: str, gcp_location: str, benchmark_id: str
|
|
56
|
+
) -> str:
|
|
57
|
+
"""Returns the Google Cloud Console URL for viewing progress."""
|
|
58
|
+
return (
|
|
59
|
+
"https://console.cloud.google.com/ai-edge-portal/benchmarks/details/"
|
|
60
|
+
f"{gcp_location}/{benchmark_id}?project={gcp_project}"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def run_gcp(
|
|
65
|
+
model_path_str: str,
|
|
66
|
+
accelerator: str,
|
|
67
|
+
devices: list[str],
|
|
68
|
+
gcp_project: str | None = None,
|
|
69
|
+
gcp_bucket: str | None = None,
|
|
70
|
+
compilation_mode: str | None = None,
|
|
71
|
+
soc_model: str | None = None,
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Runs the model on GCP via AI Edge Portal Cloud API.
|
|
74
|
+
|
|
75
|
+
Uploads model to GCS if it's not already there.
|
|
76
|
+
Submits benchmark job to AI Edge Portal Cloud API.
|
|
77
|
+
Polls benchmark job status.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
model_path_str: Path to the LiteRT model file (local or gs://).
|
|
81
|
+
accelerator: Hardware accelerator to use (cpu, gpu, npu).
|
|
82
|
+
devices: Target device model(s) (e.g., 'pixel 7', 'sm-s931u1').
|
|
83
|
+
gcp_project: GCP project ID for benchmarking.
|
|
84
|
+
gcp_bucket: GCS bucket name for uploading model.
|
|
85
|
+
compilation_mode: Compilation mode for NPU (jit, aot).
|
|
86
|
+
soc_model: Target SoC model for NPU AOT mode.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
device_list = []
|
|
90
|
+
if isinstance(devices, str):
|
|
91
|
+
items = [devices]
|
|
92
|
+
else:
|
|
93
|
+
items = devices
|
|
94
|
+
|
|
95
|
+
for item in items:
|
|
96
|
+
if item:
|
|
97
|
+
parts = [p.strip() for p in item.split(",") if p.strip()]
|
|
98
|
+
device_list.extend(parts)
|
|
99
|
+
|
|
100
|
+
if not device_list:
|
|
101
|
+
raise click.ClickException(
|
|
102
|
+
"Error: --device is required for running GCP benchmark tests."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if not gcp_project:
|
|
106
|
+
gcp_project = _DEFAULT_GCP_PROJECT
|
|
107
|
+
|
|
108
|
+
if not gcp_project:
|
|
109
|
+
raise click.ClickException(
|
|
110
|
+
"Missing GCP project. You must specify a GCP project by passing"
|
|
111
|
+
" '--gcp-project <PROJECT_ID>' or by setting the 'LITERT_GCP_PROJECT'"
|
|
112
|
+
" environment variable."
|
|
113
|
+
)
|
|
114
|
+
model_path = model_path_str
|
|
115
|
+
# Upload model to GCS if it's not already there.
|
|
116
|
+
if not model_path.startswith("gs://"):
|
|
117
|
+
local_model = pathlib.Path(model_path)
|
|
118
|
+
if not local_model.exists():
|
|
119
|
+
click.secho(f"Error: Local model file not found: {model_path}", fg="red")
|
|
120
|
+
return
|
|
121
|
+
|
|
122
|
+
target_bucket = gcp_bucket or _GCP_BUCKET
|
|
123
|
+
if not target_bucket:
|
|
124
|
+
target_bucket = f"{gcp_project}-litert-models"
|
|
125
|
+
click.secho(
|
|
126
|
+
"Note: GCS bucket not specified via '--gcp-bucket' or"
|
|
127
|
+
" 'LITERT_GCP_BUCKET' environment variable. Using default"
|
|
128
|
+
f" project-bound bucket 'gs://{target_bucket}'.",
|
|
129
|
+
fg="yellow",
|
|
130
|
+
)
|
|
131
|
+
else:
|
|
132
|
+
click.echo(f"Using specified GCS bucket 'gs://{target_bucket}'.")
|
|
133
|
+
|
|
134
|
+
# Check if bucket exists, create if not
|
|
135
|
+
click.echo(
|
|
136
|
+
f"Ensuring GCS bucket 'gs://{target_bucket}' exists for project"
|
|
137
|
+
f" '{gcp_project}'..."
|
|
138
|
+
)
|
|
139
|
+
try:
|
|
140
|
+
check_res = subprocess.run(
|
|
141
|
+
["gcloud", "storage", "ls", f"gs://{target_bucket}"],
|
|
142
|
+
check=False,
|
|
143
|
+
stdout=subprocess.DEVNULL,
|
|
144
|
+
stderr=subprocess.DEVNULL,
|
|
145
|
+
)
|
|
146
|
+
if check_res.returncode != 0:
|
|
147
|
+
click.secho(
|
|
148
|
+
f"Creating GCS bucket 'gs://{target_bucket}' in location"
|
|
149
|
+
f" '{_DEFAULT_GCP_LOCATION}'...",
|
|
150
|
+
fg="cyan",
|
|
151
|
+
)
|
|
152
|
+
subprocess.run(
|
|
153
|
+
[
|
|
154
|
+
"gcloud",
|
|
155
|
+
"storage",
|
|
156
|
+
"buckets",
|
|
157
|
+
"create",
|
|
158
|
+
f"gs://{target_bucket}",
|
|
159
|
+
f"--project={gcp_project}",
|
|
160
|
+
f"--location={_DEFAULT_GCP_LOCATION}",
|
|
161
|
+
],
|
|
162
|
+
check=True,
|
|
163
|
+
stdout=subprocess.DEVNULL,
|
|
164
|
+
stderr=subprocess.DEVNULL,
|
|
165
|
+
)
|
|
166
|
+
except subprocess.CalledProcessError as e:
|
|
167
|
+
click.secho(
|
|
168
|
+
f"Error: Failed to ensure GCS bucket 'gs://{target_bucket}': {e}",
|
|
169
|
+
fg="red",
|
|
170
|
+
)
|
|
171
|
+
return
|
|
172
|
+
|
|
173
|
+
click.secho(
|
|
174
|
+
f"Uploading local model '{model_path}' to gs://{target_bucket}/...",
|
|
175
|
+
fg="cyan",
|
|
176
|
+
)
|
|
177
|
+
try:
|
|
178
|
+
subprocess.run(
|
|
179
|
+
[
|
|
180
|
+
"gcloud",
|
|
181
|
+
"storage",
|
|
182
|
+
"cp",
|
|
183
|
+
str(local_model),
|
|
184
|
+
f"gs://{target_bucket}/",
|
|
185
|
+
],
|
|
186
|
+
check=True,
|
|
187
|
+
)
|
|
188
|
+
model_path = f"gs://{target_bucket}/{local_model.name}"
|
|
189
|
+
except subprocess.CalledProcessError as e:
|
|
190
|
+
click.secho(
|
|
191
|
+
f"Error: Failed to upload '{model_path}' to Google Cloud Storage:"
|
|
192
|
+
f" {e}",
|
|
193
|
+
fg="red",
|
|
194
|
+
)
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
job_id = f"litert-cli-benchmark-{uuid.uuid4().hex[:8]}"
|
|
198
|
+
|
|
199
|
+
click.echo("Fetching GCP access token...")
|
|
200
|
+
try:
|
|
201
|
+
token = subprocess.check_output(
|
|
202
|
+
["gcloud", "auth", "print-access-token"], text=True
|
|
203
|
+
).strip()
|
|
204
|
+
except subprocess.CalledProcessError as e:
|
|
205
|
+
click.secho(
|
|
206
|
+
"Error: Failed to get gcloud access token. Please run 'gcloud auth"
|
|
207
|
+
f" login' first. Details: {e}",
|
|
208
|
+
fg="red",
|
|
209
|
+
)
|
|
210
|
+
return
|
|
211
|
+
|
|
212
|
+
portal_endpoint = os.environ.get(
|
|
213
|
+
"AI_EDGE_PORTAL_ENDPOINT", _DEFAULT_PORTAL_ENDPOINT
|
|
214
|
+
).rstrip("/")
|
|
215
|
+
url = _get_submission_url(
|
|
216
|
+
portal_endpoint, gcp_project, _DEFAULT_GCP_LOCATION, job_id
|
|
217
|
+
)
|
|
218
|
+
headers = {
|
|
219
|
+
"Authorization": f"Bearer {token}",
|
|
220
|
+
"Content-Type": "application/json",
|
|
221
|
+
"X-Goog-User-Project": gcp_project,
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
accel_name = accelerator.upper()
|
|
225
|
+
|
|
226
|
+
run_spec: dict[str, any] = {
|
|
227
|
+
"accelerator": accel_name,
|
|
228
|
+
"id": accelerator.lower(),
|
|
229
|
+
"displayName": f"{accelerator.lower()}_test",
|
|
230
|
+
"runtimeVersion": _DEFAULT_PORTAL_LITERT_RUNTIME_VERSION,
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
if accel_name == "NPU":
|
|
234
|
+
comp_mode = (compilation_mode or "jit").upper()
|
|
235
|
+
if comp_mode == "AOT":
|
|
236
|
+
raise click.ClickException(
|
|
237
|
+
"Error: NPU AOT compilation mode is temporarily disabled for GCP"
|
|
238
|
+
" benchmarking. Please use JIT compilation mode (--jit) instead."
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
if not soc_model:
|
|
242
|
+
raise click.ClickException(
|
|
243
|
+
"Error: --soc-model is required when using NPU JIT compilation mode."
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
run_spec["modelPath"] = model_path.replace("gs://", "")
|
|
247
|
+
run_spec["npuConfig"] = {
|
|
248
|
+
"npuCompilationMode": "JIT",
|
|
249
|
+
"socConfigs": [{
|
|
250
|
+
"socModel": soc_model,
|
|
251
|
+
"aotModelPath": "",
|
|
252
|
+
}],
|
|
253
|
+
"cpuFallbackConfig": {"threadCount": 4},
|
|
254
|
+
}
|
|
255
|
+
else:
|
|
256
|
+
run_spec["modelPath"] = model_path.replace("gs://", "")
|
|
257
|
+
|
|
258
|
+
body = {
|
|
259
|
+
"displayName": job_id,
|
|
260
|
+
"modelPaths": [],
|
|
261
|
+
"deviceConfigs": [{"deviceModel": d} for d in device_list],
|
|
262
|
+
"runSpecs": [run_spec],
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
# Submit the benchmark job via http requests to AI Edge Portal Cloud API.
|
|
266
|
+
req = urllib.request.Request(
|
|
267
|
+
url, data=json.dumps(body).encode("utf-8"), headers=headers, method="POST"
|
|
268
|
+
)
|
|
269
|
+
click.echo(
|
|
270
|
+
f"Submitting '{accelerator}' benchmark job '{job_id}' to AI Edge Portal"
|
|
271
|
+
f" (Project: {gcp_project}, Location:"
|
|
272
|
+
f" {_DEFAULT_GCP_LOCATION})..."
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
with urllib.request.urlopen(req) as response:
|
|
277
|
+
resp_data = json.loads(response.read().decode())
|
|
278
|
+
click.secho("Benchmark job submitted successfully!", fg="green")
|
|
279
|
+
|
|
280
|
+
op_name = resp_data.get("name", "")
|
|
281
|
+
if op_name and "operations" in op_name:
|
|
282
|
+
op_url = _get_operation_url(portal_endpoint, op_name)
|
|
283
|
+
console_url = _get_console_url(
|
|
284
|
+
gcp_project, _DEFAULT_GCP_LOCATION, job_id
|
|
285
|
+
)
|
|
286
|
+
click.echo(
|
|
287
|
+
f"Waiting for benchmark to complete (Operation: {op_name}). This"
|
|
288
|
+
" may take a few minutes..."
|
|
289
|
+
)
|
|
290
|
+
click.secho(
|
|
291
|
+
f"View progress on the Cloud Console: {console_url}", fg="cyan"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
try:
|
|
295
|
+
while True:
|
|
296
|
+
time.sleep(15)
|
|
297
|
+
click.echo(".", nl=False)
|
|
298
|
+
req_op = urllib.request.Request(op_url, headers=headers)
|
|
299
|
+
try:
|
|
300
|
+
with urllib.request.urlopen(req_op) as res_op:
|
|
301
|
+
op_data = json.loads(res_op.read().decode())
|
|
302
|
+
except urllib.error.HTTPError as e:
|
|
303
|
+
with e:
|
|
304
|
+
click.secho(f"\nError polling operation: {e}", fg="yellow")
|
|
305
|
+
continue
|
|
306
|
+
|
|
307
|
+
if op_data.get("done"):
|
|
308
|
+
click.echo("") # Print a newline after the dots
|
|
309
|
+
if "error" in op_data:
|
|
310
|
+
click.secho(
|
|
311
|
+
"Benchmark failed:"
|
|
312
|
+
f" {json.dumps(op_data['error'], indent=2)}",
|
|
313
|
+
fg="red",
|
|
314
|
+
)
|
|
315
|
+
else:
|
|
316
|
+
click.secho(
|
|
317
|
+
"Benchmark operation completed successfully!", fg="green"
|
|
318
|
+
)
|
|
319
|
+
click.echo(json.dumps(op_data.get("response", {}), indent=2))
|
|
320
|
+
break
|
|
321
|
+
except KeyboardInterrupt:
|
|
322
|
+
click.echo("")
|
|
323
|
+
click.secho(
|
|
324
|
+
"\nPolling interrupted. The benchmark job is still running.",
|
|
325
|
+
fg="yellow",
|
|
326
|
+
)
|
|
327
|
+
click.echo(
|
|
328
|
+
f"You can check its status later by viewing it in the console:"
|
|
329
|
+
" {console_url}"
|
|
330
|
+
)
|
|
331
|
+
else:
|
|
332
|
+
click.echo(json.dumps(resp_data, indent=2))
|
|
333
|
+
except urllib.error.HTTPError as e:
|
|
334
|
+
err_body = e.read().decode()
|
|
335
|
+
click.secho(f"Failed to submit benchmark: {e.code} {e.reason}", fg="red")
|
|
336
|
+
click.secho(f"Details: {err_body}", fg="red")
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright 2026 The LiteRT CLI Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Clean command for LiteRT CLI."""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import pathlib
|
|
21
|
+
import shlex
|
|
22
|
+
import shutil
|
|
23
|
+
import subprocess
|
|
24
|
+
|
|
25
|
+
import click
|
|
26
|
+
from litert_cli.core import android_utils
|
|
27
|
+
from litert_cli.core import constants
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@click.command(name="clean")
|
|
31
|
+
def clean_cmd() -> None:
|
|
32
|
+
"""Cleans up local caches, downloaded files, and remote Android directories."""
|
|
33
|
+
|
|
34
|
+
click.echo("Cleaning LiteRT CLI workspace...")
|
|
35
|
+
|
|
36
|
+
# 1. Clean local downloaded CLI root qairt directory
|
|
37
|
+
qairt_dir = pathlib.Path(constants.LITERT_CLI_ROOT) / "qairt"
|
|
38
|
+
try:
|
|
39
|
+
shutil.rmtree(qairt_dir)
|
|
40
|
+
click.echo(f"Removing local qairt workspace: {qairt_dir}")
|
|
41
|
+
except FileNotFoundError:
|
|
42
|
+
pass
|
|
43
|
+
except OSError as e:
|
|
44
|
+
click.secho(f"Warning: Failed to remove {qairt_dir}: {e}", fg="yellow")
|
|
45
|
+
|
|
46
|
+
# 2. Clean local cache directory (e.g., ~/.cache/litert-cli)
|
|
47
|
+
cache_dir = pathlib.Path(constants.LITERT_CLI_CACHE_DIR)
|
|
48
|
+
try:
|
|
49
|
+
shutil.rmtree(cache_dir)
|
|
50
|
+
click.echo(f"Removing local cache: {cache_dir}")
|
|
51
|
+
except FileNotFoundError:
|
|
52
|
+
pass
|
|
53
|
+
except OSError as e:
|
|
54
|
+
click.secho(f"Warning: Failed to remove {cache_dir}: {e}", fg="yellow")
|
|
55
|
+
|
|
56
|
+
# 3. Clean remote Android directory
|
|
57
|
+
try:
|
|
58
|
+
android_utils.check_adb()
|
|
59
|
+
android_root = constants.LITERT_CLI_ANDROID_ROOT
|
|
60
|
+
click.echo(f"Removing remote Android workspace via adb: {android_root}")
|
|
61
|
+
|
|
62
|
+
subprocess.run(
|
|
63
|
+
["adb", "shell", f"rm -rf {shlex.quote(android_root)}"],
|
|
64
|
+
check=False,
|
|
65
|
+
stdout=subprocess.DEVNULL,
|
|
66
|
+
stderr=subprocess.DEVNULL,
|
|
67
|
+
)
|
|
68
|
+
except click.ClickException:
|
|
69
|
+
click.echo(
|
|
70
|
+
"No active Android device found via adb. Skipping remote cleanup."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
click.secho("Cleanup complete!", fg="green")
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright 2026 The LiteRT CLI Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""CLI module for the `litert compile` command.
|
|
17
|
+
|
|
18
|
+
This command applies NPU AOT compilation to a standard LiteRT (.tflite) model.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
from collections.abc import Sequence
|
|
24
|
+
import importlib
|
|
25
|
+
import pathlib
|
|
26
|
+
import shutil
|
|
27
|
+
import textwrap
|
|
28
|
+
|
|
29
|
+
import click
|
|
30
|
+
from litert_cli.core import constants
|
|
31
|
+
from litert_cli.core import deps
|
|
32
|
+
from litert_cli.core import npu_utils
|
|
33
|
+
from litert_cli.core import utils
|
|
34
|
+
from litert_cli.core.targets_manager import TargetsManager
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@click.command(
|
|
38
|
+
"compile",
|
|
39
|
+
help=textwrap.dedent("""\
|
|
40
|
+
Apply AOT (Ahead-of-Time) compilation for NPUs to a TFLite model.
|
|
41
|
+
|
|
42
|
+
model_path: Path to a valid .tflite model.
|
|
43
|
+
|
|
44
|
+
Examples:
|
|
45
|
+
|
|
46
|
+
Basic Compilation for specific NPU:
|
|
47
|
+
|
|
48
|
+
$ litert compile my_model.tflite --target sm8450
|
|
49
|
+
|
|
50
|
+
Compile for multiple targets and export AI Pack for Android:
|
|
51
|
+
|
|
52
|
+
$ litert compile my_model.tflite --target sm8550 --target mt6989 \
|
|
53
|
+
--export-aipack my_npu_models
|
|
54
|
+
"""),
|
|
55
|
+
)
|
|
56
|
+
@click.argument("model_path", type=str)
|
|
57
|
+
@click.option(
|
|
58
|
+
"--update-targets",
|
|
59
|
+
type=str,
|
|
60
|
+
required=False,
|
|
61
|
+
default=None,
|
|
62
|
+
help=(
|
|
63
|
+
"Update SoC target lists from GitHub. Pass 'main' for latest, or a"
|
|
64
|
+
" version tag like 'v2.1.4'."
|
|
65
|
+
),
|
|
66
|
+
)
|
|
67
|
+
@click.option(
|
|
68
|
+
"--target",
|
|
69
|
+
type=str,
|
|
70
|
+
multiple=True,
|
|
71
|
+
required=True,
|
|
72
|
+
help="One or more NPU target codenames (e.g., sm8450).",
|
|
73
|
+
)
|
|
74
|
+
@click.option(
|
|
75
|
+
"--export-aipack",
|
|
76
|
+
type=click.Path(
|
|
77
|
+
file_okay=False,
|
|
78
|
+
dir_okay=True,
|
|
79
|
+
resolve_path=True,
|
|
80
|
+
path_type=pathlib.Path,
|
|
81
|
+
),
|
|
82
|
+
required=False,
|
|
83
|
+
default=None,
|
|
84
|
+
help=(
|
|
85
|
+
"If specified, exports an AI Pack directory for PODAI instead of"
|
|
86
|
+
" standard .tflite."
|
|
87
|
+
),
|
|
88
|
+
)
|
|
89
|
+
@click.option(
|
|
90
|
+
"--output-dir",
|
|
91
|
+
type=click.Path(
|
|
92
|
+
file_okay=False,
|
|
93
|
+
dir_okay=True,
|
|
94
|
+
resolve_path=True,
|
|
95
|
+
path_type=pathlib.Path,
|
|
96
|
+
),
|
|
97
|
+
required=False,
|
|
98
|
+
default=None,
|
|
99
|
+
help=(
|
|
100
|
+
"Directory to save the compiled TFLite model. Defaults to current"
|
|
101
|
+
" directory."
|
|
102
|
+
),
|
|
103
|
+
)
|
|
104
|
+
@deps.require_extra("compile")
|
|
105
|
+
def compile_cmd(
|
|
106
|
+
model_path: str,
|
|
107
|
+
target: Sequence[str],
|
|
108
|
+
update_targets: str | None,
|
|
109
|
+
export_aipack: pathlib.Path | None,
|
|
110
|
+
output_dir: pathlib.Path | None,
|
|
111
|
+
) -> None:
|
|
112
|
+
"""Compiles a tflite model with NPU AOT backends.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
model_path: Path to the input tflite model or Model Reference.
|
|
116
|
+
target: List of target SoCs or acceleration backends.
|
|
117
|
+
export_aipack: Path to export the compiled model as an AI Pack.
|
|
118
|
+
output_dir: Directory to save the compiled TFLite model.
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
click.ClickException: If compilation or export fails.
|
|
122
|
+
"""
|
|
123
|
+
from ai_edge_litert.aot import aot_compile as aot_lib
|
|
124
|
+
from ai_edge_litert.aot.ai_pack import export_lib as ai_pack_export
|
|
125
|
+
from litert_cli.core import models as core_models
|
|
126
|
+
|
|
127
|
+
# Quiet if default is true
|
|
128
|
+
if constants.DEFAULT_QUIET:
|
|
129
|
+
utils.enable_quiet_mode()
|
|
130
|
+
|
|
131
|
+
# Initialize targets
|
|
132
|
+
manager = TargetsManager()
|
|
133
|
+
|
|
134
|
+
# Handle update or first-run download
|
|
135
|
+
if update_targets:
|
|
136
|
+
manager.download_targets(version=update_targets)
|
|
137
|
+
importlib.reload(constants)
|
|
138
|
+
else:
|
|
139
|
+
# Check if cache exists
|
|
140
|
+
if not manager.load_targets():
|
|
141
|
+
click.echo("No target cache found. Downloading default target lists...")
|
|
142
|
+
try:
|
|
143
|
+
manager.download_targets(version="main")
|
|
144
|
+
importlib.reload(constants)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
click.echo(f"Warning: Failed to download default targets: {e}")
|
|
147
|
+
click.echo("Falling back to built-in static target lists.")
|
|
148
|
+
|
|
149
|
+
resolved_model_path, _ = core_models.resolve_model_reference(model_path)
|
|
150
|
+
if str(resolved_model_path) != str(model_path):
|
|
151
|
+
click.echo(f"Resolved model '{model_path}' to '{resolved_model_path}'")
|
|
152
|
+
|
|
153
|
+
resolved_model_path = pathlib.Path(resolved_model_path)
|
|
154
|
+
|
|
155
|
+
click.echo(
|
|
156
|
+
f"Compiling model {resolved_model_path} for targets: {', '.join(target)}"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
aot_targets = [npu_utils.get_aot_target(t) for t in target]
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
compiled_models = aot_lib.aot_compile(
|
|
163
|
+
str(resolved_model_path),
|
|
164
|
+
target=aot_targets,
|
|
165
|
+
keep_going=False,
|
|
166
|
+
)
|
|
167
|
+
resolved_output_dir = output_dir or pathlib.Path.cwd()
|
|
168
|
+
|
|
169
|
+
base_name = resolved_model_path.stem
|
|
170
|
+
|
|
171
|
+
if export_aipack:
|
|
172
|
+
click.echo(f"Exporting AI Pack to: {export_aipack}")
|
|
173
|
+
if export_aipack.exists():
|
|
174
|
+
if export_aipack.is_dir():
|
|
175
|
+
shutil.rmtree(export_aipack)
|
|
176
|
+
else:
|
|
177
|
+
export_aipack.unlink()
|
|
178
|
+
export_aipack.mkdir(parents=True, exist_ok=True)
|
|
179
|
+
try:
|
|
180
|
+
ai_pack_export.export(
|
|
181
|
+
compiled_models, str(export_aipack), base_name, "model"
|
|
182
|
+
)
|
|
183
|
+
except Exception as e:
|
|
184
|
+
raise click.ClickException(f"Failed to export AI Pack: {e!r}") from e
|
|
185
|
+
else:
|
|
186
|
+
click.echo(f"Exporting compiled TFLite to: {resolved_output_dir}")
|
|
187
|
+
try:
|
|
188
|
+
compiled_models.export(str(resolved_output_dir), model_name=base_name)
|
|
189
|
+
except Exception as e:
|
|
190
|
+
raise click.ClickException(f"Failed to export models: {e!r}") from e
|
|
191
|
+
except click.ClickException:
|
|
192
|
+
raise
|
|
193
|
+
except Exception as e:
|
|
194
|
+
unsupported = [
|
|
195
|
+
t
|
|
196
|
+
for t in target
|
|
197
|
+
if not any(k in t.lower() for k in ("sm", "qnn", "qualcomm"))
|
|
198
|
+
]
|
|
199
|
+
if unsupported:
|
|
200
|
+
raise click.ClickException(
|
|
201
|
+
f"AOT Compilation failed for target(s) {', '.join(unsupported)}:"
|
|
202
|
+
" Currently, only the Qualcomm platform is fully supported for"
|
|
203
|
+
" offline AOT compilation."
|
|
204
|
+
) from e
|
|
205
|
+
|
|
206
|
+
raise click.ClickException(
|
|
207
|
+
f"AOT Compilation of '{resolved_model_path}' for targets"
|
|
208
|
+
f" {', '.join(target)} failed: {e!r}"
|
|
209
|
+
) from e
|
|
210
|
+
|
|
211
|
+
click.secho("AOT Compilation Completed Successfully!", fg="green")
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2026 The LiteRT CLI Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""CLI command for converting models to LiteRT format."""
|
|
17
|
+
|
|
18
|
+
from litert_cli.commands.convert.cli import convert_cmd
|
|
19
|
+
|
|
20
|
+
__all__ = ["convert_cmd"]
|