optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/cli.py ADDED
@@ -0,0 +1,660 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Rebellions Inc. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at:
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import inspect
18
+ import json
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Optional
22
+
23
+ from huggingface_hub import hf_hub_download
24
+
25
+ from .__version__ import __version__
26
+ from .configuration_utils import RBLNModelConfig
27
+ from .utils.model_utils import get_rbln_model_cls
28
+ from .utils.runtime_utils import ContextRblnConfig
29
+
30
+
31
+ def set_nested_dict(dictionary, key_path, value):
32
+ """
33
+ Set a value in a nested dictionary using dot notation.
34
+
35
+ Args:
36
+ dictionary: The dictionary to modify
37
+ key_path: Dot-separated key path (e.g., "unet.batch_size")
38
+ value: The value to set
39
+ """
40
+ keys = key_path.split(".")
41
+ current = dictionary
42
+
43
+ # Navigate to the parent of the final key
44
+ for key in keys[:-1]:
45
+ if key not in current:
46
+ current[key] = {}
47
+ current = current[key]
48
+
49
+ # Set the final value
50
+ current[keys[-1]] = value
51
+
52
+
53
+ def parse_value(value_str):
54
+ """
55
+ Parse a string value to appropriate Python type.
56
+
57
+ Args:
58
+ value_str: String value to parse
59
+
60
+ Returns:
61
+ Parsed value (bool, int, float, list, dict, or str)
62
+ """
63
+ # First try to parse as JSON (handles dicts, lists, etc.)
64
+ try:
65
+ return json.loads(value_str)
66
+ except (json.JSONDecodeError, ValueError):
67
+ pass
68
+
69
+ # Handle boolean values
70
+ if value_str.lower() in ["true", "false"]:
71
+ return value_str.lower() == "true"
72
+
73
+ # Handle comma-separated values as lists
74
+ if "," in value_str:
75
+ parts = [part.strip() for part in value_str.split(",")]
76
+ # Recursively parse each part
77
+ return [parse_single_value(part) for part in parts]
78
+
79
+ # Handle single values
80
+ return parse_single_value(value_str)
81
+
82
+
83
+ def parse_single_value(value_str):
84
+ """
85
+ Parse a single string value to appropriate Python type.
86
+
87
+ Args:
88
+ value_str: String value to parse (no commas)
89
+
90
+ Returns:
91
+ Parsed value (bool, int, float, or str)
92
+ """
93
+ # Handle boolean values
94
+ if value_str.lower() in ["true", "false"]:
95
+ return value_str.lower() == "true"
96
+
97
+ # Handle integer values
98
+ if value_str.isdigit() or (value_str.startswith("-") and value_str[1:].isdigit()):
99
+ return int(value_str)
100
+
101
+ # Handle float values
102
+ try:
103
+ return float(value_str)
104
+ except ValueError:
105
+ pass
106
+
107
+ # Return as string if all else fails
108
+ return value_str
109
+
110
+
111
+ # ---- Simple ANSI styling helpers for richer CLI output ----
112
+ ANSI_RESET = "\033[0m"
113
+ ANSI_DIM = "\033[2m"
114
+ ANSI_UNDERLINE = "\033[4m"
115
+ ANSI_RED = "\033[31m"
116
+ ANSI_GREEN = "\033[32m"
117
+ ANSI_YELLOW = "\033[33m"
118
+ ANSI_BLUE = "\033[34m"
119
+ ANSI_MAGENTA = "\033[35m"
120
+ ANSI_CYAN = "\033[36m"
121
+ ANSI_BRIGHT_RED = "\033[91m"
122
+ ANSI_BRIGHT_GREEN = "\033[92m"
123
+ ANSI_BRIGHT_YELLOW = "\033[93m"
124
+ ANSI_BRIGHT_BLUE = "\033[94m"
125
+ ANSI_BRIGHT_MAGENTA = "\033[95m"
126
+ ANSI_BRIGHT_CYAN = "\033[96m"
127
+
128
+ STYLES_ENABLED = True
129
+
130
+
131
+ def _color(text: str, color: str) -> str:
132
+ if not STYLES_ENABLED:
133
+ return text
134
+ return f"{color}{text}{ANSI_RESET}"
135
+
136
+
137
+ def _underline(text: str) -> str:
138
+ if not STYLES_ENABLED:
139
+ return text
140
+ return f"{ANSI_UNDERLINE}{text}{ANSI_RESET}"
141
+
142
+
143
+ def _section(title: str, color: str = ANSI_BRIGHT_CYAN, icon: str = "✦") -> str:
144
+ line = f"{icon} {title}"
145
+ return _underline(_color(line, color))
146
+
147
+
148
+ def _label(text: str) -> str:
149
+ # Inline label for key names
150
+ return _color(text, ANSI_BRIGHT_CYAN)
151
+
152
+
153
+ EXAMPLES_TEXT = r"""
154
+ Quick start examples
155
+ 1) Compile a Llama chat model for causal LM
156
+ optimum-rbln-cli --output-dir ./compiled_llama \
157
+ --model-id meta-llama/Llama-2-7b-chat-hf \
158
+ --batch-size 2 --tensor-parallel-size 4
159
+
160
+ 2) Compile with explicit class (Auto sequence classification)
161
+ optimum-rbln-cli --output-dir ./compiled_bert \
162
+ --class RBLNAutoModelForSequenceClassification \
163
+ --model-id bert-base-uncased \
164
+ --batch-size 8 --max-seq-len 512
165
+
166
+ 3) Pass nested rbln_config with dot-notation (e.g., for diffusion)
167
+ optimum-rbln-cli --output-dir ./compiled_sd \
168
+ --model-id runwayml/stable-diffusion-v1-5 \
169
+ --unet.batch_size 2 --vae.batch_size 1
170
+
171
+ Notes
172
+ - Any extra --key value pairs not defined above are collected into rbln_config
173
+ and forwarded to from_pretrained(..., rbln_config=...).
174
+ - Use --list-classes to see available RBLN classes.
175
+ - Use --show-rbln-config to see accepted rbln_config keys for the resolved class
176
+ (via --class or inferred from --model-id).
177
+ - Show this examples list anytime with: optimum-rbln-cli --examples
178
+ """
179
+
180
+
181
+ def _list_available_rbln_classes():
182
+ """Return a sorted list of (name, kind) for available RBLN classes; kind in {"Model","Pipeline","Auto"}."""
183
+ try:
184
+ # Import lazily exposed module and enumerate public names
185
+ import optimum.rbln as rbln # noqa: WPS433 (third-party import within function)
186
+
187
+ # Import bases for filtering
188
+ RBLNBaseModel = getattr(rbln, "RBLNBaseModel", None)
189
+ RBLNDiffusionMixin = getattr(rbln, "RBLNDiffusionMixin", None)
190
+
191
+ class_names = []
192
+ for name in dir(rbln):
193
+ if not name.startswith("RBLN"):
194
+ continue
195
+ try:
196
+ obj = getattr(rbln, name)
197
+ if not inspect.isclass(obj):
198
+ continue
199
+ # Exclude config classes and obvious non-user-facing bases
200
+ if name.endswith("Config") or name in {"RBLNModel", "RBLNBaseModel", "RBLNDiffusionMixin"}:
201
+ continue
202
+
203
+ # Keep only concrete models/pipelines/auto
204
+ is_model = RBLNBaseModel is not None and isinstance(obj, type) and issubclass(obj, RBLNBaseModel)
205
+ is_pipeline = (
206
+ RBLNDiffusionMixin is not None and isinstance(obj, type) and issubclass(obj, RBLNDiffusionMixin)
207
+ )
208
+ is_auto = name.startswith("RBLNAuto")
209
+ if is_model:
210
+ class_names.append((name, "Model"))
211
+ elif is_pipeline:
212
+ class_names.append((name, "Pipeline"))
213
+ elif is_auto:
214
+ class_names.append((name, "Auto"))
215
+ except Exception:
216
+ # Skip anything that errors on attribute access
217
+ continue
218
+ # Deduplicate and sort by kind then name
219
+ unique = {(n, k) for (n, k) in class_names}
220
+ return sorted(unique, key=lambda x: (x[1], x[0]))
221
+ except Exception:
222
+ return []
223
+
224
+
225
+ def _print_rbln_config_options(class_name: str):
226
+ """Inspect the RBLN config class for a given model/pipeline and print accepted rbln_config keys."""
227
+ try:
228
+ model_cls = get_rbln_model_cls(class_name)
229
+ except Exception as e:
230
+ print(f"Unknown RBLN class: {class_name}. Error: {e}", file=sys.stderr)
231
+ sys.exit(2)
232
+
233
+ # Obtain the associated config class
234
+ try:
235
+ config_cls = model_cls.get_rbln_config_class()
236
+ except Exception:
237
+ print(
238
+ f"The class '{class_name}' does not provide an associated RBLN config class.",
239
+ file=sys.stderr,
240
+ )
241
+ sys.exit(2)
242
+
243
+ # Description from both class docstring and __init__ docstring
244
+ class_doc = None
245
+ init_doc = None
246
+ try:
247
+ class_doc = inspect.getdoc(config_cls)
248
+ except Exception:
249
+ class_doc = None
250
+ try:
251
+ init_doc = inspect.getdoc(getattr(config_cls, "__init__", None))
252
+ except Exception:
253
+ init_doc = None
254
+
255
+ # Base and specific parameter sets via signature introspection
256
+ base_params = set()
257
+ try:
258
+ base_sig = inspect.signature(RBLNModelConfig.__init__)
259
+ base_params = {p.name for p in base_sig.parameters.values() if p.name not in {"self"}}
260
+ except Exception:
261
+ pass
262
+
263
+ try:
264
+ cfg_sig = inspect.signature(config_cls.__init__)
265
+ cfg_params = [p for p in cfg_sig.parameters.values() if p.name not in {"self"}]
266
+ except Exception:
267
+ cfg_params = []
268
+
269
+ # Identify submodule keys if present
270
+ submodules = []
271
+ try:
272
+ submodules = list(getattr(config_cls, "submodules", []) or [])
273
+ except Exception:
274
+ submodules = []
275
+
276
+ # Categorize parameters
277
+ common_keys = []
278
+ specific_keys = []
279
+ for p in cfg_params:
280
+ if p.kind in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL):
281
+ continue
282
+ if p.name in submodules:
283
+ continue
284
+ if p.name in base_params:
285
+ common_keys.append(p)
286
+ else:
287
+ specific_keys.append(p)
288
+
289
+ print(_section(f"RBLN class: {class_name}", ANSI_BRIGHT_BLUE, icon="🧩"))
290
+ print(_underline(_color(f"Config class: {config_cls.__name__}", ANSI_BRIGHT_CYAN)))
291
+ if class_doc:
292
+ print(_underline("\nDescription (class):"))
293
+ for line in class_doc.splitlines():
294
+ print(f" {line}")
295
+ if init_doc:
296
+ print(_underline("\nDescription (__init__):"))
297
+ for line in init_doc.splitlines():
298
+ print(f" {line}")
299
+ if submodules:
300
+ print(_underline("\nSubmodules:"))
301
+ for s in submodules:
302
+ print(f" • {s} {_color('(use nested keys like --' + s + '.batch_size 2)', ANSI_DIM)}")
303
+
304
+ # Curated: common compile-time options that live in RBLNModelConfig (non-runtime)
305
+ print(_underline("\nCommon compile-time options (in rbln_config):"))
306
+ print(" • npu: Target NPU for compilation (e.g., 'RBLN-CA25').")
307
+ print(" • tensor_parallel_size: Number of NPUs to shard the model at compile time.")
308
+
309
+ print(_underline("\nTips:"))
310
+ print(" - Pass config keys as CLI flags, e.g., --batch_size 2 --max_seq_len 4096")
311
+ print(" - Compile-time examples: --npu RBLN-CA25 --tensor_parallel_size 4")
312
+ print(" - Use dot-notation for submodules, e.g., --vision_tower.image_size 336 --language_model.batch_size 1")
313
+ print(" - To see examples: optimum-rbln-cli --examples")
314
+
315
+
316
+ def _read_json_from_model_id(
317
+ model_id: str,
318
+ filename: str,
319
+ *,
320
+ hf_token: Optional[str] = None,
321
+ hf_revision: Optional[str] = None,
322
+ hf_cache_dir: Optional[str] = None,
323
+ hf_force_download: bool = False,
324
+ hf_local_files_only: bool = False,
325
+ ) -> Optional[dict]:
326
+ """Read a JSON file (e.g., config.json or model_index.json) from a local path or the HuggingFace Hub.
327
+
328
+ Args:
329
+ model_id: Local directory path or HuggingFace Hub repo id
330
+ filename: Name of the JSON file to read
331
+
332
+ Returns:
333
+ Parsed JSON dictionary if found, else None
334
+ """
335
+ # Local directory
336
+ local_dir = Path(model_id)
337
+ if local_dir.exists() and local_dir.is_dir():
338
+ local_file = local_dir / filename
339
+ if local_file.exists():
340
+ try:
341
+ with local_file.open("r", encoding="utf-8") as f:
342
+ return json.load(f)
343
+ except Exception:
344
+ return None
345
+
346
+ # HuggingFace Hub
347
+ try:
348
+ downloaded_path = hf_hub_download(
349
+ repo_id=model_id,
350
+ filename=filename,
351
+ revision=hf_revision,
352
+ token=hf_token,
353
+ cache_dir=hf_cache_dir,
354
+ force_download=hf_force_download,
355
+ local_files_only=hf_local_files_only,
356
+ )
357
+ p = Path(downloaded_path)
358
+ if p.exists():
359
+ with p.open("r", encoding="utf-8") as f:
360
+ return json.load(f)
361
+ except Exception:
362
+ return None
363
+
364
+ return None
365
+
366
+
367
+ def _infer_rbln_class_from_model_id(
368
+ model_id: str,
369
+ *,
370
+ hf_token: Optional[str] = None,
371
+ hf_revision: Optional[str] = None,
372
+ hf_cache_dir: Optional[str] = None,
373
+ hf_force_download: bool = False,
374
+ hf_local_files_only: bool = False,
375
+ ) -> Optional[str]:
376
+ """Infer RBLN class name from model files by prefixing discovered class with 'RBLN'.
377
+
378
+ Order of precedence:
379
+ 1) model_index.json['pipeline'] -> e.g., 'StableDiffusionPipeline' -> 'RBLNStableDiffusionPipeline'
380
+ 2) config.json['architectures'][0] -> e.g., 'LlamaForCausalLM' -> 'RBLNLlamaForCausalLM'
381
+ """
382
+ # 1) Diffusers-style pipeline
383
+ model_index = _read_json_from_model_id(
384
+ model_id,
385
+ "model_index.json",
386
+ hf_token=hf_token,
387
+ hf_revision=hf_revision,
388
+ hf_cache_dir=hf_cache_dir,
389
+ hf_force_download=hf_force_download,
390
+ hf_local_files_only=hf_local_files_only,
391
+ )
392
+ if isinstance(model_index, dict):
393
+ pipeline_cls = model_index.get("_class_name")
394
+ if isinstance(pipeline_cls, str) and pipeline_cls:
395
+ return f"RBLN{pipeline_cls}"
396
+
397
+ # 2) Transformers config architectures
398
+ cfg = _read_json_from_model_id(
399
+ model_id,
400
+ "config.json",
401
+ hf_token=hf_token,
402
+ hf_revision=hf_revision,
403
+ hf_cache_dir=hf_cache_dir,
404
+ hf_force_download=hf_force_download,
405
+ hf_local_files_only=hf_local_files_only,
406
+ )
407
+ if isinstance(cfg, dict):
408
+ architectures = cfg.get("architectures")
409
+ if isinstance(architectures, list) and architectures:
410
+ arch0 = architectures[0]
411
+ if isinstance(arch0, str) and arch0:
412
+ return f"RBLN{arch0}"
413
+
414
+ return None
415
+
416
+
417
+ def main():
418
+ """
419
+ Main CLI function for optimum-rbln model compilation.
420
+ """
421
+ # Pre-parse lightweight flags that should work without other required args
422
+ pre_parser = argparse.ArgumentParser(add_help=False)
423
+ pre_parser.add_argument("--list-classes", action="store_true", help="List available RBLN classes and exit")
424
+ pre_parser.add_argument("--examples", action="store_true", help="Show quick start examples and exit")
425
+ pre_parser.add_argument("--version", action="store_true", help="Show version and exit")
426
+ pre_parser.add_argument("--no-style", action="store_true", help="Disable ANSI styling in output")
427
+ pre_args, _ = pre_parser.parse_known_args()
428
+
429
+ if pre_args.version:
430
+ print(f"optimum-rbln-cli {__version__}")
431
+ return
432
+
433
+ # Apply style preference as early as possible
434
+ global STYLES_ENABLED
435
+ if pre_args.no_style:
436
+ STYLES_ENABLED = False
437
+
438
+ if pre_args.list_classes:
439
+ classes = _list_available_rbln_classes()
440
+ if not classes:
441
+ print(_section("No RBLN classes found", ANSI_RED, icon="✖"))
442
+ print("Please ensure the package is installed correctly.")
443
+ else:
444
+ autos = [n for n, k in classes if k == "Auto"]
445
+ models = [n for n, k in classes if k == "Model"]
446
+ pipes = [n for n, k in classes if k == "Pipeline"]
447
+ print(_section("Available RBLN classes (use with --class)", ANSI_BRIGHT_BLUE, icon="📚"))
448
+ if autos:
449
+ print(_underline(_color("\nAuto classes:", ANSI_BRIGHT_YELLOW)))
450
+ for name in autos:
451
+ print(f" • {name}")
452
+ if models:
453
+ print(_underline(_color("\nModels:", ANSI_BRIGHT_GREEN)))
454
+ for name in models:
455
+ print(f" • {name}")
456
+ if pipes:
457
+ print(_underline(_color("\nPipelines:", ANSI_BRIGHT_MAGENTA)))
458
+ for name in pipes:
459
+ print(f" • {name}")
460
+ print(f"\nTotal: {_underline(str(len(classes)))}")
461
+ return
462
+
463
+ if pre_args.examples:
464
+ print(EXAMPLES_TEXT)
465
+ return
466
+
467
+ parser = argparse.ArgumentParser(
468
+ description=(
469
+ "Compile and export HuggingFace models/pipelines for RBLN devices.\n\n"
470
+ "Required: --model-id.\n"
471
+ "Additional --key value pairs are forwarded to rbln_config.\n"
472
+ "Use dot-notation for nested fields (e.g., --unet.batch_size 2)."
473
+ ),
474
+ formatter_class=argparse.RawDescriptionHelpFormatter,
475
+ epilog=EXAMPLES_TEXT,
476
+ )
477
+
478
+ parser.add_argument(
479
+ "--model-id",
480
+ dest="model_id",
481
+ type=str,
482
+ required=True,
483
+ help="Model ID from HuggingFace Hub or local directory path",
484
+ )
485
+
486
+ # Optional output directory argument (defaults to ./rbln_out)
487
+ parser.add_argument(
488
+ "-o",
489
+ "--output-dir",
490
+ dest="output_dir",
491
+ type=str,
492
+ default="./rbln_out",
493
+ help="Directory where the compiled model will be saved (default: ./rbln_out)",
494
+ )
495
+
496
+ # Optional class argument (can be inferred)
497
+ parser.add_argument(
498
+ "--class",
499
+ dest="model_class",
500
+ type=str,
501
+ required=False,
502
+ help=(
503
+ "RBLN model class to use for compilation (e.g., RBLNLlamaForCausalLM, RBLNAutoModelForCausalLM). "
504
+ "If omitted, it will be inferred from model_id by reading model_index.json or config.json."
505
+ ),
506
+ )
507
+
508
+ # Optional flag to show rbln_config for the resolved class (no compilation)
509
+ parser.add_argument(
510
+ "--show-rbln-config",
511
+ dest="show_rbln_config",
512
+ action="store_true",
513
+ help="Show rbln_config keys for the resolved RBLN class (via --class or inferred from --model-id) and exit",
514
+ )
515
+
516
+ # Standard --version that integrates with argparse (works after full parse)
517
+ parser.add_argument(
518
+ "--version",
519
+ action="version",
520
+ version=f"%(prog)s {__version__}",
521
+ help="Show version and exit",
522
+ )
523
+ parser.add_argument("--no-style", action="store_true", help="Disable ANSI styling in output")
524
+
525
+ # HuggingFace Hub access options
526
+ parser.add_argument(
527
+ "--hf-token",
528
+ dest="hf_token",
529
+ type=str,
530
+ default=None,
531
+ help="HuggingFace token to access private repositories",
532
+ )
533
+ parser.add_argument(
534
+ "--hf-revision",
535
+ dest="hf_revision",
536
+ type=str,
537
+ default=None,
538
+ help="Specific model revision to download (branch, tag, or commit)",
539
+ )
540
+ parser.add_argument(
541
+ "--hf-cache-dir",
542
+ dest="hf_cache_dir",
543
+ type=str,
544
+ default=None,
545
+ help="Directory to use as HuggingFace download cache",
546
+ )
547
+ parser.add_argument(
548
+ "--hf-force-download",
549
+ dest="hf_force_download",
550
+ action="store_true",
551
+ help="Force redownload of files from the HuggingFace Hub",
552
+ )
553
+ parser.add_argument(
554
+ "--hf-local-files-only",
555
+ dest="hf_local_files_only",
556
+ action="store_true",
557
+ help="Only use local files and do not attempt to download from the network",
558
+ )
559
+ # All other arguments will be parsed dynamically and passed to from_pretrained
560
+
561
+ # Print help with examples when no args were provided
562
+ if len(sys.argv) == 1:
563
+ parser.print_help()
564
+ sys.exit(2)
565
+
566
+ # Parse known args to allow for additional rbln_* arguments
567
+ args, unknown_args = parser.parse_known_args()
568
+
569
+ try:
570
+ # Resolve or infer model class for compilation
571
+ resolved_class_name: Optional[str] = args.model_class
572
+ if not resolved_class_name:
573
+ resolved_class_name = _infer_rbln_class_from_model_id(
574
+ args.model_id,
575
+ hf_token=args.hf_token,
576
+ hf_revision=args.hf_revision,
577
+ hf_cache_dir=args.hf_cache_dir,
578
+ hf_force_download=args.hf_force_download,
579
+ hf_local_files_only=args.hf_local_files_only,
580
+ )
581
+ if not resolved_class_name:
582
+ print(
583
+ "Could not infer RBLN class from model files. Please specify --class explicitly.",
584
+ file=sys.stderr,
585
+ )
586
+ sys.exit(2)
587
+
588
+ if args.show_rbln_config:
589
+ _print_rbln_config_options(resolved_class_name)
590
+ return
591
+
592
+ # Get the model class using the utility function (with helpful error)
593
+ try:
594
+ model_class = get_rbln_model_cls(resolved_class_name)
595
+ except AttributeError:
596
+ print(
597
+ f"Unknown RBLN class: {resolved_class_name}.\n"
598
+ "Run 'optimum-rbln-cli --list-classes' to see available classes.",
599
+ file=sys.stderr,
600
+ )
601
+ sys.exit(2)
602
+
603
+ # Create output directory
604
+ output_path = Path(args.output_dir)
605
+ output_path.mkdir(parents=True, exist_ok=True)
606
+
607
+ # Prepare rbln_config by parsing all unknown arguments
608
+ rbln_config = {}
609
+
610
+ # Parse all unknown arguments
611
+ i = 0
612
+ while i < len(unknown_args):
613
+ arg = unknown_args[i]
614
+ if arg.startswith("--"):
615
+ arg_name = arg[2:].replace("-", "_")
616
+ if i + 1 < len(unknown_args) and not unknown_args[i + 1].startswith("--"):
617
+ # Has a value
618
+ arg_value = unknown_args[i + 1]
619
+ parsed_value = parse_value(arg_value)
620
+
621
+ # Check if this is a nested config argument (contains dots)
622
+ if "." in arg_name:
623
+ set_nested_dict(rbln_config, arg_name, parsed_value)
624
+ else:
625
+ rbln_config[arg_name] = parsed_value
626
+ i += 2
627
+ else:
628
+ # Boolean flag
629
+ if "." in arg_name:
630
+ set_nested_dict(rbln_config, arg_name, True)
631
+ else:
632
+ rbln_config[arg_name] = True
633
+ i += 1
634
+ else:
635
+ i += 1
636
+
637
+ # Set create_runtimes to False by default for CLI compilation if not specified
638
+ create_runtimes = rbln_config.pop("create_runtimes", False)
639
+
640
+ print(_section("Starting compilation", ANSI_BRIGHT_BLUE, icon="🚀"))
641
+ print(f"{_label('Model:')} {args.model_id}")
642
+ print(f"{_label('Class:')} {resolved_class_name}")
643
+ print(f"{_label('Output:')} {output_path.absolute()}")
644
+ print(f"{_label('rbln_config:')} {json.dumps(rbln_config, indent=2, ensure_ascii=False)}")
645
+
646
+ with ContextRblnConfig(create_runtimes=create_runtimes):
647
+ _ = model_class.from_pretrained(
648
+ args.model_id, export=True, model_save_dir=str(output_path), rbln_config=rbln_config
649
+ )
650
+
651
+ print(_section("Model compilation completed successfully", ANSI_BRIGHT_GREEN, icon="✅"))
652
+ print(f"Saved to: {output_path.absolute()}")
653
+
654
+ except Exception as e:
655
+ print(f"❌ Error during model compilation: {e}", file=sys.stderr)
656
+ sys.exit(1)
657
+
658
+
659
+ if __name__ == "__main__":
660
+ main()