optimum-rbln 0.9.2a10__py3-none-any.whl → 0.9.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +4 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/cli.py +660 -0
- optimum/rbln/modeling.py +6 -0
- optimum/rbln/transformers/__init__.py +4 -0
- optimum/rbln/transformers/models/__init__.py +5 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +17 -17
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/utils/runtime_utils.py +25 -10
- {optimum_rbln-0.9.2a10.dist-info → optimum_rbln-0.9.2rc1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.9.2a10.dist-info → optimum_rbln-0.9.2rc1.dist-info}/RECORD +17 -11
- optimum_rbln-0.9.2rc1.dist-info/entry_points.txt +2 -0
- {optimum_rbln-0.9.2a10.dist-info → optimum_rbln-0.9.2rc1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.2a10.dist-info → optimum_rbln-0.9.2rc1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
|
@@ -74,6 +74,8 @@ _import_structure = {
|
|
|
74
74
|
"RBLNCLIPVisionModelWithProjectionConfig",
|
|
75
75
|
"RBLNColPaliForRetrieval",
|
|
76
76
|
"RBLNColPaliForRetrievalConfig",
|
|
77
|
+
"RBLNColQwen2ForRetrieval",
|
|
78
|
+
"RBLNColQwen2ForRetrievalConfig",
|
|
77
79
|
"RBLNDecoderOnlyModelConfig",
|
|
78
80
|
"RBLNDecoderOnlyModel",
|
|
79
81
|
"RBLNDecoderOnlyModelForCausalLM",
|
|
@@ -366,6 +368,8 @@ if TYPE_CHECKING:
|
|
|
366
368
|
RBLNCLIPVisionModelWithProjectionConfig,
|
|
367
369
|
RBLNColPaliForRetrieval,
|
|
368
370
|
RBLNColPaliForRetrievalConfig,
|
|
371
|
+
RBLNColQwen2ForRetrieval,
|
|
372
|
+
RBLNColQwen2ForRetrievalConfig,
|
|
369
373
|
RBLNDecoderOnlyModel,
|
|
370
374
|
RBLNDecoderOnlyModelConfig,
|
|
371
375
|
RBLNDecoderOnlyModelForCausalLM,
|
optimum/rbln/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.9.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 9, 2, '
|
|
31
|
+
__version__ = version = '0.9.2rc1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 9, 2, 'rc1')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
optimum/rbln/cli.py
ADDED
|
@@ -0,0 +1,660 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
3
|
+
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at:
|
|
7
|
+
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import inspect
|
|
18
|
+
import json
|
|
19
|
+
import sys
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Optional
|
|
22
|
+
|
|
23
|
+
from huggingface_hub import hf_hub_download
|
|
24
|
+
|
|
25
|
+
from .__version__ import __version__
|
|
26
|
+
from .configuration_utils import RBLNModelConfig
|
|
27
|
+
from .utils.model_utils import get_rbln_model_cls
|
|
28
|
+
from .utils.runtime_utils import ContextRblnConfig
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def set_nested_dict(dictionary, key_path, value):
|
|
32
|
+
"""
|
|
33
|
+
Set a value in a nested dictionary using dot notation.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
dictionary: The dictionary to modify
|
|
37
|
+
key_path: Dot-separated key path (e.g., "unet.batch_size")
|
|
38
|
+
value: The value to set
|
|
39
|
+
"""
|
|
40
|
+
keys = key_path.split(".")
|
|
41
|
+
current = dictionary
|
|
42
|
+
|
|
43
|
+
# Navigate to the parent of the final key
|
|
44
|
+
for key in keys[:-1]:
|
|
45
|
+
if key not in current:
|
|
46
|
+
current[key] = {}
|
|
47
|
+
current = current[key]
|
|
48
|
+
|
|
49
|
+
# Set the final value
|
|
50
|
+
current[keys[-1]] = value
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def parse_value(value_str):
|
|
54
|
+
"""
|
|
55
|
+
Parse a string value to appropriate Python type.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
value_str: String value to parse
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Parsed value (bool, int, float, list, dict, or str)
|
|
62
|
+
"""
|
|
63
|
+
# First try to parse as JSON (handles dicts, lists, etc.)
|
|
64
|
+
try:
|
|
65
|
+
return json.loads(value_str)
|
|
66
|
+
except (json.JSONDecodeError, ValueError):
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
# Handle boolean values
|
|
70
|
+
if value_str.lower() in ["true", "false"]:
|
|
71
|
+
return value_str.lower() == "true"
|
|
72
|
+
|
|
73
|
+
# Handle comma-separated values as lists
|
|
74
|
+
if "," in value_str:
|
|
75
|
+
parts = [part.strip() for part in value_str.split(",")]
|
|
76
|
+
# Recursively parse each part
|
|
77
|
+
return [parse_single_value(part) for part in parts]
|
|
78
|
+
|
|
79
|
+
# Handle single values
|
|
80
|
+
return parse_single_value(value_str)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def parse_single_value(value_str):
|
|
84
|
+
"""
|
|
85
|
+
Parse a single string value to appropriate Python type.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
value_str: String value to parse (no commas)
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Parsed value (bool, int, float, or str)
|
|
92
|
+
"""
|
|
93
|
+
# Handle boolean values
|
|
94
|
+
if value_str.lower() in ["true", "false"]:
|
|
95
|
+
return value_str.lower() == "true"
|
|
96
|
+
|
|
97
|
+
# Handle integer values
|
|
98
|
+
if value_str.isdigit() or (value_str.startswith("-") and value_str[1:].isdigit()):
|
|
99
|
+
return int(value_str)
|
|
100
|
+
|
|
101
|
+
# Handle float values
|
|
102
|
+
try:
|
|
103
|
+
return float(value_str)
|
|
104
|
+
except ValueError:
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
# Return as string if all else fails
|
|
108
|
+
return value_str
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# ---- Simple ANSI styling helpers for richer CLI output ----
|
|
112
|
+
ANSI_RESET = "\033[0m"
|
|
113
|
+
ANSI_DIM = "\033[2m"
|
|
114
|
+
ANSI_UNDERLINE = "\033[4m"
|
|
115
|
+
ANSI_RED = "\033[31m"
|
|
116
|
+
ANSI_GREEN = "\033[32m"
|
|
117
|
+
ANSI_YELLOW = "\033[33m"
|
|
118
|
+
ANSI_BLUE = "\033[34m"
|
|
119
|
+
ANSI_MAGENTA = "\033[35m"
|
|
120
|
+
ANSI_CYAN = "\033[36m"
|
|
121
|
+
ANSI_BRIGHT_RED = "\033[91m"
|
|
122
|
+
ANSI_BRIGHT_GREEN = "\033[92m"
|
|
123
|
+
ANSI_BRIGHT_YELLOW = "\033[93m"
|
|
124
|
+
ANSI_BRIGHT_BLUE = "\033[94m"
|
|
125
|
+
ANSI_BRIGHT_MAGENTA = "\033[95m"
|
|
126
|
+
ANSI_BRIGHT_CYAN = "\033[96m"
|
|
127
|
+
|
|
128
|
+
STYLES_ENABLED = True
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _color(text: str, color: str) -> str:
|
|
132
|
+
if not STYLES_ENABLED:
|
|
133
|
+
return text
|
|
134
|
+
return f"{color}{text}{ANSI_RESET}"
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _underline(text: str) -> str:
|
|
138
|
+
if not STYLES_ENABLED:
|
|
139
|
+
return text
|
|
140
|
+
return f"{ANSI_UNDERLINE}{text}{ANSI_RESET}"
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _section(title: str, color: str = ANSI_BRIGHT_CYAN, icon: str = "✦") -> str:
|
|
144
|
+
line = f"{icon} {title}"
|
|
145
|
+
return _underline(_color(line, color))
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _label(text: str) -> str:
|
|
149
|
+
# Inline label for key names
|
|
150
|
+
return _color(text, ANSI_BRIGHT_CYAN)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
EXAMPLES_TEXT = r"""
|
|
154
|
+
Quick start examples
|
|
155
|
+
1) Compile a Llama chat model for causal LM
|
|
156
|
+
optimum-rbln-cli --output-dir ./compiled_llama \
|
|
157
|
+
--model-id meta-llama/Llama-2-7b-chat-hf \
|
|
158
|
+
--batch-size 2 --tensor-parallel-size 4
|
|
159
|
+
|
|
160
|
+
2) Compile with explicit class (Auto sequence classification)
|
|
161
|
+
optimum-rbln-cli --output-dir ./compiled_bert \
|
|
162
|
+
--class RBLNAutoModelForSequenceClassification \
|
|
163
|
+
--model-id bert-base-uncased \
|
|
164
|
+
--batch-size 8 --max-seq-len 512
|
|
165
|
+
|
|
166
|
+
3) Pass nested rbln_config with dot-notation (e.g., for diffusion)
|
|
167
|
+
optimum-rbln-cli --output-dir ./compiled_sd \
|
|
168
|
+
--model-id runwayml/stable-diffusion-v1-5 \
|
|
169
|
+
--unet.batch_size 2 --vae.batch_size 1
|
|
170
|
+
|
|
171
|
+
Notes
|
|
172
|
+
- Any extra --key value pairs not defined above are collected into rbln_config
|
|
173
|
+
and forwarded to from_pretrained(..., rbln_config=...).
|
|
174
|
+
- Use --list-classes to see available RBLN classes.
|
|
175
|
+
- Use --show-rbln-config to see accepted rbln_config keys for the resolved class
|
|
176
|
+
(via --class or inferred from --model-id).
|
|
177
|
+
- Show this examples list anytime with: optimum-rbln-cli --examples
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _list_available_rbln_classes():
|
|
182
|
+
"""Return a sorted list of (name, kind) for available RBLN classes; kind in {"Model","Pipeline","Auto"}."""
|
|
183
|
+
try:
|
|
184
|
+
# Import lazily exposed module and enumerate public names
|
|
185
|
+
import optimum.rbln as rbln # noqa: WPS433 (third-party import within function)
|
|
186
|
+
|
|
187
|
+
# Import bases for filtering
|
|
188
|
+
RBLNBaseModel = getattr(rbln, "RBLNBaseModel", None)
|
|
189
|
+
RBLNDiffusionMixin = getattr(rbln, "RBLNDiffusionMixin", None)
|
|
190
|
+
|
|
191
|
+
class_names = []
|
|
192
|
+
for name in dir(rbln):
|
|
193
|
+
if not name.startswith("RBLN"):
|
|
194
|
+
continue
|
|
195
|
+
try:
|
|
196
|
+
obj = getattr(rbln, name)
|
|
197
|
+
if not inspect.isclass(obj):
|
|
198
|
+
continue
|
|
199
|
+
# Exclude config classes and obvious non-user-facing bases
|
|
200
|
+
if name.endswith("Config") or name in {"RBLNModel", "RBLNBaseModel", "RBLNDiffusionMixin"}:
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
# Keep only concrete models/pipelines/auto
|
|
204
|
+
is_model = RBLNBaseModel is not None and isinstance(obj, type) and issubclass(obj, RBLNBaseModel)
|
|
205
|
+
is_pipeline = (
|
|
206
|
+
RBLNDiffusionMixin is not None and isinstance(obj, type) and issubclass(obj, RBLNDiffusionMixin)
|
|
207
|
+
)
|
|
208
|
+
is_auto = name.startswith("RBLNAuto")
|
|
209
|
+
if is_model:
|
|
210
|
+
class_names.append((name, "Model"))
|
|
211
|
+
elif is_pipeline:
|
|
212
|
+
class_names.append((name, "Pipeline"))
|
|
213
|
+
elif is_auto:
|
|
214
|
+
class_names.append((name, "Auto"))
|
|
215
|
+
except Exception:
|
|
216
|
+
# Skip anything that errors on attribute access
|
|
217
|
+
continue
|
|
218
|
+
# Deduplicate and sort by kind then name
|
|
219
|
+
unique = {(n, k) for (n, k) in class_names}
|
|
220
|
+
return sorted(unique, key=lambda x: (x[1], x[0]))
|
|
221
|
+
except Exception:
|
|
222
|
+
return []
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _print_rbln_config_options(class_name: str):
|
|
226
|
+
"""Inspect the RBLN config class for a given model/pipeline and print accepted rbln_config keys."""
|
|
227
|
+
try:
|
|
228
|
+
model_cls = get_rbln_model_cls(class_name)
|
|
229
|
+
except Exception as e:
|
|
230
|
+
print(f"Unknown RBLN class: {class_name}. Error: {e}", file=sys.stderr)
|
|
231
|
+
sys.exit(2)
|
|
232
|
+
|
|
233
|
+
# Obtain the associated config class
|
|
234
|
+
try:
|
|
235
|
+
config_cls = model_cls.get_rbln_config_class()
|
|
236
|
+
except Exception:
|
|
237
|
+
print(
|
|
238
|
+
f"The class '{class_name}' does not provide an associated RBLN config class.",
|
|
239
|
+
file=sys.stderr,
|
|
240
|
+
)
|
|
241
|
+
sys.exit(2)
|
|
242
|
+
|
|
243
|
+
# Description from both class docstring and __init__ docstring
|
|
244
|
+
class_doc = None
|
|
245
|
+
init_doc = None
|
|
246
|
+
try:
|
|
247
|
+
class_doc = inspect.getdoc(config_cls)
|
|
248
|
+
except Exception:
|
|
249
|
+
class_doc = None
|
|
250
|
+
try:
|
|
251
|
+
init_doc = inspect.getdoc(getattr(config_cls, "__init__", None))
|
|
252
|
+
except Exception:
|
|
253
|
+
init_doc = None
|
|
254
|
+
|
|
255
|
+
# Base and specific parameter sets via signature introspection
|
|
256
|
+
base_params = set()
|
|
257
|
+
try:
|
|
258
|
+
base_sig = inspect.signature(RBLNModelConfig.__init__)
|
|
259
|
+
base_params = {p.name for p in base_sig.parameters.values() if p.name not in {"self"}}
|
|
260
|
+
except Exception:
|
|
261
|
+
pass
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
cfg_sig = inspect.signature(config_cls.__init__)
|
|
265
|
+
cfg_params = [p for p in cfg_sig.parameters.values() if p.name not in {"self"}]
|
|
266
|
+
except Exception:
|
|
267
|
+
cfg_params = []
|
|
268
|
+
|
|
269
|
+
# Identify submodule keys if present
|
|
270
|
+
submodules = []
|
|
271
|
+
try:
|
|
272
|
+
submodules = list(getattr(config_cls, "submodules", []) or [])
|
|
273
|
+
except Exception:
|
|
274
|
+
submodules = []
|
|
275
|
+
|
|
276
|
+
# Categorize parameters
|
|
277
|
+
common_keys = []
|
|
278
|
+
specific_keys = []
|
|
279
|
+
for p in cfg_params:
|
|
280
|
+
if p.kind in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL):
|
|
281
|
+
continue
|
|
282
|
+
if p.name in submodules:
|
|
283
|
+
continue
|
|
284
|
+
if p.name in base_params:
|
|
285
|
+
common_keys.append(p)
|
|
286
|
+
else:
|
|
287
|
+
specific_keys.append(p)
|
|
288
|
+
|
|
289
|
+
print(_section(f"RBLN class: {class_name}", ANSI_BRIGHT_BLUE, icon="🧩"))
|
|
290
|
+
print(_underline(_color(f"Config class: {config_cls.__name__}", ANSI_BRIGHT_CYAN)))
|
|
291
|
+
if class_doc:
|
|
292
|
+
print(_underline("\nDescription (class):"))
|
|
293
|
+
for line in class_doc.splitlines():
|
|
294
|
+
print(f" {line}")
|
|
295
|
+
if init_doc:
|
|
296
|
+
print(_underline("\nDescription (__init__):"))
|
|
297
|
+
for line in init_doc.splitlines():
|
|
298
|
+
print(f" {line}")
|
|
299
|
+
if submodules:
|
|
300
|
+
print(_underline("\nSubmodules:"))
|
|
301
|
+
for s in submodules:
|
|
302
|
+
print(f" • {s} {_color('(use nested keys like --' + s + '.batch_size 2)', ANSI_DIM)}")
|
|
303
|
+
|
|
304
|
+
# Curated: common compile-time options that live in RBLNModelConfig (non-runtime)
|
|
305
|
+
print(_underline("\nCommon compile-time options (in rbln_config):"))
|
|
306
|
+
print(" • npu: Target NPU for compilation (e.g., 'RBLN-CA25').")
|
|
307
|
+
print(" • tensor_parallel_size: Number of NPUs to shard the model at compile time.")
|
|
308
|
+
|
|
309
|
+
print(_underline("\nTips:"))
|
|
310
|
+
print(" - Pass config keys as CLI flags, e.g., --batch_size 2 --max_seq_len 4096")
|
|
311
|
+
print(" - Compile-time examples: --npu RBLN-CA25 --tensor_parallel_size 4")
|
|
312
|
+
print(" - Use dot-notation for submodules, e.g., --vision_tower.image_size 336 --language_model.batch_size 1")
|
|
313
|
+
print(" - To see examples: optimum-rbln-cli --examples")
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def _read_json_from_model_id(
|
|
317
|
+
model_id: str,
|
|
318
|
+
filename: str,
|
|
319
|
+
*,
|
|
320
|
+
hf_token: Optional[str] = None,
|
|
321
|
+
hf_revision: Optional[str] = None,
|
|
322
|
+
hf_cache_dir: Optional[str] = None,
|
|
323
|
+
hf_force_download: bool = False,
|
|
324
|
+
hf_local_files_only: bool = False,
|
|
325
|
+
) -> Optional[dict]:
|
|
326
|
+
"""Read a JSON file (e.g., config.json or model_index.json) from a local path or the HuggingFace Hub.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
model_id: Local directory path or HuggingFace Hub repo id
|
|
330
|
+
filename: Name of the JSON file to read
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Parsed JSON dictionary if found, else None
|
|
334
|
+
"""
|
|
335
|
+
# Local directory
|
|
336
|
+
local_dir = Path(model_id)
|
|
337
|
+
if local_dir.exists() and local_dir.is_dir():
|
|
338
|
+
local_file = local_dir / filename
|
|
339
|
+
if local_file.exists():
|
|
340
|
+
try:
|
|
341
|
+
with local_file.open("r", encoding="utf-8") as f:
|
|
342
|
+
return json.load(f)
|
|
343
|
+
except Exception:
|
|
344
|
+
return None
|
|
345
|
+
|
|
346
|
+
# HuggingFace Hub
|
|
347
|
+
try:
|
|
348
|
+
downloaded_path = hf_hub_download(
|
|
349
|
+
repo_id=model_id,
|
|
350
|
+
filename=filename,
|
|
351
|
+
revision=hf_revision,
|
|
352
|
+
token=hf_token,
|
|
353
|
+
cache_dir=hf_cache_dir,
|
|
354
|
+
force_download=hf_force_download,
|
|
355
|
+
local_files_only=hf_local_files_only,
|
|
356
|
+
)
|
|
357
|
+
p = Path(downloaded_path)
|
|
358
|
+
if p.exists():
|
|
359
|
+
with p.open("r", encoding="utf-8") as f:
|
|
360
|
+
return json.load(f)
|
|
361
|
+
except Exception:
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
return None
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _infer_rbln_class_from_model_id(
|
|
368
|
+
model_id: str,
|
|
369
|
+
*,
|
|
370
|
+
hf_token: Optional[str] = None,
|
|
371
|
+
hf_revision: Optional[str] = None,
|
|
372
|
+
hf_cache_dir: Optional[str] = None,
|
|
373
|
+
hf_force_download: bool = False,
|
|
374
|
+
hf_local_files_only: bool = False,
|
|
375
|
+
) -> Optional[str]:
|
|
376
|
+
"""Infer RBLN class name from model files by prefixing discovered class with 'RBLN'.
|
|
377
|
+
|
|
378
|
+
Order of precedence:
|
|
379
|
+
1) model_index.json['pipeline'] -> e.g., 'StableDiffusionPipeline' -> 'RBLNStableDiffusionPipeline'
|
|
380
|
+
2) config.json['architectures'][0] -> e.g., 'LlamaForCausalLM' -> 'RBLNLlamaForCausalLM'
|
|
381
|
+
"""
|
|
382
|
+
# 1) Diffusers-style pipeline
|
|
383
|
+
model_index = _read_json_from_model_id(
|
|
384
|
+
model_id,
|
|
385
|
+
"model_index.json",
|
|
386
|
+
hf_token=hf_token,
|
|
387
|
+
hf_revision=hf_revision,
|
|
388
|
+
hf_cache_dir=hf_cache_dir,
|
|
389
|
+
hf_force_download=hf_force_download,
|
|
390
|
+
hf_local_files_only=hf_local_files_only,
|
|
391
|
+
)
|
|
392
|
+
if isinstance(model_index, dict):
|
|
393
|
+
pipeline_cls = model_index.get("_class_name")
|
|
394
|
+
if isinstance(pipeline_cls, str) and pipeline_cls:
|
|
395
|
+
return f"RBLN{pipeline_cls}"
|
|
396
|
+
|
|
397
|
+
# 2) Transformers config architectures
|
|
398
|
+
cfg = _read_json_from_model_id(
|
|
399
|
+
model_id,
|
|
400
|
+
"config.json",
|
|
401
|
+
hf_token=hf_token,
|
|
402
|
+
hf_revision=hf_revision,
|
|
403
|
+
hf_cache_dir=hf_cache_dir,
|
|
404
|
+
hf_force_download=hf_force_download,
|
|
405
|
+
hf_local_files_only=hf_local_files_only,
|
|
406
|
+
)
|
|
407
|
+
if isinstance(cfg, dict):
|
|
408
|
+
architectures = cfg.get("architectures")
|
|
409
|
+
if isinstance(architectures, list) and architectures:
|
|
410
|
+
arch0 = architectures[0]
|
|
411
|
+
if isinstance(arch0, str) and arch0:
|
|
412
|
+
return f"RBLN{arch0}"
|
|
413
|
+
|
|
414
|
+
return None
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def main():
|
|
418
|
+
"""
|
|
419
|
+
Main CLI function for optimum-rbln model compilation.
|
|
420
|
+
"""
|
|
421
|
+
# Pre-parse lightweight flags that should work without other required args
|
|
422
|
+
pre_parser = argparse.ArgumentParser(add_help=False)
|
|
423
|
+
pre_parser.add_argument("--list-classes", action="store_true", help="List available RBLN classes and exit")
|
|
424
|
+
pre_parser.add_argument("--examples", action="store_true", help="Show quick start examples and exit")
|
|
425
|
+
pre_parser.add_argument("--version", action="store_true", help="Show version and exit")
|
|
426
|
+
pre_parser.add_argument("--no-style", action="store_true", help="Disable ANSI styling in output")
|
|
427
|
+
pre_args, _ = pre_parser.parse_known_args()
|
|
428
|
+
|
|
429
|
+
if pre_args.version:
|
|
430
|
+
print(f"optimum-rbln-cli {__version__}")
|
|
431
|
+
return
|
|
432
|
+
|
|
433
|
+
# Apply style preference as early as possible
|
|
434
|
+
global STYLES_ENABLED
|
|
435
|
+
if pre_args.no_style:
|
|
436
|
+
STYLES_ENABLED = False
|
|
437
|
+
|
|
438
|
+
if pre_args.list_classes:
|
|
439
|
+
classes = _list_available_rbln_classes()
|
|
440
|
+
if not classes:
|
|
441
|
+
print(_section("No RBLN classes found", ANSI_RED, icon="✖"))
|
|
442
|
+
print("Please ensure the package is installed correctly.")
|
|
443
|
+
else:
|
|
444
|
+
autos = [n for n, k in classes if k == "Auto"]
|
|
445
|
+
models = [n for n, k in classes if k == "Model"]
|
|
446
|
+
pipes = [n for n, k in classes if k == "Pipeline"]
|
|
447
|
+
print(_section("Available RBLN classes (use with --class)", ANSI_BRIGHT_BLUE, icon="📚"))
|
|
448
|
+
if autos:
|
|
449
|
+
print(_underline(_color("\nAuto classes:", ANSI_BRIGHT_YELLOW)))
|
|
450
|
+
for name in autos:
|
|
451
|
+
print(f" • {name}")
|
|
452
|
+
if models:
|
|
453
|
+
print(_underline(_color("\nModels:", ANSI_BRIGHT_GREEN)))
|
|
454
|
+
for name in models:
|
|
455
|
+
print(f" • {name}")
|
|
456
|
+
if pipes:
|
|
457
|
+
print(_underline(_color("\nPipelines:", ANSI_BRIGHT_MAGENTA)))
|
|
458
|
+
for name in pipes:
|
|
459
|
+
print(f" • {name}")
|
|
460
|
+
print(f"\nTotal: {_underline(str(len(classes)))}")
|
|
461
|
+
return
|
|
462
|
+
|
|
463
|
+
if pre_args.examples:
|
|
464
|
+
print(EXAMPLES_TEXT)
|
|
465
|
+
return
|
|
466
|
+
|
|
467
|
+
parser = argparse.ArgumentParser(
|
|
468
|
+
description=(
|
|
469
|
+
"Compile and export HuggingFace models/pipelines for RBLN devices.\n\n"
|
|
470
|
+
"Required: --model-id.\n"
|
|
471
|
+
"Additional --key value pairs are forwarded to rbln_config.\n"
|
|
472
|
+
"Use dot-notation for nested fields (e.g., --unet.batch_size 2)."
|
|
473
|
+
),
|
|
474
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
475
|
+
epilog=EXAMPLES_TEXT,
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
parser.add_argument(
|
|
479
|
+
"--model-id",
|
|
480
|
+
dest="model_id",
|
|
481
|
+
type=str,
|
|
482
|
+
required=True,
|
|
483
|
+
help="Model ID from HuggingFace Hub or local directory path",
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# Optional output directory argument (defaults to ./rbln_out)
|
|
487
|
+
parser.add_argument(
|
|
488
|
+
"-o",
|
|
489
|
+
"--output-dir",
|
|
490
|
+
dest="output_dir",
|
|
491
|
+
type=str,
|
|
492
|
+
default="./rbln_out",
|
|
493
|
+
help="Directory where the compiled model will be saved (default: ./rbln_out)",
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
# Optional class argument (can be inferred)
|
|
497
|
+
parser.add_argument(
|
|
498
|
+
"--class",
|
|
499
|
+
dest="model_class",
|
|
500
|
+
type=str,
|
|
501
|
+
required=False,
|
|
502
|
+
help=(
|
|
503
|
+
"RBLN model class to use for compilation (e.g., RBLNLlamaForCausalLM, RBLNAutoModelForCausalLM). "
|
|
504
|
+
"If omitted, it will be inferred from model_id by reading model_index.json or config.json."
|
|
505
|
+
),
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
# Optional flag to show rbln_config for the resolved class (no compilation)
|
|
509
|
+
parser.add_argument(
|
|
510
|
+
"--show-rbln-config",
|
|
511
|
+
dest="show_rbln_config",
|
|
512
|
+
action="store_true",
|
|
513
|
+
help="Show rbln_config keys for the resolved RBLN class (via --class or inferred from --model-id) and exit",
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
# Standard --version that integrates with argparse (works after full parse)
|
|
517
|
+
parser.add_argument(
|
|
518
|
+
"--version",
|
|
519
|
+
action="version",
|
|
520
|
+
version=f"%(prog)s {__version__}",
|
|
521
|
+
help="Show version and exit",
|
|
522
|
+
)
|
|
523
|
+
parser.add_argument("--no-style", action="store_true", help="Disable ANSI styling in output")
|
|
524
|
+
|
|
525
|
+
# HuggingFace Hub access options
|
|
526
|
+
parser.add_argument(
|
|
527
|
+
"--hf-token",
|
|
528
|
+
dest="hf_token",
|
|
529
|
+
type=str,
|
|
530
|
+
default=None,
|
|
531
|
+
help="HuggingFace token to access private repositories",
|
|
532
|
+
)
|
|
533
|
+
parser.add_argument(
|
|
534
|
+
"--hf-revision",
|
|
535
|
+
dest="hf_revision",
|
|
536
|
+
type=str,
|
|
537
|
+
default=None,
|
|
538
|
+
help="Specific model revision to download (branch, tag, or commit)",
|
|
539
|
+
)
|
|
540
|
+
parser.add_argument(
|
|
541
|
+
"--hf-cache-dir",
|
|
542
|
+
dest="hf_cache_dir",
|
|
543
|
+
type=str,
|
|
544
|
+
default=None,
|
|
545
|
+
help="Directory to use as HuggingFace download cache",
|
|
546
|
+
)
|
|
547
|
+
parser.add_argument(
|
|
548
|
+
"--hf-force-download",
|
|
549
|
+
dest="hf_force_download",
|
|
550
|
+
action="store_true",
|
|
551
|
+
help="Force redownload of files from the HuggingFace Hub",
|
|
552
|
+
)
|
|
553
|
+
parser.add_argument(
|
|
554
|
+
"--hf-local-files-only",
|
|
555
|
+
dest="hf_local_files_only",
|
|
556
|
+
action="store_true",
|
|
557
|
+
help="Only use local files and do not attempt to download from the network",
|
|
558
|
+
)
|
|
559
|
+
# All other arguments will be parsed dynamically and passed to from_pretrained
|
|
560
|
+
|
|
561
|
+
# Print help with examples when no args were provided
|
|
562
|
+
if len(sys.argv) == 1:
|
|
563
|
+
parser.print_help()
|
|
564
|
+
sys.exit(2)
|
|
565
|
+
|
|
566
|
+
# Parse known args to allow for additional rbln_* arguments
|
|
567
|
+
args, unknown_args = parser.parse_known_args()
|
|
568
|
+
|
|
569
|
+
try:
|
|
570
|
+
# Resolve or infer model class for compilation
|
|
571
|
+
resolved_class_name: Optional[str] = args.model_class
|
|
572
|
+
if not resolved_class_name:
|
|
573
|
+
resolved_class_name = _infer_rbln_class_from_model_id(
|
|
574
|
+
args.model_id,
|
|
575
|
+
hf_token=args.hf_token,
|
|
576
|
+
hf_revision=args.hf_revision,
|
|
577
|
+
hf_cache_dir=args.hf_cache_dir,
|
|
578
|
+
hf_force_download=args.hf_force_download,
|
|
579
|
+
hf_local_files_only=args.hf_local_files_only,
|
|
580
|
+
)
|
|
581
|
+
if not resolved_class_name:
|
|
582
|
+
print(
|
|
583
|
+
"Could not infer RBLN class from model files. Please specify --class explicitly.",
|
|
584
|
+
file=sys.stderr,
|
|
585
|
+
)
|
|
586
|
+
sys.exit(2)
|
|
587
|
+
|
|
588
|
+
if args.show_rbln_config:
|
|
589
|
+
_print_rbln_config_options(resolved_class_name)
|
|
590
|
+
return
|
|
591
|
+
|
|
592
|
+
# Get the model class using the utility function (with helpful error)
|
|
593
|
+
try:
|
|
594
|
+
model_class = get_rbln_model_cls(resolved_class_name)
|
|
595
|
+
except AttributeError:
|
|
596
|
+
print(
|
|
597
|
+
f"Unknown RBLN class: {resolved_class_name}.\n"
|
|
598
|
+
"Run 'optimum-rbln-cli --list-classes' to see available classes.",
|
|
599
|
+
file=sys.stderr,
|
|
600
|
+
)
|
|
601
|
+
sys.exit(2)
|
|
602
|
+
|
|
603
|
+
# Create output directory
|
|
604
|
+
output_path = Path(args.output_dir)
|
|
605
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
606
|
+
|
|
607
|
+
# Prepare rbln_config by parsing all unknown arguments
|
|
608
|
+
rbln_config = {}
|
|
609
|
+
|
|
610
|
+
# Parse all unknown arguments
|
|
611
|
+
i = 0
|
|
612
|
+
while i < len(unknown_args):
|
|
613
|
+
arg = unknown_args[i]
|
|
614
|
+
if arg.startswith("--"):
|
|
615
|
+
arg_name = arg[2:].replace("-", "_")
|
|
616
|
+
if i + 1 < len(unknown_args) and not unknown_args[i + 1].startswith("--"):
|
|
617
|
+
# Has a value
|
|
618
|
+
arg_value = unknown_args[i + 1]
|
|
619
|
+
parsed_value = parse_value(arg_value)
|
|
620
|
+
|
|
621
|
+
# Check if this is a nested config argument (contains dots)
|
|
622
|
+
if "." in arg_name:
|
|
623
|
+
set_nested_dict(rbln_config, arg_name, parsed_value)
|
|
624
|
+
else:
|
|
625
|
+
rbln_config[arg_name] = parsed_value
|
|
626
|
+
i += 2
|
|
627
|
+
else:
|
|
628
|
+
# Boolean flag
|
|
629
|
+
if "." in arg_name:
|
|
630
|
+
set_nested_dict(rbln_config, arg_name, True)
|
|
631
|
+
else:
|
|
632
|
+
rbln_config[arg_name] = True
|
|
633
|
+
i += 1
|
|
634
|
+
else:
|
|
635
|
+
i += 1
|
|
636
|
+
|
|
637
|
+
# Set create_runtimes to False by default for CLI compilation if not specified
|
|
638
|
+
create_runtimes = rbln_config.pop("create_runtimes", False)
|
|
639
|
+
|
|
640
|
+
print(_section("Starting compilation", ANSI_BRIGHT_BLUE, icon="🚀"))
|
|
641
|
+
print(f"{_label('Model:')} {args.model_id}")
|
|
642
|
+
print(f"{_label('Class:')} {resolved_class_name}")
|
|
643
|
+
print(f"{_label('Output:')} {output_path.absolute()}")
|
|
644
|
+
print(f"{_label('rbln_config:')} {json.dumps(rbln_config, indent=2, ensure_ascii=False)}")
|
|
645
|
+
|
|
646
|
+
with ContextRblnConfig(create_runtimes=create_runtimes):
|
|
647
|
+
_ = model_class.from_pretrained(
|
|
648
|
+
args.model_id, export=True, model_save_dir=str(output_path), rbln_config=rbln_config
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
print(_section("Model compilation completed successfully", ANSI_BRIGHT_GREEN, icon="✅"))
|
|
652
|
+
print(f"Saved to: {output_path.absolute()}")
|
|
653
|
+
|
|
654
|
+
except Exception as e:
|
|
655
|
+
print(f"❌ Error during model compilation: {e}", file=sys.stderr)
|
|
656
|
+
sys.exit(1)
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
if __name__ == "__main__":
|
|
660
|
+
main()
|