optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,816 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import inspect
17
+ import json
18
+ from dataclasses import asdict, dataclass
19
+ from pathlib import Path
20
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
21
+
22
+ import rebel
23
+ import torch
24
+
25
+ from .__version__ import __version__
26
+ from .utils.logging import get_logger
27
+ from .utils.runtime_utils import ContextRblnConfig
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
34
+ DEFAULT_MOD_NAME = "default"
35
+ TypeInputInfo = List[Tuple[str, Tuple[int], str]]
36
+
37
+
38
+ @dataclass
39
+ class RBLNCompileConfig:
40
+ """
41
+ Configuration for RBLN compilation.
42
+
43
+ Attributes:
44
+ compiled_model_name (str): Name of the compiled model.
45
+ mod_name (str): Name of the RBLN module.
46
+ input_info (Union[List[TypeInputInfo], TypeInputInfo]): Information about input tensors.
47
+ fusion (Optional[bool]): Whether to use fusion optimization.
48
+ npu (Optional[str]): NPU configuration.
49
+ tensor_parallel_size (Optional[int]): Size for tensor parallelism.
50
+ """
51
+
52
+ compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
53
+ mod_name: str = DEFAULT_MOD_NAME
54
+ input_info: Union[List[TypeInputInfo], TypeInputInfo] = None
55
+ fusion: Optional[bool] = None
56
+ npu: Optional[str] = None
57
+ tensor_parallel_size: Optional[int] = None
58
+
59
+ @staticmethod
60
+ def normalize_dtype(dtype):
61
+ """
62
+ Convert framework-specific dtype to string representation.
63
+ i.e. torch.float32 -> "float32"
64
+
65
+ Args:
66
+ dtype: The input dtype (can be string, torch dtype, or numpy dtype).
67
+
68
+ Returns:
69
+ str: The normalized string representation of the dtype.
70
+ """
71
+ if isinstance(dtype, str):
72
+ return dtype
73
+ else:
74
+ dtype: str = repr(dtype).split(".")[-1]
75
+ if dtype.endswith("'>"): # numpy
76
+ dtype = dtype[:-2]
77
+ return dtype
78
+
79
+ @property
80
+ def is_multiple_input_info(self) -> bool:
81
+ def is_valid_input_info(input_info):
82
+ if not isinstance(input_info, list):
83
+ return False
84
+ return all(
85
+ isinstance(item, (tuple, list))
86
+ and len(item) == 3
87
+ and isinstance(item[0], str) # name
88
+ and isinstance(item[1], (tuple, list)) # shape
89
+ and all(isinstance(x, int) for x in item[1])
90
+ and isinstance(item[2], str) # dtype
91
+ for item in input_info
92
+ )
93
+
94
+ if isinstance(self.input_info, list):
95
+ return all(is_valid_input_info(info) for info in self.input_info)
96
+ return False
97
+
98
+ def __post_init__(self):
99
+ def normalize_input_info(input_info):
100
+ return [(i[0], i[1], RBLNCompileConfig.normalize_dtype(i[2]) or "float32") for i in input_info]
101
+
102
+ if self.is_multiple_input_info:
103
+ self.input_info = [normalize_input_info(info) for info in self.input_info]
104
+ else:
105
+ self.input_info = normalize_input_info(self.input_info)
106
+
107
+ def update(self, kwargs: Dict[str, Any]):
108
+ self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
109
+ self.mod_name = kwargs.get("mod_name", self.mod_name)
110
+ self.input_info = kwargs.get("input_info", self.input_info)
111
+ self.fusion = kwargs.get("fusion", self.fusion)
112
+ self.npu = kwargs.get("npu", self.npu)
113
+ self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
114
+ return self
115
+
116
+ def get_dummy_inputs(
117
+ self, fill=0, static_tensors: Dict[str, torch.Tensor] = {}, meta_tensor_names: List[str] = []
118
+ ):
119
+ dummy = []
120
+ for name, shape, dtype in self.input_info:
121
+ if name in static_tensors:
122
+ tensor = static_tensors[name]
123
+ if shape != list(tensor.shape):
124
+ raise RuntimeError(f"Different shape for dummy inputs. ({shape} != {list(tensor.shape)})")
125
+ if getattr(torch, dtype) != tensor.dtype:
126
+ raise RuntimeError(f"Different dtype for dummy inputs ({dtype} != {tensor.dtype})")
127
+ dummy.append(tensor)
128
+ else:
129
+ if name in meta_tensor_names:
130
+ device = "meta"
131
+ else:
132
+ device = "cpu"
133
+
134
+ dummy.append(
135
+ torch.fill(torch.empty(*shape, dtype=getattr(torch, dtype), device=torch.device(device)), fill)
136
+ if len(shape) > 0
137
+ else torch.tensor(fill, dtype=getattr(torch, dtype), device=torch.device(device))
138
+ )
139
+ return tuple(dummy)
140
+
141
+ def asdict(self):
142
+ return asdict(self)
143
+
144
+
145
+ RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map", "activate_profiler"]
146
+
147
+
148
+ def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
149
+ path = Path(path)
150
+ if path.is_dir():
151
+ path = path / "rbln_config.json"
152
+
153
+ with open(path, "r") as jsonf:
154
+ config_file = json.load(jsonf)
155
+
156
+ if "_meta" in config_file:
157
+ is_legacy_rbln_config = True
158
+
159
+ if is_legacy_rbln_config:
160
+ raise RuntimeError(
161
+ f"`{path}` is an old version. Please recompile the model to get the latest config file."
162
+ )
163
+
164
+ cls_name = config_file["cls_name"]
165
+ cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
166
+ return cls, config_file
167
+
168
+
169
+ class RBLNAutoConfig:
170
+ def __new__(cls, **kwargs):
171
+ cls_name = kwargs.get("cls_name")
172
+ if cls_name is None:
173
+ raise ValueError("`cls_name` is required.")
174
+ cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
175
+ return cls(**kwargs)
176
+
177
+ @staticmethod
178
+ def load_from_dict(config_dict: Dict[str, Any]) -> "RBLNModelConfig":
179
+ cls_name = config_dict.get("cls_name")
180
+ if cls_name is None:
181
+ raise ValueError("`cls_name` is required.")
182
+ cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
183
+ return cls(**config_dict)
184
+
185
+ @staticmethod
186
+ def load(
187
+ path: str,
188
+ passed_rbln_config: Optional["RBLNModelConfig"] = None,
189
+ kwargs: Optional[Dict[str, Any]] = {},
190
+ return_unused_kwargs: bool = False,
191
+ ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
192
+ """
193
+ Load RBLNModelConfig from a path.
194
+ Class name is automatically inferred from the `rbln_config.json` file.
195
+
196
+ Args:
197
+ path (str): Path to the RBLNModelConfig.
198
+ passed_rbln_config (Optional["RBLNModelConfig"]): RBLNModelConfig to pass its runtime options.
199
+
200
+ Returns:
201
+ RBLNModelConfig: The loaded RBLNModelConfig.
202
+ """
203
+ cls, config_file = load_config(path)
204
+
205
+ rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
206
+ rbln_runtime_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in RUNTIME_KEYWORDS}
207
+ rbln_submodule_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in cls.submodules}
208
+
209
+ rbln_kwargs = {
210
+ key[5:]: kwargs.pop(key)
211
+ for key in rbln_keys
212
+ if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
213
+ }
214
+
215
+ if len(rbln_kwargs) > 0:
216
+ raise ValueError(f"Cannot set the following arguments: {list(rbln_kwargs.keys())}")
217
+
218
+ # Process submodule's rbln_config
219
+ for submodule in cls.submodules:
220
+ if submodule not in config_file:
221
+ raise ValueError(f"Submodule {submodule} not found in rbln_config.json.")
222
+ submodule_config = config_file[submodule]
223
+ submodule_config.update(rbln_submodule_kwargs.pop(submodule, {}))
224
+ config_file[submodule] = RBLNAutoConfig.load_from_dict(submodule_config)
225
+
226
+ if passed_rbln_config is not None:
227
+ config_file.update(passed_rbln_config._runtime_options)
228
+ # TODO(jongho): Reject if the passed_rbln_config has different attributes from the config_file
229
+
230
+ config_file.update(rbln_runtime_kwargs)
231
+
232
+ if return_unused_kwargs:
233
+ return cls(**config_file), kwargs
234
+ else:
235
+ return cls(**config_file)
236
+
237
+
238
+ class RBLNModelConfig:
239
+ """Base configuration class for RBLN models that handles compilation settings, runtime options, and submodules.
240
+
241
+ This class provides functionality for:
242
+ 1. Managing compilation configurations for RBLN devices
243
+ 2. Configuring runtime behavior such as device placement
244
+ 3. Handling nested configuration objects for complex model architectures
245
+ 4. Serializing and deserializing configurations
246
+
247
+ Examples:
248
+ Using with RBLNModel.from_pretrained():
249
+ ```python
250
+ from optimum.rbln import RBLNResNetForImageClassification
251
+
252
+ # Method 1: Using rbln_ prefixed arguments (recommended for simple cases)
253
+ model = RBLNResNetForImageClassification.from_pretrained(
254
+ "model_id",
255
+ export=True, # Compile the model
256
+ rbln_image_size=224,
257
+ rbln_batch_size=16,
258
+ rbln_create_runtimes=True,
259
+ rbln_device=0
260
+ )
261
+
262
+ # Method 2: Using a config dictionary
263
+ rbln_config_dict = {
264
+ "image_size": 224,
265
+ "batch_size": 16,
266
+ "create_runtimes": True
267
+ }
268
+ model = RBLNResNetForImageClassification.from_pretrained(
269
+ "model_id",
270
+ export=True,
271
+ rbln_config=rbln_config_dict
272
+ )
273
+
274
+ # Method 3: Using a RBLNModelConfig instance
275
+ from optimum.rbln import RBLNResNetForImageClassificationConfig
276
+
277
+ config = RBLNResNetForImageClassificationConfig(
278
+ image_size=224,
279
+ batch_size=16,
280
+ create_runtimes=True
281
+ )
282
+
283
+ model = RBLNResNetForImageClassification.from_pretrained(
284
+ "model_id",
285
+ export=True,
286
+ rbln_config=config
287
+ )
288
+
289
+ # Method 4: Combining a config object with override parameters
290
+ # (rbln_ prefixed parameters take precedence over rbln_config values)
291
+ model = RBLNResNetForImageClassification.from_pretrained(
292
+ "model_id",
293
+ export=True,
294
+ rbln_config=config,
295
+ rbln_image_size=320, # This overrides the value in config
296
+ rbln_device=1 # This sets a new value
297
+ )
298
+ ```
299
+
300
+
301
+ Save and load configuration:
302
+ ```python
303
+ # Save to disk
304
+ config.save("/path/to/model")
305
+
306
+ # Load configuration from disk
307
+ loaded_config = RBLNModelConfig.load("/path/to/model")
308
+
309
+ # Using AutoConfig
310
+ loaded_config = RBLNAutoConfig.load("/path/to/model")
311
+ ```
312
+
313
+
314
+ Converting between configuration formats:
315
+ ```python
316
+ # Converting a dictionary to a config instance
317
+ config_dict = {
318
+ "image_size": 224,
319
+ "batch_size": 8,
320
+ "create_runtimes": True
321
+ }
322
+ config = RBLNResNetForImageClassificationConfig(**config_dict)
323
+ ```
324
+
325
+ Configuration for language models:
326
+ ```python
327
+ from optimum.rbln import RBLNLlamaForCausalLMConfig, RBLNCompileConfig
328
+
329
+ # Configure a LLaMA for RBLN
330
+ config = RBLNLlamaForCausalLMConfig(
331
+ max_seq_len=4096,
332
+ device=[0, 1, 2, 3],
333
+ tensor_parallel_size=4 # For multi-NPU parallel inference
334
+ )
335
+ ```
336
+
337
+ Working with models that have submodules:
338
+ ```python
339
+ from optimum.rbln import RBLNLlavaNextForConditionalGeneration
340
+
341
+ # Configuring a model with submodules
342
+ # LlavaNext has a vision_tower and a language_model submodule
343
+ model = RBLNLlavaNextForConditionalGeneration.from_pretrained(
344
+ "llava-hf/llava-v1.6-mistral-7b-hf",
345
+ export=True,
346
+ rbln_config={
347
+ # Main model's (projector, which is not a submodule) configuration
348
+ "create_runtimes": True,
349
+ "device": 0,
350
+
351
+ # Submodule configurations as nested dictionaries
352
+ "vision_tower": {
353
+ "image_size": 336,
354
+ },
355
+ "language_model": {
356
+ "tensor_parallel_size": 4, # Distribute across 4 NPUs
357
+ "max_seq_len": 8192,
358
+ "use_inputs_embeds": True,
359
+ "batch_size": 1,
360
+ },
361
+ },
362
+ )
363
+ ```
364
+
365
+ Advanced multi-device deployment with tensor parallelism:
366
+ ```python
367
+ from optimum.rbln import RBLNLlamaForCausalLMConfig
368
+
369
+ # Setup a complex multi-device configuration for large language models
370
+ llm_config = RBLNLlamaForCausalLMConfig(
371
+ # Split model across 8 NPUs
372
+ tensor_parallel_size=8,
373
+
374
+ # Runtime options
375
+ device=[8, 9, 10, 11, 12, 13, 14, 15],
376
+ create_runtimes=True,
377
+ activate_profiler=True, # Enable profiling for performance analysis
378
+
379
+ # Model-specific parameters for the LLM
380
+ max_seq_len=131072,
381
+ batch_size=4,
382
+ attn_impl="flash_attn",
383
+ )
384
+ ```
385
+
386
+ Compilation without runtime creation (create_runtimes=False):
387
+ ```python
388
+ from optimum.rbln import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
389
+
390
+ # Compile a model on a machine without NPU or for later use
391
+ config = RBLNLlamaForCausalLMConfig(
392
+ create_runtimes=False, # Compile only, don't create runtime
393
+ npu="RBLN-CA25", # Specify target NPU for compilation
394
+ max_seq_len=4096,
395
+ tensor_parallel_size=4,
396
+ batch_size=1
397
+ )
398
+
399
+ # Export the model - will compile but not create runtimes
400
+ model = RBLNLlamaForCausalLM.from_pretrained(
401
+ "meta-llama/Llama-2-7b-hf",
402
+ export=True,
403
+ rbln_config=config
404
+ )
405
+
406
+ # Save the compiled model for later use on NPU
407
+ model.save_pretrained("./compiled_llama_model")
408
+
409
+ # Later, on a machine with the target NPU
410
+ inference_model = RBLNLlamaForCausalLM.from_pretrained(
411
+ "./compiled_llama_model",
412
+ rbln_create_runtimes=True, # Now create runtimes (Optional)
413
+ )
414
+ ```
415
+
416
+ Two-stage workflow with separate compilation and runtime:
417
+ ```python
418
+ from optimum.rbln import RBLNResNetForImageClassification
419
+
420
+ # Stage 1: Model engineer compiles model (can be on any machine)
421
+ def compile_model():
422
+ model = RBLNResNetForImageClassification.from_pretrained(
423
+ "microsoft/resnet-50",
424
+ export=True,
425
+ rbln_create_runtimes=False,
426
+ rbln_npu="RBLN-CA25",
427
+ rbln_image_size=224
428
+ )
429
+ model.save_pretrained("./compiled_model")
430
+ print("Model compiled and saved, ready for deployment")
431
+
432
+ # Stage 2: Deployment engineer loads model on NPU
433
+ def deploy_model():
434
+ model = RBLNResNetForImageClassification.from_pretrained(
435
+ "./compiled_model",
436
+ rbln_create_runtimes=True,
437
+ )
438
+ print("Model loaded and ready for inference")
439
+ return model
440
+ ```
441
+ """
442
+
443
+ non_save_attributes = [
444
+ "_frozen",
445
+ "_runtime_options",
446
+ "npu",
447
+ "tensor_parallel_size",
448
+ "create_runtimes",
449
+ "optimize_host_memory",
450
+ "device",
451
+ "device_map",
452
+ "activate_profiler",
453
+ ]
454
+ submodules: List[str] = []
455
+ subclass_non_save_attributes = []
456
+
457
+ def init_submodule_config(
458
+ self,
459
+ submodule_config_cls: Type["RBLNModelConfig"],
460
+ submodule_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
461
+ **kwargs,
462
+ ) -> "RBLNModelConfig":
463
+ """
464
+ Initialize a submodule config from a dict or a RBLNModelConfig.
465
+
466
+ kwargs is specified from the predecessor config.
467
+ """
468
+ if submodule_config is None:
469
+ submodule_config = {}
470
+
471
+ if isinstance(submodule_config, dict):
472
+ from_predecessor = self._runtime_options.copy()
473
+ from_predecessor.update(kwargs)
474
+ init_kwargs = from_predecessor
475
+ init_kwargs.update(submodule_config)
476
+ submodule_config = submodule_config_cls(**init_kwargs)
477
+
478
+ if not isinstance(submodule_config, submodule_config_cls):
479
+ raise TypeError(f"Invalid submodule config type: {type(submodule_config)}")
480
+
481
+ return submodule_config
482
+
483
+ def __setattr__(self, key, value):
484
+ if (
485
+ key != "_attributes_map"
486
+ and key not in self.non_save_attributes
487
+ and key not in self.subclass_non_save_attributes
488
+ ):
489
+ self._attributes_map[key] = value
490
+
491
+ if hasattr(self, "_frozen") and self._frozen:
492
+ if not hasattr(self, key) or getattr(self, key) != value:
493
+ raise RuntimeError(
494
+ f"`{self.__class__.__name__}` is frozen. Cannot update or set attribute after freezing."
495
+ )
496
+
497
+ # If the submodule is a dict, Instantiate the submodule config class
498
+ if key in self.submodules and isinstance(value, dict) and (cls_name := value.get("cls_name")):
499
+ rbln_config_cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
500
+ value = rbln_config_cls(**value)
501
+
502
+ # Forbid setting keyword-only arguments
503
+ # keyword-only arguments should be translated to other attributes, not set directly
504
+ _keyword_only_args = set()
505
+ init_signature = inspect.signature(self.__class__.__init__)
506
+ for param_name, param in init_signature.parameters.items():
507
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
508
+ _keyword_only_args.add(param_name)
509
+
510
+ if key in _keyword_only_args:
511
+ raise AttributeError(
512
+ f"Cannot set attribute '{key}'. This is an internal error. Please report it to the developers."
513
+ )
514
+
515
+ super().__setattr__(key, value)
516
+
517
+ def __init__(
518
+ self,
519
+ cls_name: Optional[str] = None,
520
+ create_runtimes: Optional[bool] = None,
521
+ optimize_host_memory: Optional[bool] = None,
522
+ device: Optional[Union[int, List[int]]] = None,
523
+ device_map: Optional[Dict[str, Union[int, List[int]]]] = None,
524
+ activate_profiler: Optional[bool] = None,
525
+ npu: Optional[str] = None,
526
+ tensor_parallel_size: Optional[int] = None,
527
+ optimum_rbln_version: Optional[str] = None,
528
+ _compile_cfgs: List[RBLNCompileConfig] = [],
529
+ **kwargs,
530
+ ):
531
+ """
532
+ Initialize a RBLN model configuration with runtime options and compile configurations.
533
+
534
+ Args:
535
+ cls_name (Optional[str]): The class name of the configuration. Defaults to the current class name.
536
+ create_runtimes (Optional[bool]): Whether to create RBLN runtimes. Defaults to True if an NPU is available.
537
+ optimize_host_memory (Optional[bool]): Whether to optimize host memory usage. Defaults to True.
538
+ device (Optional[Union[int, List[int]]]): The device(s) to load the model onto. Can be a single device ID or a list.
539
+ device_map (Optional[Dict[str, Union[int, List[int]]]]): Mapping from compiled model names to device IDs.
540
+ activate_profiler (Optional[bool]): Whether to activate the profiler for performance analysis.
541
+ npu (Optional[str]): The NPU device name to use for compilation.
542
+ tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
543
+ optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
544
+ _compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
545
+ **kwargs: Additional keyword arguments.
546
+
547
+ Raises:
548
+ ValueError: If unexpected keyword arguments are provided.
549
+
550
+
551
+ """
552
+ self._attributes_map = {}
553
+ self._frozen = False
554
+
555
+ self.cls_name = cls_name
556
+ if self.cls_name is None:
557
+ self.cls_name = self.__class__.__name__
558
+
559
+ self._runtime_options = {}
560
+ self._runtime_options["create_runtimes"] = create_runtimes
561
+ self._runtime_options["optimize_host_memory"] = optimize_host_memory
562
+ self._runtime_options["device"] = device
563
+ self._runtime_options["device_map"] = device_map
564
+ self._runtime_options["activate_profiler"] = activate_profiler
565
+
566
+ # Automatically pass npu, tensor_parallel_size to compile_cfgs
567
+ self.npu = npu
568
+ self.tensor_parallel_size = tensor_parallel_size
569
+
570
+ self.optimum_rbln_version = optimum_rbln_version
571
+ if self.optimum_rbln_version is None:
572
+ self.optimum_rbln_version = __version__
573
+
574
+ self._compile_cfgs: List[RBLNCompileConfig] = _compile_cfgs
575
+
576
+ if not isinstance(self._compile_cfgs, list):
577
+ raise ValueError("`compile_cfgs` must be a list of `RBLNCompileConfig`.")
578
+ if len(self._compile_cfgs) > 0 and not isinstance(self._compile_cfgs[0], RBLNCompileConfig):
579
+ self.set_compile_cfgs([RBLNCompileConfig(**cfg) for cfg in self._compile_cfgs])
580
+
581
+ if len(kwargs) > 0:
582
+ raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
583
+
584
+ @property
585
+ def rbln_model_cls_name(self) -> str:
586
+ return self.__class__.__name__[:-6]
587
+
588
+ @property
589
+ def rbln_model_cls(self) -> Type:
590
+ rbln_model_cls = getattr(importlib.import_module("optimum.rbln"), self.rbln_model_cls_name, None)
591
+ if rbln_model_cls is None:
592
+ raise ValueError(
593
+ f"RBLN model class {self.rbln_model_cls_name} not found. This is an internal error. "
594
+ "Please report it to the developers."
595
+ )
596
+ return rbln_model_cls
597
+
598
+ def _prepare_for_serialization(self):
599
+ """
600
+ Prepare the attributes map for serialization by converting nested RBLNModelConfig
601
+ objects to their serializable form.
602
+ """
603
+ serializable_map = {}
604
+ for key, value in self._attributes_map.items():
605
+ if isinstance(value, RBLNModelConfig):
606
+ # Convert nested RBLNModelConfig to its serializable form
607
+ serializable_map[key] = value._prepare_for_serialization()
608
+ elif key == "_compile_cfgs":
609
+ serializable_map[key] = [cfg.asdict() for cfg in value]
610
+ else:
611
+ serializable_map[key] = value
612
+ return serializable_map
613
+
614
+ def __repr__(self):
615
+ repr_dict = self._prepare_for_serialization()
616
+ return json.dumps(repr_dict, indent=2)
617
+
618
+ @property
619
+ def compile_cfgs(self):
620
+ return self._compile_cfgs
621
+
622
+ @compile_cfgs.setter
623
+ def compile_cfgs(self, compile_cfgs: List[RBLNCompileConfig]):
624
+ raise RuntimeError("`compile_cfgs` cannot be set directly. Please use `set_compile_cfgs` instead.")
625
+
626
+ def set_compile_cfgs(self, compile_cfgs: List[RBLNCompileConfig]):
627
+ if not isinstance(compile_cfgs, list):
628
+ raise ValueError("`compile_cfgs` must be a list of `RBLNCompileConfig`.")
629
+ if len(compile_cfgs) == 0:
630
+ raise ValueError("`compile_cfgs` must contain at least one `RBLNCompileConfig`.")
631
+ if not isinstance(compile_cfgs[0], RBLNCompileConfig):
632
+ raise ValueError("`compile_cfgs` must contain only `RBLNCompileConfig`.")
633
+
634
+ self._compile_cfgs = compile_cfgs
635
+ for compile_cfg in self._compile_cfgs:
636
+ compile_cfg.npu = self.npu
637
+ compile_cfg.tensor_parallel_size = self.tensor_parallel_size
638
+
639
+ def freeze(self):
640
+ if self._frozen:
641
+ raise RuntimeError(f"`{self.__class__.__name__}` is already frozen.")
642
+
643
+ if (
644
+ not isinstance(self._compile_cfgs, list)
645
+ or len(self._compile_cfgs) == 0
646
+ or not all(isinstance(cfg, RBLNCompileConfig) for cfg in self._compile_cfgs)
647
+ ):
648
+ raise RuntimeError("`compile_cfgs` must be set before freezing.")
649
+
650
+ for submodule_name in self.submodules:
651
+ submodule_config = getattr(self, submodule_name, None)
652
+ if not isinstance(submodule_config, RBLNModelConfig):
653
+ raise ValueError(f"`{submodule_name}` must be an instance of `RBLNModelConfig` before freezing.")
654
+
655
+ if not submodule_config.is_frozen():
656
+ raise ValueError(f"`{submodule_name}` config must be frozen before freezing super config.")
657
+
658
+ self._frozen = True
659
+
660
+ def is_frozen(self):
661
+ return self._frozen
662
+
663
+ def save(self, path: str):
664
+ if not self._frozen:
665
+ raise RuntimeError("`RBLNModelConfig` is not frozen. Please call `set_compile_cfgs` first.")
666
+
667
+ # save as json file without runtime attributes
668
+ path = Path(path)
669
+ if path.is_dir():
670
+ path = path / "rbln_config.json"
671
+
672
+ with open(path, "w") as jsonf:
673
+ serializable_data = self._prepare_for_serialization()
674
+ json.dump(serializable_data, jsonf, indent=2)
675
+
676
+ @classmethod
677
+ def load(cls, path: str, **kwargs) -> "RBLNModelConfig":
678
+ """
679
+ Load a RBLNModelConfig from a path.
680
+
681
+ Args:
682
+ path (str): Path to the RBLNModelConfig file or directory containing the config file.
683
+ **kwargs: Additional keyword arguments to override configuration values.
684
+ Keys starting with 'rbln_' will have the prefix removed and be used
685
+ to update the configuration.
686
+
687
+ Returns:
688
+ RBLNModelConfig: The loaded configuration instance.
689
+
690
+ Note:
691
+ This method loads the configuration from the specified path and applies any
692
+ provided overrides. If the loaded configuration class doesn't match the expected
693
+ class, a warning will be logged.
694
+ """
695
+ cls_reserved, config_file = load_config(path)
696
+
697
+ if cls_reserved != cls:
698
+ logger.warning(f"Expected {cls.__name__}, but got {cls_reserved.__name__}.")
699
+
700
+ rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
701
+ rbln_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys}
702
+ config_file.update(rbln_kwargs)
703
+
704
+ return cls(**config_file)
705
+
706
+ @classmethod
707
+ def initialize_from_kwargs(
708
+ cls: Type["RBLNModelConfig"],
709
+ rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
710
+ **kwargs,
711
+ ) -> Tuple["RBLNModelConfig", Dict[str, Any]]:
712
+ """
713
+ Initialize RBLNModelConfig from kwargs.
714
+ """
715
+ kwargs_keys = list(kwargs.keys())
716
+ rbln_kwargs = {key[5:]: kwargs.pop(key) for key in kwargs_keys if key.startswith("rbln_")}
717
+
718
+ if isinstance(rbln_config, dict):
719
+ rbln_config.update(rbln_kwargs)
720
+ rbln_config = cls(**rbln_config)
721
+
722
+ elif rbln_config is None:
723
+ rbln_config = cls(**rbln_kwargs)
724
+
725
+ elif isinstance(rbln_config, RBLNModelConfig):
726
+ for key, value in rbln_kwargs.items():
727
+ setattr(rbln_config, key, value)
728
+
729
+ return rbln_config, kwargs
730
+
731
+ def get_default_values_for_original_cls(self, func_name: str, keys: List[str]) -> Dict[str, Any]:
732
+ """
733
+ Get default values for original class attributes from RBLNModelConfig.
734
+
735
+ Args:
736
+ func_name (str): The name of the function to get the default values for.
737
+ keys (List[str]): The keys of the attributes to get.
738
+
739
+ Returns:
740
+ Dict[str, Any]: The default values for the attributes.
741
+ """
742
+ model_cls = self.rbln_model_cls.get_hf_class()
743
+ func = getattr(model_cls, func_name)
744
+ func_signature = inspect.signature(func)
745
+ default_values = {}
746
+ for key in keys:
747
+ if key in func_signature.parameters:
748
+ default_values[key] = func_signature.parameters[key].default
749
+ else:
750
+ raise ValueError(f"Default value for `{key}` is not set for the model class.")
751
+ return default_values
752
+
753
+ @property
754
+ def create_runtimes(self):
755
+ context = ContextRblnConfig.get_current_context()["create_runtimes"]
756
+ if context is not None:
757
+ return context
758
+ elif self._runtime_options["create_runtimes"] is None:
759
+ return rebel.npu_is_available()
760
+ return self._runtime_options["create_runtimes"]
761
+
762
+ @create_runtimes.setter
763
+ def create_runtimes(self, create_runtimes: bool):
764
+ self._runtime_options["create_runtimes"] = create_runtimes
765
+
766
+ @property
767
+ def optimize_host_memory(self):
768
+ context = ContextRblnConfig.get_current_context()["optimize_host_memory"]
769
+ if context is not None:
770
+ return context
771
+ elif self._runtime_options["optimize_host_memory"] is None:
772
+ return True
773
+ return self._runtime_options["optimize_host_memory"]
774
+
775
+ @optimize_host_memory.setter
776
+ def optimize_host_memory(self, optimize_host_memory: bool):
777
+ self._runtime_options["optimize_host_memory"] = optimize_host_memory
778
+
779
+ @property
780
+ def device(self):
781
+ context = ContextRblnConfig.get_current_context()["device"]
782
+ if context is not None:
783
+ return context
784
+ return self._runtime_options["device"]
785
+
786
+ @device.setter
787
+ def device(self, device: Union[int, List[int]]):
788
+ self._runtime_options["device"] = device
789
+
790
+ @property
791
+ def device_map(self):
792
+ context = ContextRblnConfig.get_current_context()["device_map"]
793
+ if context:
794
+ return context
795
+ elif self._runtime_options["device_map"] is None:
796
+ rbln_device_map = {}
797
+ device_val = self.device
798
+ for cfg in self.compile_cfgs:
799
+ rbln_device_map[cfg.compiled_model_name] = device_val
800
+ return rbln_device_map
801
+ return self._runtime_options["device_map"]
802
+
803
+ @device_map.setter
804
+ def device_map(self, device_map: Dict[str, Union[int, List[int]]]):
805
+ self._runtime_options["device_map"] = device_map
806
+
807
+ @property
808
+ def activate_profiler(self):
809
+ context = ContextRblnConfig.get_current_context()["activate_profiler"]
810
+ if context is not None:
811
+ return context
812
+ return self._runtime_options["activate_profiler"]
813
+
814
+ @activate_profiler.setter
815
+ def activate_profiler(self, activate_profiler: bool):
816
+ self._runtime_options["activate_profiler"] = activate_profiler