litert-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/litert_cli.ipynb +313 -0
- examples/models/presets/default.py +19 -0
- examples/run_cli_demo.sh +38 -0
- examples/run_cli_npu.sh +89 -0
- examples/run_commands.sh +67 -0
- examples/run_models.sh +63 -0
- examples/run_smoke_tests.sh +58 -0
- examples/utils.ps1 +163 -0
- examples/utils.sh +184 -0
- litert_cli/__init__.py +15 -0
- litert_cli/commands/benchmark/__init__.py +16 -0
- litert_cli/commands/benchmark/android.py +212 -0
- litert_cli/commands/benchmark/cli.py +294 -0
- litert_cli/commands/benchmark/desktop.py +228 -0
- litert_cli/commands/benchmark/gcp.py +336 -0
- litert_cli/commands/clean.py +73 -0
- litert_cli/commands/compile.py +211 -0
- litert_cli/commands/convert/__init__.py +20 -0
- litert_cli/commands/convert/cli.py +255 -0
- litert_cli/commands/convert/generic.py +211 -0
- litert_cli/commands/convert/huggingface.py +175 -0
- litert_cli/commands/delete.py +56 -0
- litert_cli/commands/download.py +274 -0
- litert_cli/commands/import.py +124 -0
- litert_cli/commands/list.py +132 -0
- litert_cli/commands/lm.py +74 -0
- litert_cli/commands/quantize.py +193 -0
- litert_cli/commands/run/__init__.py +16 -0
- litert_cli/commands/run/android.py +394 -0
- litert_cli/commands/run/cli.py +297 -0
- litert_cli/commands/run/desktop.py +340 -0
- litert_cli/commands/visualize.py +234 -0
- litert_cli/core/android_utils.py +304 -0
- litert_cli/core/android_utils_test.py +236 -0
- litert_cli/core/constants.py +131 -0
- litert_cli/core/deps.py +180 -0
- litert_cli/core/deps_test.py +101 -0
- litert_cli/core/inputs.py +203 -0
- litert_cli/core/inputs_test.py +176 -0
- litert_cli/core/log_filters.py +50 -0
- litert_cli/core/models.py +96 -0
- litert_cli/core/npu_utils.py +382 -0
- litert_cli/core/targets_manager.py +192 -0
- litert_cli/core/utils.py +58 -0
- litert_cli/litert.py +119 -0
- litert_cli/litert_help_test.py +51 -0
- litert_cli/litert_test.py +88 -0
- litert_cli/models/__init__.py +145 -0
- litert_cli/models/asr/__init__.py +15 -0
- litert_cli/models/asr/asr_model.py +108 -0
- litert_cli/models/asr/parakeet_ctc.py +165 -0
- litert_cli/models/asr/runner.py +482 -0
- litert_cli/models/base.py +57 -0
- litert_cli/test_data/dummy_calib_data.py +26 -0
- litert_cli/test_data/dummy_cv_model.py +52 -0
- litert_cli/test_data/dummy_cv_model.tflite +0 -0
- litert_cli/test_data/generate_test_inputs.py +51 -0
- litert_cli/test_data/mobilenet_v3_calib_data.py +25 -0
- litert_cli/test_data/quantize_recipe.json +16 -0
- litert_cli/test_data/resnet18.py +31 -0
- litert_cli-0.1.0.dist-info/METADATA +38 -0
- litert_cli-0.1.0.dist-info/RECORD +67 -0
- litert_cli-0.1.0.dist-info/WHEEL +5 -0
- litert_cli-0.1.0.dist-info/entry_points.txt +2 -0
- litert_cli-0.1.0.dist-info/licenses/LICENSE +202 -0
- litert_cli-0.1.0.dist-info/top_level.txt +3 -0
- tools/build_wheels.py +122 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
# Copyright 2026 The LiteRT CLI Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Command Line Interface for executing LiteRT models.
|
|
17
|
+
|
|
18
|
+
This module provides the `litert run` command, which allows users to
|
|
19
|
+
execute TFLite models on either the local desktop or a connected Android device.
|
|
20
|
+
|
|
21
|
+
Key Features:
|
|
22
|
+
- Desktop Execution: Uses the LiteRT Python API (`CompiledModel`) to run
|
|
23
|
+
inference locally. It automatically inspects the model signature,
|
|
24
|
+
generates appropriate dummy input data, and prints the output tensor
|
|
25
|
+
statistics.
|
|
26
|
+
- Android Execution: Seamlessly pushes the model and the compiled `run_model`
|
|
27
|
+
binary to an attached Android device via `adb`, and executes it remotely.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from __future__ import annotations
|
|
31
|
+
|
|
32
|
+
from collections.abc import Sequence
|
|
33
|
+
import textwrap
|
|
34
|
+
|
|
35
|
+
import click
|
|
36
|
+
from litert_cli.core import constants
|
|
37
|
+
from litert_cli.core import deps
|
|
38
|
+
from litert_cli.core import utils
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@click.command(
|
|
42
|
+
"run",
|
|
43
|
+
help=textwrap.dedent("""\
|
|
44
|
+
Run LiteRT models locally or on device.
|
|
45
|
+
|
|
46
|
+
MODEL: Path to the LiteRT model (.tflite) or a Model Reference (e.g., nvidia/parakeet-ctc-0.6b).
|
|
47
|
+
|
|
48
|
+
Examples:
|
|
49
|
+
|
|
50
|
+
1. Run on desktop (CPU) with dummy inputs:
|
|
51
|
+
|
|
52
|
+
$ litert run model.tflite
|
|
53
|
+
|
|
54
|
+
2. Run on desktop with GPU acceleration:
|
|
55
|
+
|
|
56
|
+
$ litert run model.tflite --gpu
|
|
57
|
+
|
|
58
|
+
3. Run with custom inputs (path or literal):
|
|
59
|
+
|
|
60
|
+
$ litert run model.tflite --input image.jpg
|
|
61
|
+
|
|
62
|
+
OR with named inputs:
|
|
63
|
+
|
|
64
|
+
$ litert run model.tflite --input in1=1.0 --input in2=image.jpg
|
|
65
|
+
|
|
66
|
+
4. Run on an attached Android device:
|
|
67
|
+
|
|
68
|
+
$ litert run model.tflite --android
|
|
69
|
+
|
|
70
|
+
5. Run on Android with GPU acceleration:
|
|
71
|
+
|
|
72
|
+
$ litert run model.tflite --android --gpu
|
|
73
|
+
|
|
74
|
+
6. Benchmark execution with 10 iterations:
|
|
75
|
+
|
|
76
|
+
$ litert run model.tflite --iterations 10
|
|
77
|
+
|
|
78
|
+
7. Print detailed tensor outputs:
|
|
79
|
+
|
|
80
|
+
$ litert run model.tflite --print-tensors --sample-size 10
|
|
81
|
+
|
|
82
|
+
8. Run with multiple accelerators (npu -> gpu -> cpu fallback):
|
|
83
|
+
|
|
84
|
+
$ litert run model.tflite --npu --gpu --cpu
|
|
85
|
+
|
|
86
|
+
OR explicitly:
|
|
87
|
+
|
|
88
|
+
$ litert run model.tflite --accelerator npu,gpu,cpu
|
|
89
|
+
"""),
|
|
90
|
+
)
|
|
91
|
+
@deps.require_extra("run")
|
|
92
|
+
@click.argument("model", type=str)
|
|
93
|
+
@click.option(
|
|
94
|
+
"--input",
|
|
95
|
+
"inputs",
|
|
96
|
+
multiple=True,
|
|
97
|
+
help=(
|
|
98
|
+
"Input data for the model. Can be a literal array (e.g. '[1,2]'), "
|
|
99
|
+
"a path to an image/npy/raw file. "
|
|
100
|
+
"You can specify multiple inputs using format: --input name=value "
|
|
101
|
+
"or just --input value if the model has only one input."
|
|
102
|
+
),
|
|
103
|
+
)
|
|
104
|
+
@click.option(
|
|
105
|
+
"--model-params",
|
|
106
|
+
"model_params",
|
|
107
|
+
multiple=True,
|
|
108
|
+
help="Model specific parameters in format key=value.",
|
|
109
|
+
)
|
|
110
|
+
@click.option(
|
|
111
|
+
"--model-help",
|
|
112
|
+
is_flag=True,
|
|
113
|
+
default=False,
|
|
114
|
+
help="Show help specific to the matched model plugin.",
|
|
115
|
+
)
|
|
116
|
+
@click.option(
|
|
117
|
+
"--desktop",
|
|
118
|
+
"target",
|
|
119
|
+
flag_value="desktop",
|
|
120
|
+
default=True,
|
|
121
|
+
help="Target desktop platform to run (Default).",
|
|
122
|
+
)
|
|
123
|
+
@click.option(
|
|
124
|
+
"--android",
|
|
125
|
+
"target",
|
|
126
|
+
flag_value="android",
|
|
127
|
+
help="Target Android platform to run.",
|
|
128
|
+
)
|
|
129
|
+
@click.option(
|
|
130
|
+
"--accelerator",
|
|
131
|
+
type=str,
|
|
132
|
+
help="Comma-separated list of hardware accelerators (e.g. npu,gpu,cpu).",
|
|
133
|
+
)
|
|
134
|
+
@click.option(
|
|
135
|
+
"--cpu",
|
|
136
|
+
is_flag=True,
|
|
137
|
+
help="Use CPU accelerator.",
|
|
138
|
+
)
|
|
139
|
+
@click.option(
|
|
140
|
+
"--gpu",
|
|
141
|
+
is_flag=True,
|
|
142
|
+
help="Use GPU accelerator.",
|
|
143
|
+
)
|
|
144
|
+
@click.option(
|
|
145
|
+
"--npu",
|
|
146
|
+
is_flag=True,
|
|
147
|
+
help="Use NPU accelerator.",
|
|
148
|
+
)
|
|
149
|
+
@click.option(
|
|
150
|
+
"--signature-index",
|
|
151
|
+
type=int,
|
|
152
|
+
default=0,
|
|
153
|
+
help="Index of model signature to run. Default is 0.",
|
|
154
|
+
)
|
|
155
|
+
@click.option(
|
|
156
|
+
"--iterations",
|
|
157
|
+
type=int,
|
|
158
|
+
default=1,
|
|
159
|
+
help="Number of times to execute the model for benchmarking. Default is 1.",
|
|
160
|
+
)
|
|
161
|
+
@click.option(
|
|
162
|
+
"--print-tensors",
|
|
163
|
+
is_flag=True,
|
|
164
|
+
default=False,
|
|
165
|
+
help="Print output tensor values after execution.",
|
|
166
|
+
)
|
|
167
|
+
@click.option(
|
|
168
|
+
"--sample-size",
|
|
169
|
+
type=int,
|
|
170
|
+
default=5,
|
|
171
|
+
help="Number of sample elements to print from tensors. Default is 5.",
|
|
172
|
+
)
|
|
173
|
+
@click.pass_context
|
|
174
|
+
def run_cmd(
|
|
175
|
+
unused_ctx: click.Context,
|
|
176
|
+
model: str,
|
|
177
|
+
inputs: Sequence[str],
|
|
178
|
+
model_params: Sequence[str],
|
|
179
|
+
model_help: bool,
|
|
180
|
+
target: str,
|
|
181
|
+
accelerator: str | None,
|
|
182
|
+
cpu: bool,
|
|
183
|
+
gpu: bool,
|
|
184
|
+
npu: bool,
|
|
185
|
+
signature_index: int,
|
|
186
|
+
iterations: int,
|
|
187
|
+
print_tensors: bool,
|
|
188
|
+
sample_size: int,
|
|
189
|
+
) -> None:
|
|
190
|
+
r"""Runs LiteRT models locally or on device.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
unused_ctx: Click context.
|
|
194
|
+
model: Path to the LiteRT model (.tflite).
|
|
195
|
+
inputs: Tuple of input assignments (e.g., 'name=value' or just 'value').
|
|
196
|
+
model_params: Model specific parameters.
|
|
197
|
+
model_help: Show help specific to the matched model plugin.
|
|
198
|
+
target: Execution target ('desktop' or 'android').
|
|
199
|
+
accelerator: Hardware accelerator ('cpu', 'gpu', or 'npu').
|
|
200
|
+
cpu: Use CPU accelerator.
|
|
201
|
+
gpu: Use GPU accelerator.
|
|
202
|
+
npu: Use NPU accelerator.
|
|
203
|
+
signature_index: Index of model signature to run.
|
|
204
|
+
iterations: Number of times to execute the model for benchmarking.
|
|
205
|
+
print_tensors: Whether to print output tensor elements.
|
|
206
|
+
sample_size: Number of sample elements to print from tensors.
|
|
207
|
+
"""
|
|
208
|
+
# Resolve the order of accelerators
|
|
209
|
+
accelerator_list = []
|
|
210
|
+
if accelerator:
|
|
211
|
+
accelerator_list = [
|
|
212
|
+
a.strip().lower() for a in accelerator.split(",") if a.strip()
|
|
213
|
+
]
|
|
214
|
+
else:
|
|
215
|
+
if npu:
|
|
216
|
+
accelerator_list.append("npu")
|
|
217
|
+
if gpu:
|
|
218
|
+
accelerator_list.append("gpu")
|
|
219
|
+
if cpu:
|
|
220
|
+
accelerator_list.append("cpu")
|
|
221
|
+
|
|
222
|
+
if not accelerator_list:
|
|
223
|
+
accelerator_list = ["cpu"]
|
|
224
|
+
|
|
225
|
+
accelerator = ",".join(accelerator_list)
|
|
226
|
+
|
|
227
|
+
# Quiet if default is true
|
|
228
|
+
if constants.DEFAULT_QUIET:
|
|
229
|
+
|
|
230
|
+
utils.enable_quiet_mode()
|
|
231
|
+
|
|
232
|
+
# --- Model Reference and Cache Resolution ---
|
|
233
|
+
from litert_cli.core import models as core_models # pylint: disable=g-import-not-at-top
|
|
234
|
+
|
|
235
|
+
resolved_model_path, resolved_hf_id = core_models.resolve_model_reference(
|
|
236
|
+
model
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if resolved_model_path != model:
|
|
240
|
+
click.echo(f"Resolved model '{model}' to '{resolved_model_path}'")
|
|
241
|
+
|
|
242
|
+
# --- Plugin Dispatch Mechanism ---
|
|
243
|
+
# Try to delegate to a model-specific plugin first.
|
|
244
|
+
from litert_cli import models # pylint: disable=g-import-not-at-top
|
|
245
|
+
|
|
246
|
+
# Parse model-params into a dictionary
|
|
247
|
+
parsed_model_params = {}
|
|
248
|
+
if model_params:
|
|
249
|
+
for p in model_params:
|
|
250
|
+
if "=" in p:
|
|
251
|
+
k, v = p.split("=", 1)
|
|
252
|
+
parsed_model_params[k] = v
|
|
253
|
+
|
|
254
|
+
# Pass the resolved hf_id as model_id to dispatch, and the actual file path
|
|
255
|
+
# in kwargs
|
|
256
|
+
plugin_result = models.dispatch_model_intent(
|
|
257
|
+
intent="run",
|
|
258
|
+
model_id=resolved_hf_id or str(model),
|
|
259
|
+
inputs=inputs,
|
|
260
|
+
model_help=model_help,
|
|
261
|
+
model_params=parsed_model_params,
|
|
262
|
+
target=target,
|
|
263
|
+
accelerator=accelerator,
|
|
264
|
+
model_path=resolved_model_path, # Pass the actual file path here!
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
if plugin_result is not None:
|
|
268
|
+
# If the plugin handled it or showed help, we exit
|
|
269
|
+
return
|
|
270
|
+
# ----------------------------------
|
|
271
|
+
|
|
272
|
+
if target == "desktop":
|
|
273
|
+
from litert_cli.commands.run import desktop # pylint: disable=g-import-not-at-top
|
|
274
|
+
|
|
275
|
+
desktop.run_desktop(
|
|
276
|
+
model_path=str(resolved_model_path),
|
|
277
|
+
inputs=inputs,
|
|
278
|
+
accelerator=accelerator,
|
|
279
|
+
signature_index=signature_index,
|
|
280
|
+
iterations=iterations,
|
|
281
|
+
print_tensors=print_tensors,
|
|
282
|
+
sample_size=sample_size,
|
|
283
|
+
)
|
|
284
|
+
elif target == "android":
|
|
285
|
+
from litert_cli.commands.run import android # pylint: disable=g-import-not-at-top
|
|
286
|
+
|
|
287
|
+
android.run_android(
|
|
288
|
+
model_path=str(resolved_model_path),
|
|
289
|
+
inputs=inputs,
|
|
290
|
+
accelerator=accelerator,
|
|
291
|
+
signature_index=signature_index,
|
|
292
|
+
iterations=iterations,
|
|
293
|
+
print_tensors=print_tensors,
|
|
294
|
+
sample_size=sample_size,
|
|
295
|
+
)
|
|
296
|
+
else:
|
|
297
|
+
click.secho(f"Target '{target}' is not yet supported.", fg="red")
|
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
# Copyright 2026 The LiteRT CLI Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Desktop execution engine for LiteRT models.
|
|
17
|
+
|
|
18
|
+
Uses CompiledModel to load and run models on desktop (CPU/GPU).
|
|
19
|
+
|
|
20
|
+
Usage Examples:
|
|
21
|
+
1. Run a model on desktop (CPU):
|
|
22
|
+
$ litert run /path/to/model.tflite --desktop
|
|
23
|
+
|
|
24
|
+
2. Run with GPU acceleration:
|
|
25
|
+
$ litert run /path/to/model.tflite --desktop --gpu
|
|
26
|
+
OR
|
|
27
|
+
$ litert run /path/to/model.tflite --desktop --accelerator gpu
|
|
28
|
+
|
|
29
|
+
3. Run with multiple accelerators (gpu -> cpu native fallback):
|
|
30
|
+
$ litert run /path/to/model.tflite --desktop --gpu --cpu
|
|
31
|
+
OR
|
|
32
|
+
$ litert run /path/to/model.tflite --desktop --accelerator gpu,cpu
|
|
33
|
+
|
|
34
|
+
4. Run with custom inputs:
|
|
35
|
+
$ litert run /path/to/model.tflite --desktop --input input_name=value
|
|
36
|
+
|
|
37
|
+
5. Run with multiple iterations (benchmark):
|
|
38
|
+
$ litert run /path/to/model.tflite --desktop --iterations 10
|
|
39
|
+
|
|
40
|
+
6. Print tensor details:
|
|
41
|
+
$ litert run /path/to/model.tflite --desktop --print-tensors
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
from __future__ import annotations
|
|
45
|
+
|
|
46
|
+
from collections.abc import Mapping, Sequence
|
|
47
|
+
import contextlib
|
|
48
|
+
import time
|
|
49
|
+
from typing import Any, TYPE_CHECKING
|
|
50
|
+
|
|
51
|
+
import click
|
|
52
|
+
from litert_cli.core import constants
|
|
53
|
+
from litert_cli.core import inputs as inputs_utils
|
|
54
|
+
from litert_cli.core import utils
|
|
55
|
+
import numpy as np
|
|
56
|
+
|
|
57
|
+
if TYPE_CHECKING:
|
|
58
|
+
# Import heavy dependencies only for type hinting to improve CLI startup
|
|
59
|
+
# performance. These are not imported at runtime.
|
|
60
|
+
from ai_edge_litert.compiled_model import CompiledModel # pylint: disable=g-import-not-at-top
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _parse_inputs_dict(inputs: Sequence[str]) -> dict[str, str]:
|
|
64
|
+
"""Parse a tuple of input assignments into a dictionary.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
inputs: A tuple of input strings, e.g., ('name=value', 'value2').
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
A dictionary mapping names to values. Unnamed inputs use '_default_'.
|
|
71
|
+
"""
|
|
72
|
+
parsed_inputs = {}
|
|
73
|
+
if inputs:
|
|
74
|
+
for inp in inputs:
|
|
75
|
+
if "=" in inp:
|
|
76
|
+
k, v = inp.split("=", 1)
|
|
77
|
+
parsed_inputs[k] = v
|
|
78
|
+
else:
|
|
79
|
+
parsed_inputs["_default_"] = inp
|
|
80
|
+
return parsed_inputs
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _prepare_inputs(
|
|
84
|
+
*,
|
|
85
|
+
cm: CompiledModel,
|
|
86
|
+
sig_key: str,
|
|
87
|
+
parsed_inputs: dict[str, str],
|
|
88
|
+
) -> dict[str, Any]:
|
|
89
|
+
"""Prepare CompiledModel input buffers.
|
|
90
|
+
|
|
91
|
+
Load parsed input assignments or generates random dummy data to load into the
|
|
92
|
+
CompiledModel TensorBuffers.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
cm: The loaded CompiledModel structure to interact with.
|
|
96
|
+
sig_key: Signature key describing the input interface.
|
|
97
|
+
parsed_inputs: Dictionary mapping input names to file path/literal strings.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
A dictionary mapping tensor names to their populated TensorBuffers.
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
click.ClickException: If input loading or parsing fails.
|
|
104
|
+
"""
|
|
105
|
+
inputs_dict = {}
|
|
106
|
+
input_details = cm.get_input_tensor_details(sig_key)
|
|
107
|
+
|
|
108
|
+
for name, details in input_details.items():
|
|
109
|
+
shape = details.get("shape", [1])
|
|
110
|
+
tensor_type = details.get("dtype", "?")
|
|
111
|
+
np_dtype = inputs_utils.get_np_dtype(tensor_type)
|
|
112
|
+
|
|
113
|
+
input_data_str = parsed_inputs.get(name) or parsed_inputs.get("_default_")
|
|
114
|
+
|
|
115
|
+
if input_data_str:
|
|
116
|
+
click.echo(
|
|
117
|
+
f"Loading input {name!r} from {input_data_str!r} (shape:"
|
|
118
|
+
f" {shape}, dtype: {tensor_type})"
|
|
119
|
+
)
|
|
120
|
+
try:
|
|
121
|
+
input_data = inputs_utils.parse_input(input_data_str, shape, np_dtype)
|
|
122
|
+
except ImportError as ie:
|
|
123
|
+
click.secho(ie, fg="red")
|
|
124
|
+
raise click.ClickException("Failed to load input module.") from ie
|
|
125
|
+
except Exception as e:
|
|
126
|
+
click.secho(f"Failed to parse input: {e!r}", fg="red")
|
|
127
|
+
raise click.ClickException(
|
|
128
|
+
f"Failed to parse input {name!r}: {e!r}"
|
|
129
|
+
) from e
|
|
130
|
+
else:
|
|
131
|
+
click.echo(
|
|
132
|
+
f"Generating random dummy input {name!r} with shape {shape} and"
|
|
133
|
+
f" dtype {tensor_type}"
|
|
134
|
+
)
|
|
135
|
+
rng = np.random.default_rng()
|
|
136
|
+
if np.issubdtype(np_dtype, np.integer):
|
|
137
|
+
input_data = rng.integers(0, 10, size=shape, dtype=np_dtype)
|
|
138
|
+
else:
|
|
139
|
+
input_data = np.asarray(
|
|
140
|
+
rng.uniform(low=-1.0, high=1.0, size=shape)
|
|
141
|
+
).astype(np_dtype)
|
|
142
|
+
|
|
143
|
+
tb = cm.create_input_buffer_by_name(sig_key, name)
|
|
144
|
+
tb.write(input_data)
|
|
145
|
+
inputs_dict[name] = tb
|
|
146
|
+
|
|
147
|
+
return inputs_dict
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _print_outputs(
|
|
151
|
+
outputs_by_name: Mapping[str, Any],
|
|
152
|
+
print_tensors: bool,
|
|
153
|
+
sample_size: int,
|
|
154
|
+
output_details: Mapping[str, Any],
|
|
155
|
+
) -> None:
|
|
156
|
+
"""Print inference outputs to stdout.
|
|
157
|
+
|
|
158
|
+
Iterate through absolute tensor results and applies heuristics for
|
|
159
|
+
classification formatting or raw values flattening details.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
outputs_by_name: Dictionary mapping output names to read-ready
|
|
163
|
+
TensorBuffers.
|
|
164
|
+
print_tensors: Boolean flag to trigger full tensor stream printing.
|
|
165
|
+
sample_size: Constraint on how many elements to print for large arrays.
|
|
166
|
+
output_details: Dictionary mapping output names to their tensor details.
|
|
167
|
+
"""
|
|
168
|
+
click.echo("Outputs:")
|
|
169
|
+
for out_name, out_tb in outputs_by_name.items():
|
|
170
|
+
try:
|
|
171
|
+
shape = out_tb.shape if hasattr(out_tb, "shape") else []
|
|
172
|
+
num_elements = np.prod(shape) if shape else 1
|
|
173
|
+
|
|
174
|
+
details = output_details.get(out_name, {})
|
|
175
|
+
tensor_type = details.get("dtype", "?")
|
|
176
|
+
np_dtype = inputs_utils.get_np_dtype(tensor_type)
|
|
177
|
+
|
|
178
|
+
out_np = out_tb.read(num_elements, np_dtype)
|
|
179
|
+
|
|
180
|
+
if shape:
|
|
181
|
+
out_np = out_np.reshape(shape)
|
|
182
|
+
|
|
183
|
+
if print_tensors:
|
|
184
|
+
flat_out = out_np.ravel()
|
|
185
|
+
n_elem = len(flat_out)
|
|
186
|
+
click.echo(f" {out_name} (shape: {shape}):")
|
|
187
|
+
|
|
188
|
+
if n_elem <= sample_size * 2:
|
|
189
|
+
click.echo(f" {flat_out}")
|
|
190
|
+
else:
|
|
191
|
+
p_start = flat_out[:sample_size]
|
|
192
|
+
p_end = flat_out[-sample_size:]
|
|
193
|
+
click.echo(
|
|
194
|
+
f" [{' '.join(str(x) for x in p_start)} ..."
|
|
195
|
+
f" {' '.join(str(x) for x in p_end)}]"
|
|
196
|
+
)
|
|
197
|
+
else:
|
|
198
|
+
# Classification inference heuristics fallback
|
|
199
|
+
if (len(shape) == 1 and shape[0] > 1) or (
|
|
200
|
+
len(shape) == 2 and shape[0] == 1 and shape[1] > 1
|
|
201
|
+
):
|
|
202
|
+
scores = out_np.flatten()
|
|
203
|
+
n_top = min(5, len(scores))
|
|
204
|
+
top_indices = np.argsort(scores)[-n_top:][::-1]
|
|
205
|
+
|
|
206
|
+
click.echo(f" {out_name} (Top {n_top} Predictions):")
|
|
207
|
+
for i, idx in enumerate(top_indices):
|
|
208
|
+
click.echo(f" {i+1}: index {idx} - score {scores[idx]:.4f}")
|
|
209
|
+
else:
|
|
210
|
+
click.echo(
|
|
211
|
+
f" {out_name}: mean={np.mean(out_np):.4f},"
|
|
212
|
+
f" min={np.min(out_np):.4f}, max={np.max(out_np):.4f}"
|
|
213
|
+
)
|
|
214
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
215
|
+
click.echo(
|
|
216
|
+
f" {out_name}: [Unable to read data natively without specific"
|
|
217
|
+
f" dtype info] (Error: {e!r})"
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def run_desktop(
|
|
222
|
+
*,
|
|
223
|
+
model_path: str,
|
|
224
|
+
inputs: Sequence[str],
|
|
225
|
+
accelerator: str,
|
|
226
|
+
signature_index: int,
|
|
227
|
+
iterations: int,
|
|
228
|
+
print_tensors: bool,
|
|
229
|
+
sample_size: int,
|
|
230
|
+
quiet: bool = False,
|
|
231
|
+
) -> None:
|
|
232
|
+
"""Runs the model on the desktop target using CompiledModel.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
model_path: Local path to the LiteRT model file (.tflite).
|
|
236
|
+
inputs: Tuple of input assignments (e.g., 'name=value').
|
|
237
|
+
accelerator: Hardware accelerator ('cpu', 'gpu', 'npu').
|
|
238
|
+
signature_index: Signature index to execute.
|
|
239
|
+
iterations: Number of execute loops for remote runner.
|
|
240
|
+
print_tensors: Whether to print absolute stats after execution completes.
|
|
241
|
+
sample_size: Limit execution sample stream print length per tensor.
|
|
242
|
+
quiet: Whether to silence stderr output.
|
|
243
|
+
|
|
244
|
+
Raises:
|
|
245
|
+
click.ClickException: On loading failure or inference execution errors.
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
accel_list = [a.strip().lower() for a in accelerator.split(",") if a.strip()]
|
|
249
|
+
|
|
250
|
+
# pylint: disable=g-import-not-at-top,reimported
|
|
251
|
+
from ai_edge_litert.compiled_model import CompiledModel
|
|
252
|
+
from ai_edge_litert.compiled_model import Environment
|
|
253
|
+
from ai_edge_litert.hardware_accelerator import HardwareAccelerator
|
|
254
|
+
|
|
255
|
+
hw_accel = HardwareAccelerator(0)
|
|
256
|
+
for accel in accel_list:
|
|
257
|
+
if accel == "cpu":
|
|
258
|
+
hw_accel |= HardwareAccelerator.CPU
|
|
259
|
+
elif accel == "gpu":
|
|
260
|
+
hw_accel |= HardwareAccelerator.GPU
|
|
261
|
+
elif accel == "npu":
|
|
262
|
+
hw_accel |= HardwareAccelerator.NPU
|
|
263
|
+
else:
|
|
264
|
+
raise click.ClickException(f"Unsupported hardware accelerator: {accel!r}")
|
|
265
|
+
|
|
266
|
+
if hw_accel == HardwareAccelerator(0):
|
|
267
|
+
hw_accel = HardwareAccelerator.CPU
|
|
268
|
+
|
|
269
|
+
click.echo(
|
|
270
|
+
f"Loading model on desktop: {model_path} with native hardware"
|
|
271
|
+
f" accelerators: {hw_accel}"
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
ctx = utils.silence_stderr() if quiet else contextlib.nullcontext()
|
|
275
|
+
with ctx:
|
|
276
|
+
try:
|
|
277
|
+
env = None
|
|
278
|
+
if constants.IS_INTERNAL_ENV:
|
|
279
|
+
# In internal environment, we need to fallback to LD_LIBRARY_PATH for
|
|
280
|
+
# loading GPU accelerators in hermetic .par file. Otherwise, use
|
|
281
|
+
# default path.
|
|
282
|
+
env = Environment.create(runtime_path="")
|
|
283
|
+
cm = CompiledModel.from_file(
|
|
284
|
+
model_path, hw_accel, environment=env
|
|
285
|
+
)
|
|
286
|
+
signatures = cm.get_signature_list()
|
|
287
|
+
if not signatures:
|
|
288
|
+
raise click.ClickException(
|
|
289
|
+
f"No signatures found in the model: {model_path!r}"
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
try:
|
|
293
|
+
sig_info = cm.get_signature_by_index(signature_index)
|
|
294
|
+
sig_key = sig_info["key"]
|
|
295
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
296
|
+
raise click.ClickException(
|
|
297
|
+
f"Failed to get signature at index {signature_index}: {e!r}"
|
|
298
|
+
) from e
|
|
299
|
+
|
|
300
|
+
click.echo(f"Using signature: {sig_key!r}")
|
|
301
|
+
|
|
302
|
+
parsed_inputs = _parse_inputs_dict(inputs)
|
|
303
|
+
inputs_dict = _prepare_inputs(
|
|
304
|
+
cm=cm, sig_key=sig_key, parsed_inputs=parsed_inputs
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
click.echo(f"Running inference {iterations} times...")
|
|
308
|
+
|
|
309
|
+
run_times = []
|
|
310
|
+
|
|
311
|
+
sig_idx = cm.get_signature_index(sig_key)
|
|
312
|
+
out_buffers = cm.create_output_buffers(sig_idx)
|
|
313
|
+
out_names = signatures[sig_key]["outputs"]
|
|
314
|
+
outputs_by_name = dict(zip(out_names, out_buffers))
|
|
315
|
+
|
|
316
|
+
for _ in range(iterations):
|
|
317
|
+
start_time = time.perf_counter()
|
|
318
|
+
cm.run_by_name(sig_key, inputs_dict, outputs_by_name)
|
|
319
|
+
end_time = time.perf_counter()
|
|
320
|
+
run_times.append((end_time - start_time) * 1000)
|
|
321
|
+
|
|
322
|
+
if iterations == 1:
|
|
323
|
+
click.echo(f"Inference complete in {run_times[0]:.2f} ms")
|
|
324
|
+
else:
|
|
325
|
+
click.echo(f"Benchmark results ({iterations} iterations):")
|
|
326
|
+
click.echo(f" First run: {run_times[0]:.2f} ms")
|
|
327
|
+
click.echo(f" Average: {np.mean(run_times):.2f} ms")
|
|
328
|
+
click.echo(f" Min: {np.min(run_times):.2f} ms")
|
|
329
|
+
click.echo(f" Max: {np.max(run_times):.2f} ms")
|
|
330
|
+
|
|
331
|
+
output_details = cm.get_output_tensor_details(sig_key)
|
|
332
|
+
_print_outputs(
|
|
333
|
+
outputs_by_name, print_tensors, sample_size, output_details
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
337
|
+
raise click.ClickException(
|
|
338
|
+
f"Inference failed for model {model_path!r} with accelerator"
|
|
339
|
+
f" {accelerator!r}: {e!r}"
|
|
340
|
+
) from e
|