litert-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/litert_cli.ipynb +313 -0
- examples/models/presets/default.py +19 -0
- examples/run_cli_demo.sh +38 -0
- examples/run_cli_npu.sh +89 -0
- examples/run_commands.sh +67 -0
- examples/run_models.sh +63 -0
- examples/run_smoke_tests.sh +58 -0
- examples/utils.ps1 +163 -0
- examples/utils.sh +184 -0
- litert_cli/__init__.py +15 -0
- litert_cli/commands/benchmark/__init__.py +16 -0
- litert_cli/commands/benchmark/android.py +212 -0
- litert_cli/commands/benchmark/cli.py +294 -0
- litert_cli/commands/benchmark/desktop.py +228 -0
- litert_cli/commands/benchmark/gcp.py +336 -0
- litert_cli/commands/clean.py +73 -0
- litert_cli/commands/compile.py +211 -0
- litert_cli/commands/convert/__init__.py +20 -0
- litert_cli/commands/convert/cli.py +255 -0
- litert_cli/commands/convert/generic.py +211 -0
- litert_cli/commands/convert/huggingface.py +175 -0
- litert_cli/commands/delete.py +56 -0
- litert_cli/commands/download.py +274 -0
- litert_cli/commands/import.py +124 -0
- litert_cli/commands/list.py +132 -0
- litert_cli/commands/lm.py +74 -0
- litert_cli/commands/quantize.py +193 -0
- litert_cli/commands/run/__init__.py +16 -0
- litert_cli/commands/run/android.py +394 -0
- litert_cli/commands/run/cli.py +297 -0
- litert_cli/commands/run/desktop.py +340 -0
- litert_cli/commands/visualize.py +234 -0
- litert_cli/core/android_utils.py +304 -0
- litert_cli/core/android_utils_test.py +236 -0
- litert_cli/core/constants.py +131 -0
- litert_cli/core/deps.py +180 -0
- litert_cli/core/deps_test.py +101 -0
- litert_cli/core/inputs.py +203 -0
- litert_cli/core/inputs_test.py +176 -0
- litert_cli/core/log_filters.py +50 -0
- litert_cli/core/models.py +96 -0
- litert_cli/core/npu_utils.py +382 -0
- litert_cli/core/targets_manager.py +192 -0
- litert_cli/core/utils.py +58 -0
- litert_cli/litert.py +119 -0
- litert_cli/litert_help_test.py +51 -0
- litert_cli/litert_test.py +88 -0
- litert_cli/models/__init__.py +145 -0
- litert_cli/models/asr/__init__.py +15 -0
- litert_cli/models/asr/asr_model.py +108 -0
- litert_cli/models/asr/parakeet_ctc.py +165 -0
- litert_cli/models/asr/runner.py +482 -0
- litert_cli/models/base.py +57 -0
- litert_cli/test_data/dummy_calib_data.py +26 -0
- litert_cli/test_data/dummy_cv_model.py +52 -0
- litert_cli/test_data/dummy_cv_model.tflite +0 -0
- litert_cli/test_data/generate_test_inputs.py +51 -0
- litert_cli/test_data/mobilenet_v3_calib_data.py +25 -0
- litert_cli/test_data/quantize_recipe.json +16 -0
- litert_cli/test_data/resnet18.py +31 -0
- litert_cli-0.1.0.dist-info/METADATA +38 -0
- litert_cli-0.1.0.dist-info/RECORD +67 -0
- litert_cli-0.1.0.dist-info/WHEEL +5 -0
- litert_cli-0.1.0.dist-info/entry_points.txt +2 -0
- litert_cli-0.1.0.dist-info/licenses/LICENSE +202 -0
- litert_cli-0.1.0.dist-info/top_level.txt +3 -0
- tools/build_wheels.py +122 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
# Copyright 2026 The LiteRT CLI Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""CLI module for the `litert convert` command.
|
|
17
|
+
|
|
18
|
+
This command handles converting PyTorch and Hugging Face models into LiteRT
|
|
19
|
+
models using various conversion paths like automated HF export, generic script
|
|
20
|
+
injection, and native generative API re-authoring.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import pathlib
|
|
26
|
+
import textwrap
|
|
27
|
+
|
|
28
|
+
import click
|
|
29
|
+
from litert_cli.core import deps
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@click.command(
|
|
33
|
+
"convert",
|
|
34
|
+
help=textwrap.dedent("""\
|
|
35
|
+
Convert a PyTorch model into a LiteRT model.
|
|
36
|
+
|
|
37
|
+
MODEL_OR_SCRIPT: Hugging Face model ID or path to a PyTorch script.
|
|
38
|
+
|
|
39
|
+
Note: Only AutoModelForCausalLM models are supported, when using HF mode.
|
|
40
|
+
|
|
41
|
+
Examples:
|
|
42
|
+
|
|
43
|
+
Automated HF Conversion:
|
|
44
|
+
|
|
45
|
+
$ litert convert Qwen/Qwen1.5-0.5B-Chat --output /tmp/qwen
|
|
46
|
+
|
|
47
|
+
HF Conversion with Weight-Only INT4 Quantization:
|
|
48
|
+
|
|
49
|
+
$ litert convert Qwen/Qwen1.5-0.5B-Chat --quantize-recipe weight_only_wi4_afp32 --output /tmp/qwen_w4
|
|
50
|
+
|
|
51
|
+
HF Conversion with Custom Prefill & Cache Lengths:
|
|
52
|
+
|
|
53
|
+
$ litert convert Qwen/Qwen1.5-0.5B-Chat --prefill-lengths "128,512" --cache-length 2048 --output /tmp/qwen_custom
|
|
54
|
+
|
|
55
|
+
Generic Script Injection with Quantization:
|
|
56
|
+
|
|
57
|
+
$ litert convert my_model.py --quantize-recipe dynamic_wi8_afp32 --output /tmp/mymodel
|
|
58
|
+
|
|
59
|
+
Generic Script with Model Args and AOT Target Compilation:
|
|
60
|
+
|
|
61
|
+
$ litert convert my_model.py --model-args "batch_size=4" --target sm8450 --output /tmp/mymodel_npu
|
|
62
|
+
"""),
|
|
63
|
+
)
|
|
64
|
+
@deps.require_extra("convert")
|
|
65
|
+
@click.argument("model_or_script", type=str, required=True)
|
|
66
|
+
@click.option(
|
|
67
|
+
"--output",
|
|
68
|
+
type=click.Path(
|
|
69
|
+
file_okay=False,
|
|
70
|
+
dir_okay=True,
|
|
71
|
+
resolve_path=True,
|
|
72
|
+
path_type=pathlib.Path,
|
|
73
|
+
),
|
|
74
|
+
required=False,
|
|
75
|
+
default=None,
|
|
76
|
+
help="Directory to save the converted TFLite model.",
|
|
77
|
+
)
|
|
78
|
+
@click.option(
|
|
79
|
+
"--model-func",
|
|
80
|
+
type=str,
|
|
81
|
+
default="get_model",
|
|
82
|
+
help=(
|
|
83
|
+
"Name of the function in the --script that returns a torch.nn.Module"
|
|
84
|
+
" and optionally a quantization config. Default: 'get_model'."
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
@click.option(
|
|
88
|
+
"--input-func",
|
|
89
|
+
type=str,
|
|
90
|
+
default="get_args",
|
|
91
|
+
help=(
|
|
92
|
+
"Name of the function in the --script that returns (True) sample_args "
|
|
93
|
+
"and/or kwargs. Default: 'get_args'."
|
|
94
|
+
),
|
|
95
|
+
)
|
|
96
|
+
@click.option(
|
|
97
|
+
"--target",
|
|
98
|
+
type=str,
|
|
99
|
+
multiple=True,
|
|
100
|
+
help=(
|
|
101
|
+
"One or more NPU target codenames (e.g., sm8450) to apply AOT"
|
|
102
|
+
" compilation."
|
|
103
|
+
),
|
|
104
|
+
)
|
|
105
|
+
@click.option(
|
|
106
|
+
"--export-aipack",
|
|
107
|
+
type=click.Path(
|
|
108
|
+
file_okay=False,
|
|
109
|
+
dir_okay=True,
|
|
110
|
+
resolve_path=True,
|
|
111
|
+
path_type=pathlib.Path,
|
|
112
|
+
),
|
|
113
|
+
required=False,
|
|
114
|
+
default=None,
|
|
115
|
+
help=(
|
|
116
|
+
"If specified, exports an AI Pack directory for PODAI alongside the"
|
|
117
|
+
" compiled model."
|
|
118
|
+
),
|
|
119
|
+
)
|
|
120
|
+
@click.option(
|
|
121
|
+
"--quantize",
|
|
122
|
+
"--quantize-recipe",
|
|
123
|
+
type=str,
|
|
124
|
+
default=None,
|
|
125
|
+
help=(
|
|
126
|
+
"Quantization recipe to apply (e.g., dynamic_wi8_afp32,"
|
|
127
|
+
" weight_only_wi4_afp32). Alias: --quantize-recipe. For full list of"
|
|
128
|
+
" generative recipes, see 'ai_edge_quantizer.recipe'."
|
|
129
|
+
),
|
|
130
|
+
)
|
|
131
|
+
@click.option(
|
|
132
|
+
"--model-args",
|
|
133
|
+
type=str,
|
|
134
|
+
default=None,
|
|
135
|
+
help=(
|
|
136
|
+
"Comma-separated key=value arguments to pass to custom model/input"
|
|
137
|
+
" functions."
|
|
138
|
+
),
|
|
139
|
+
)
|
|
140
|
+
@click.option(
|
|
141
|
+
"--prefill-lengths",
|
|
142
|
+
type=str,
|
|
143
|
+
default="256",
|
|
144
|
+
help=(
|
|
145
|
+
"Comma-separated list of prefill lengths for HuggingFace models."
|
|
146
|
+
" Default: '256'."
|
|
147
|
+
),
|
|
148
|
+
)
|
|
149
|
+
@click.option(
|
|
150
|
+
"--cache-length",
|
|
151
|
+
type=int,
|
|
152
|
+
default=4096,
|
|
153
|
+
help="KV cache length for HuggingFace models. Default: 4096.",
|
|
154
|
+
)
|
|
155
|
+
@click.option(
|
|
156
|
+
"--bundle-litert-lm/--no-bundle-litert-lm",
|
|
157
|
+
is_flag=True,
|
|
158
|
+
default=True,
|
|
159
|
+
help=(
|
|
160
|
+
"Bundle exported artifacts into a .litert_lm package (HuggingFace mode"
|
|
161
|
+
" only). Default: True."
|
|
162
|
+
),
|
|
163
|
+
)
|
|
164
|
+
def convert_cmd(
|
|
165
|
+
model_or_script: str,
|
|
166
|
+
output: pathlib.Path | None,
|
|
167
|
+
model_func: str,
|
|
168
|
+
input_func: str,
|
|
169
|
+
target: tuple[str, ...],
|
|
170
|
+
export_aipack: pathlib.Path | None,
|
|
171
|
+
quantize: str | None,
|
|
172
|
+
model_args: str | None,
|
|
173
|
+
prefill_lengths: str,
|
|
174
|
+
cache_length: int,
|
|
175
|
+
bundle_litert_lm: bool,
|
|
176
|
+
) -> None:
|
|
177
|
+
r"""Converts a PyTorch model into a LiteRT model.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
model_or_script: Hugging Face model ID or path to a PyTorch script.
|
|
181
|
+
output: Output directory for the converted model.
|
|
182
|
+
model_func: Function to retrieve the model in 'script' mode.
|
|
183
|
+
input_func: Function to retrieve sample inputs in 'script' mode.
|
|
184
|
+
target: NPU targets to compile for.
|
|
185
|
+
export_aipack: Output directory to export the AI Pack for PODAI.
|
|
186
|
+
quantize: Quantization recipe to apply (see ai_edge_quantizer.recipe for
|
|
187
|
+
full list).
|
|
188
|
+
model_args: Arguments to pass to custom model/input functions.
|
|
189
|
+
prefill_lengths: List of prefill lengths for HuggingFace models.
|
|
190
|
+
cache_length: KV cache length for HuggingFace models.
|
|
191
|
+
bundle_litert_lm: Whether to bundle artifacts into a .litert_lm package.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
from litert_cli.core import constants, utils
|
|
195
|
+
import warnings
|
|
196
|
+
|
|
197
|
+
if constants.DEFAULT_QUIET:
|
|
198
|
+
utils.enable_quiet_mode()
|
|
199
|
+
|
|
200
|
+
# Suppress noisy warnings from torch, torchao, etc.
|
|
201
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
202
|
+
warnings.filterwarnings("ignore", category=SyntaxWarning)
|
|
203
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
204
|
+
|
|
205
|
+
if output is None:
|
|
206
|
+
if model_or_script.endswith(".py"):
|
|
207
|
+
base_name = pathlib.Path(model_or_script).stem
|
|
208
|
+
else:
|
|
209
|
+
base_name = pathlib.Path(model_or_script).name
|
|
210
|
+
output = pathlib.Path.cwd() / base_name
|
|
211
|
+
|
|
212
|
+
if constants.ENABLE_MODEL_PLUGINS:
|
|
213
|
+
from litert_cli.models import dispatch_model_intent
|
|
214
|
+
|
|
215
|
+
plugin_result = dispatch_model_intent(
|
|
216
|
+
"convert",
|
|
217
|
+
model_or_script,
|
|
218
|
+
output=output,
|
|
219
|
+
target=target,
|
|
220
|
+
quantize=quantize,
|
|
221
|
+
export_aipack=export_aipack,
|
|
222
|
+
model_args=model_args,
|
|
223
|
+
prefill_lengths=prefill_lengths,
|
|
224
|
+
cache_length=cache_length,
|
|
225
|
+
bundle_litert_lm=bundle_litert_lm,
|
|
226
|
+
)
|
|
227
|
+
if plugin_result is not None:
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
if model_or_script.endswith(".py"):
|
|
231
|
+
from litert_cli.commands.convert import generic # pylint: disable=g-import-not-at-top
|
|
232
|
+
|
|
233
|
+
generic.convert_generic_script(
|
|
234
|
+
model_or_script,
|
|
235
|
+
model_func,
|
|
236
|
+
input_func,
|
|
237
|
+
str(output),
|
|
238
|
+
target,
|
|
239
|
+
export_aipack,
|
|
240
|
+
quantize,
|
|
241
|
+
model_args,
|
|
242
|
+
)
|
|
243
|
+
else:
|
|
244
|
+
from litert_cli.commands.convert import huggingface # pylint: disable=g-import-not-at-top
|
|
245
|
+
|
|
246
|
+
huggingface.convert_huggingface(
|
|
247
|
+
model_or_script,
|
|
248
|
+
str(output),
|
|
249
|
+
target,
|
|
250
|
+
export_aipack,
|
|
251
|
+
quantize,
|
|
252
|
+
prefill_lengths,
|
|
253
|
+
cache_length,
|
|
254
|
+
bundle_litert_lm,
|
|
255
|
+
)
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright 2026 The LiteRT CLI Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Generic PyTorch script conversion logic for LiteRT CLI."""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import importlib.util
|
|
21
|
+
import pathlib
|
|
22
|
+
import shutil
|
|
23
|
+
import sys
|
|
24
|
+
from typing import Any
|
|
25
|
+
import uuid
|
|
26
|
+
|
|
27
|
+
import click
|
|
28
|
+
from litert_cli.core import npu_utils
|
|
29
|
+
|
|
30
|
+
from ai_edge_litert.aot.ai_pack import export_lib as ai_pack_export
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def convert_generic_script(
|
|
34
|
+
script: str,
|
|
35
|
+
model_func: str,
|
|
36
|
+
input_func: str,
|
|
37
|
+
output: str,
|
|
38
|
+
target: tuple[str, ...],
|
|
39
|
+
export_aipack: pathlib.Path | None,
|
|
40
|
+
quantize: str | None = None,
|
|
41
|
+
model_args: str | None = None,
|
|
42
|
+
) -> None:
|
|
43
|
+
"""Converts using generic PyTorch scripts and `litert_torch.convert`.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
script: Path to the PyTorch script (.py).
|
|
47
|
+
model_func: Name of function returning the `torch.nn.Module`.
|
|
48
|
+
input_func: Name of function returning sample inputs.
|
|
49
|
+
output: Directory to save the converted model.
|
|
50
|
+
target: NPU targets to apply AOT compilation.
|
|
51
|
+
export_aipack: Output directory to export the AI Pack for PODAI.
|
|
52
|
+
quantize: Quantization recipe to apply.
|
|
53
|
+
model_args: Arguments to pass to custom model/input functions.
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ImportError: If the script loading fails.
|
|
57
|
+
AttributeError: If functions are missing in the user script.
|
|
58
|
+
ValueError: If inputs result shape is not supported.
|
|
59
|
+
"""
|
|
60
|
+
# pylint: disable=g-import-not-at-top
|
|
61
|
+
import litert_torch
|
|
62
|
+
|
|
63
|
+
script_path = pathlib.Path(script).resolve()
|
|
64
|
+
click.echo(f"Loading custom script from: {script_path}")
|
|
65
|
+
|
|
66
|
+
# Dynamically load the user's python file with sandbox isolation
|
|
67
|
+
module_name = f"user_model_script_{uuid.uuid4().hex}"
|
|
68
|
+
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
|
69
|
+
if spec is None or spec.loader is None:
|
|
70
|
+
raise ImportError(f"Could not load script {script_path}")
|
|
71
|
+
|
|
72
|
+
user_module = importlib.util.module_from_spec(spec)
|
|
73
|
+
sys.modules[module_name] = user_module
|
|
74
|
+
|
|
75
|
+
# Add the script's directory to sys.path so it can resolve its own local imports
|
|
76
|
+
sys.path.insert(0, str(script_path.parent))
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
spec.loader.exec_module(user_module)
|
|
80
|
+
|
|
81
|
+
# Look up the model factory
|
|
82
|
+
if not hasattr(user_module, model_func):
|
|
83
|
+
raise AttributeError(
|
|
84
|
+
f"Function '{model_func}' not found in {script_path}"
|
|
85
|
+
)
|
|
86
|
+
# Look up the args factory
|
|
87
|
+
if not hasattr(user_module, input_func):
|
|
88
|
+
raise AttributeError(
|
|
89
|
+
f"Function '{input_func}' not found in {script_path}"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Parse model_args
|
|
93
|
+
parsed_args = {}
|
|
94
|
+
if model_args:
|
|
95
|
+
for item in model_args.split(","):
|
|
96
|
+
if "=" in item:
|
|
97
|
+
k, v = item.split("=", 1)
|
|
98
|
+
if v.isdigit():
|
|
99
|
+
parsed_args[k] = int(v)
|
|
100
|
+
elif v.replace(".", "", 1).isdigit():
|
|
101
|
+
parsed_args[k] = float(v)
|
|
102
|
+
elif v.lower() == "true":
|
|
103
|
+
parsed_args[k] = True
|
|
104
|
+
elif v.lower() == "false":
|
|
105
|
+
parsed_args[k] = False
|
|
106
|
+
else:
|
|
107
|
+
parsed_args[k] = v
|
|
108
|
+
|
|
109
|
+
click.echo(
|
|
110
|
+
f"Instantiating model via '{model_func}' with args: {parsed_args}..."
|
|
111
|
+
)
|
|
112
|
+
model_result = getattr(user_module, model_func)(**parsed_args)
|
|
113
|
+
|
|
114
|
+
# The user might return just the nn.Module or a tuple of (nn.Module, QuantConfig)
|
|
115
|
+
if isinstance(model_result, tuple):
|
|
116
|
+
model, quant_config = model_result
|
|
117
|
+
else:
|
|
118
|
+
model = model_result
|
|
119
|
+
quant_config = None
|
|
120
|
+
|
|
121
|
+
if quant_config is None and quantize:
|
|
122
|
+
click.echo(f"Building dynamic QuantConfig for recipe '{quantize}'...")
|
|
123
|
+
from litert_torch.quantize import quant_config as quant_config_lib
|
|
124
|
+
|
|
125
|
+
if "pt2e" in quantize.lower():
|
|
126
|
+
from litert_torch.quantize import pt2e_quantizer as pt2eq
|
|
127
|
+
|
|
128
|
+
is_dynamic = "dynamic" in quantize.lower()
|
|
129
|
+
is_per_channel = "per_channel" in quantize.lower()
|
|
130
|
+
pt2e_cfg = pt2eq.get_symmetric_quantization_config(
|
|
131
|
+
is_per_channel=is_per_channel, is_dynamic=is_dynamic
|
|
132
|
+
)
|
|
133
|
+
q = pt2eq.PT2EQuantizer().set_global(pt2e_cfg)
|
|
134
|
+
quant_config = quant_config_lib.QuantConfig(pt2e_quantizer=q)
|
|
135
|
+
else:
|
|
136
|
+
from litert_torch.generative.quantize import quant_recipe
|
|
137
|
+
from litert_torch.generative.quantize import quant_recipe_utils
|
|
138
|
+
|
|
139
|
+
if "weight_only" in quantize.lower():
|
|
140
|
+
recipe_obj = quant_recipe_utils.create_layer_quant_weight_only()
|
|
141
|
+
elif "fp16" in quantize.lower():
|
|
142
|
+
recipe_obj = quant_recipe_utils.create_layer_quant_fp16()
|
|
143
|
+
else:
|
|
144
|
+
recipe_obj = quant_recipe_utils.create_layer_quant_dynamic()
|
|
145
|
+
|
|
146
|
+
quant_config = quant_config_lib.QuantConfig(
|
|
147
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
148
|
+
default=recipe_obj
|
|
149
|
+
)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
click.echo(
|
|
153
|
+
f"Generating sample inputs via '{input_func}' with args:"
|
|
154
|
+
f" {parsed_args}..."
|
|
155
|
+
)
|
|
156
|
+
inputs_result = getattr(user_module, input_func)(**parsed_args)
|
|
157
|
+
|
|
158
|
+
sample_args: tuple[Any, ...]
|
|
159
|
+
sample_kwargs: dict[str, Any]
|
|
160
|
+
if isinstance(inputs_result, tuple):
|
|
161
|
+
sample_args = inputs_result
|
|
162
|
+
sample_kwargs = {}
|
|
163
|
+
elif isinstance(inputs_result, dict):
|
|
164
|
+
sample_args = ()
|
|
165
|
+
sample_kwargs = inputs_result
|
|
166
|
+
else:
|
|
167
|
+
raise ValueError(
|
|
168
|
+
f"'{input_func}' must return a tuple (args) or a dict (kwargs)."
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
click.echo("Executing liteRT conversion tracer...")
|
|
172
|
+
|
|
173
|
+
builder = litert_torch
|
|
174
|
+
if target:
|
|
175
|
+
for t in target:
|
|
176
|
+
target_obj = npu_utils.get_target(t)
|
|
177
|
+
builder = builder.experimental_add_compilation_backend(target_obj)
|
|
178
|
+
|
|
179
|
+
edge_model = builder.convert(
|
|
180
|
+
module=model,
|
|
181
|
+
sample_args=sample_args,
|
|
182
|
+
sample_kwargs=sample_kwargs,
|
|
183
|
+
quant_config=quant_config,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
out_path = pathlib.Path(output)
|
|
187
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
188
|
+
model_name = script_path.stem
|
|
189
|
+
|
|
190
|
+
if target and export_aipack:
|
|
191
|
+
export_dir = pathlib.Path(export_aipack)
|
|
192
|
+
click.echo(f"Exporting AI Pack to: {export_dir}")
|
|
193
|
+
shutil.rmtree(export_dir, ignore_errors=True)
|
|
194
|
+
export_dir.mkdir(parents=True, exist_ok=True)
|
|
195
|
+
ai_pack_export.export(edge_model, str(export_dir), model_name, "model")
|
|
196
|
+
else:
|
|
197
|
+
if target:
|
|
198
|
+
click.echo(f"Exporting compiled model to {out_path} as {model_name}...")
|
|
199
|
+
edge_model.export(str(out_path), model_name=model_name)
|
|
200
|
+
else:
|
|
201
|
+
final_path = out_path / f"{model_name}.tflite"
|
|
202
|
+
click.echo(f"Exporting converted model to {final_path}...")
|
|
203
|
+
edge_model.export(str(final_path))
|
|
204
|
+
|
|
205
|
+
click.echo("Done!")
|
|
206
|
+
|
|
207
|
+
finally:
|
|
208
|
+
# Cleanup injected sys.path modification and sandbox module
|
|
209
|
+
sys.modules.pop(module_name, None)
|
|
210
|
+
if sys.path[0] == str(script_path.parent):
|
|
211
|
+
sys.path.pop(0)
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
# Copyright 2026 The LiteRT CLI Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Hugging Face automated export logic for LiteRT conversion.
|
|
17
|
+
|
|
18
|
+
This module provides the implementation for the 'hf' mode of the
|
|
19
|
+
`litert convert` command, orchestrating the download and conversion of
|
|
20
|
+
models directly from the Hugging Face Hub.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import pathlib
|
|
26
|
+
import shutil
|
|
27
|
+
|
|
28
|
+
import click
|
|
29
|
+
from litert_cli.core import npu_utils
|
|
30
|
+
|
|
31
|
+
from ai_edge_litert.aot import aot_compile as aot_lib
|
|
32
|
+
from ai_edge_litert.aot.ai_pack import export_lib as ai_pack_export
|
|
33
|
+
|
|
34
|
+
def convert_huggingface(
|
|
35
|
+
model: str,
|
|
36
|
+
output: str,
|
|
37
|
+
target: tuple[str, ...],
|
|
38
|
+
export_aipack: pathlib.Path | None,
|
|
39
|
+
quantize: str | None = None,
|
|
40
|
+
prefill_lengths: str = "256",
|
|
41
|
+
cache_length: int = 4096,
|
|
42
|
+
bundle_litert_lm: bool = True,
|
|
43
|
+
) -> None:
|
|
44
|
+
"""Converts models using HuggingFace Automated Export (export_hf).
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
model: The Hugging Face model ID (e.g., Qwen/Qwen1.5-0.5B-Chat).
|
|
48
|
+
output: The directory to save the converted model.
|
|
49
|
+
target: NPU targets to apply AOT compilation.
|
|
50
|
+
export_aipack: Output directory to export the AI Pack for PODAI.
|
|
51
|
+
quantize: Quantization recipe to apply.
|
|
52
|
+
prefill_lengths: Comma-separated list of prefill lengths.
|
|
53
|
+
cache_length: KV cache length.
|
|
54
|
+
bundle_litert_lm: Whether to bundle artifacts into a .litert_lm package.
|
|
55
|
+
"""
|
|
56
|
+
# Lazy load the export module to avoid importing torch and other heavy
|
|
57
|
+
# dependencies when the litert CLI is merely invoked for --help.
|
|
58
|
+
# pylint: disable=g-import-not-at-top
|
|
59
|
+
from litert_torch.generative.export_hf import export as hf_export
|
|
60
|
+
import transformers
|
|
61
|
+
|
|
62
|
+
click.echo(f"Starting conversion for model '{model}''")
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
is_causal_lm = False
|
|
66
|
+
is_gemma3 = False
|
|
67
|
+
is_gemma3n = False
|
|
68
|
+
is_gemma4 = False
|
|
69
|
+
is_gemma_vlm = False
|
|
70
|
+
|
|
71
|
+
# Verify AutoModelForCausalLM architecture
|
|
72
|
+
try:
|
|
73
|
+
config = transformers.AutoConfig.from_pretrained(
|
|
74
|
+
model, trust_remote_code=True
|
|
75
|
+
)
|
|
76
|
+
architectures = getattr(config, "architectures", [])
|
|
77
|
+
is_causal_lm = any("CausalLM" in arch for arch in architectures)
|
|
78
|
+
is_gemma3 = any(
|
|
79
|
+
"Gemma3ForConditionalGeneration" in arch for arch in architectures
|
|
80
|
+
)
|
|
81
|
+
is_gemma3n = any(
|
|
82
|
+
"Gemma3nForConditionalGeneration" in arch for arch in architectures
|
|
83
|
+
)
|
|
84
|
+
is_gemma4 = any(
|
|
85
|
+
"Gemma4ForConditionalGeneration" in arch for arch in architectures
|
|
86
|
+
)
|
|
87
|
+
is_gemma_vlm = is_gemma3 or is_gemma3n or is_gemma4
|
|
88
|
+
|
|
89
|
+
if not (is_causal_lm or is_gemma_vlm):
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"Currently only AutoModelForCausalLM is supported (or Gemma VLM"
|
|
92
|
+
f" architectures: Gemma3, Gemma3n, Gemma4). Model '{model}' has"
|
|
93
|
+
f" architectures {architectures}."
|
|
94
|
+
)
|
|
95
|
+
except Exception as e:
|
|
96
|
+
if isinstance(e, ValueError):
|
|
97
|
+
raise
|
|
98
|
+
click.echo(f"Warning during config verification: {e}", err=True)
|
|
99
|
+
|
|
100
|
+
# Parse prefill_lengths
|
|
101
|
+
parsed_prefill = [int(x.strip()) for x in prefill_lengths.split(",")]
|
|
102
|
+
|
|
103
|
+
# Call the auto-export function from litert_torch.
|
|
104
|
+
# It automatically saves to the output.
|
|
105
|
+
task = "text_generation"
|
|
106
|
+
export_kwargs = {}
|
|
107
|
+
use_jinja_template = is_gemma4
|
|
108
|
+
if is_gemma_vlm:
|
|
109
|
+
task = "image_text_to_text"
|
|
110
|
+
export_kwargs["export_vision_encoder"] = True
|
|
111
|
+
export_kwargs["externalize_embedder"] = True
|
|
112
|
+
if is_gemma3 or is_gemma3n:
|
|
113
|
+
export_kwargs["vision_encoder_quantization_recipe"] = (
|
|
114
|
+
"weight_only_wi8_afp32"
|
|
115
|
+
)
|
|
116
|
+
if is_gemma4:
|
|
117
|
+
export_kwargs["jinja_chat_template_override"] = (
|
|
118
|
+
"litert-community/gemma-4-E2B-it-litert-lm"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
hf_export.export(
|
|
122
|
+
model=model,
|
|
123
|
+
output_dir=output,
|
|
124
|
+
task=task,
|
|
125
|
+
quantization_recipe=quantize,
|
|
126
|
+
prefill_lengths=parsed_prefill,
|
|
127
|
+
cache_length=cache_length,
|
|
128
|
+
bundle_litert_lm=bundle_litert_lm,
|
|
129
|
+
trust_remote_code=False,
|
|
130
|
+
use_jinja_template=use_jinja_template,
|
|
131
|
+
**export_kwargs,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
if target:
|
|
135
|
+
output_dir = pathlib.Path(output)
|
|
136
|
+
# Find the generated tflite
|
|
137
|
+
tflite_files = list(output_dir.glob("*.tflite"))
|
|
138
|
+
if not tflite_files:
|
|
139
|
+
raise FileNotFoundError(
|
|
140
|
+
f"No .tflite files found in HF export output: {output_dir}"
|
|
141
|
+
)
|
|
142
|
+
target_tflite = tflite_files[0]
|
|
143
|
+
base_name = target_tflite.stem
|
|
144
|
+
|
|
145
|
+
click.echo(
|
|
146
|
+
f"Compiling converted model {target_tflite} for targets:"
|
|
147
|
+
f" {', '.join(target)}"
|
|
148
|
+
)
|
|
149
|
+
aot_targets = [npu_utils.get_target(t) for t in target]
|
|
150
|
+
|
|
151
|
+
compiled_models = aot_lib.aot_compile(
|
|
152
|
+
str(target_tflite),
|
|
153
|
+
target=aot_targets,
|
|
154
|
+
keep_going=False,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
if export_aipack:
|
|
158
|
+
export_dir = pathlib.Path(export_aipack)
|
|
159
|
+
click.echo(f"Exporting AI Pack to: {export_dir}")
|
|
160
|
+
shutil.rmtree(export_dir, ignore_errors=True)
|
|
161
|
+
export_dir.mkdir(parents=True, exist_ok=True)
|
|
162
|
+
ai_pack_export.export(
|
|
163
|
+
compiled_models, str(export_dir), base_name, "model"
|
|
164
|
+
)
|
|
165
|
+
else:
|
|
166
|
+
# Overwrite the original tflite output with the compiled models
|
|
167
|
+
click.echo(f"Exporting compiled model over original in {output_dir}")
|
|
168
|
+
compiled_models.export(str(output_dir), model_name=base_name)
|
|
169
|
+
|
|
170
|
+
click.echo(f"Successfully converted and saved model to {output}")
|
|
171
|
+
|
|
172
|
+
except Exception as e:
|
|
173
|
+
click.echo(f"Error during conversion: {e}")
|
|
174
|
+
# Re-raise to let click handle the error exit
|
|
175
|
+
raise
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Copyright 2026 The LiteRT CLI Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Command line interface for deleting managed models from LiteRT cache."""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import pathlib
|
|
21
|
+
import shutil
|
|
22
|
+
|
|
23
|
+
import click
|
|
24
|
+
|
|
25
|
+
from ..core.constants import LITERT_MODELS_CACHE_DIR
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@click.command(
|
|
29
|
+
"delete", help="Delete a managed model from the centralized cache."
|
|
30
|
+
)
|
|
31
|
+
@click.argument("ref")
|
|
32
|
+
@click.option("--yes", "-y", is_flag=True, help="Do not ask for confirmation.")
|
|
33
|
+
def delete_cmd(ref: str, yes: bool) -> None:
|
|
34
|
+
"""Deletes a managed model from the centralized cache.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
ref: The model reference to delete.
|
|
38
|
+
yes: Whether to skip confirmation prompt.
|
|
39
|
+
"""
|
|
40
|
+
# Flatten for directory check
|
|
41
|
+
ref_flat = ref.replace("/", "__") if "/" in ref else ref
|
|
42
|
+
cache_path = pathlib.Path(LITERT_MODELS_CACHE_DIR) / ref_flat
|
|
43
|
+
|
|
44
|
+
if not cache_path.exists() or not cache_path.is_dir():
|
|
45
|
+
raise click.ClickException(f"Model reference '{ref}' not found in cache.")
|
|
46
|
+
|
|
47
|
+
if not yes:
|
|
48
|
+
if not click.confirm(f"Are you sure you want to delete model '{ref}'?"):
|
|
49
|
+
click.echo("Aborted.")
|
|
50
|
+
return
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
shutil.rmtree(cache_path)
|
|
54
|
+
click.secho(f"Successfully deleted model '{ref}' from cache.", fg="green")
|
|
55
|
+
except OSError as e:
|
|
56
|
+
raise click.ClickException(f"Failed to delete model: {e}") from e
|