optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. optimum/rbln/__init__.py +156 -36
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +85 -75
  31. optimum/rbln/transformers/__init__.py +79 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +96 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  73. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  74. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  75. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  76. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  77. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  78. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  79. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  80. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  81. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  82. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  83. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  84. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  85. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  86. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  87. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  88. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  89. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  90. optimum/rbln/utils/submodule.py +26 -43
  91. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
  92. optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
  93. optimum/rbln/modeling_config.py +0 -310
  94. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  95. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,772 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import inspect
17
+ import json
18
+ from dataclasses import asdict, dataclass
19
+ from pathlib import Path
20
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
21
+
22
+ import rebel
23
+ import torch
24
+
25
+ from .__version__ import __version__
26
+ from .utils.logging import get_logger
27
+ from .utils.runtime_utils import ContextRblnConfig
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
34
+ DEFAULT_MOD_NAME = "default"
35
+ TypeInputInfo = List[Tuple[str, Tuple[int], str]]
36
+
37
+
38
+ @dataclass
39
+ class RBLNCompileConfig:
40
+ """
41
+ Configuration for RBLN compilation.
42
+
43
+ Attributes:
44
+ compiled_model_name (str): Name of the compiled model.
45
+ mod_name (str): Name of the RBLN module.
46
+ input_info (Union[List[TypeInputInfo], TypeInputInfo]): Information about input tensors.
47
+ fusion (Optional[bool]): Whether to use fusion optimization.
48
+ npu (Optional[str]): NPU configuration.
49
+ tensor_parallel_size (Optional[int]): Size for tensor parallelism.
50
+ """
51
+
52
+ compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
53
+ mod_name: str = DEFAULT_MOD_NAME
54
+ input_info: Union[List[TypeInputInfo], TypeInputInfo] = None
55
+ fusion: Optional[bool] = None
56
+ npu: Optional[str] = None
57
+ tensor_parallel_size: Optional[int] = None
58
+
59
+ @staticmethod
60
+ def normalize_dtype(dtype):
61
+ """
62
+ Convert framework-specific dtype to string representation.
63
+ i.e. torch.float32 -> "float32"
64
+
65
+ Args:
66
+ dtype: The input dtype (can be string, torch dtype, or numpy dtype).
67
+
68
+ Returns:
69
+ str: The normalized string representation of the dtype.
70
+ """
71
+ if isinstance(dtype, str):
72
+ return dtype
73
+ else:
74
+ dtype: str = repr(dtype).split(".")[-1]
75
+ if dtype.endswith("'>"): # numpy
76
+ dtype = dtype[:-2]
77
+ return dtype
78
+
79
+ @property
80
+ def is_multiple_input_info(self) -> bool:
81
+ def is_valid_input_info(input_info):
82
+ if not isinstance(input_info, list):
83
+ return False
84
+ return all(
85
+ isinstance(item, (tuple, list))
86
+ and len(item) == 3
87
+ and isinstance(item[0], str) # name
88
+ and isinstance(item[1], (tuple, list)) # shape
89
+ and all(isinstance(x, int) for x in item[1])
90
+ and isinstance(item[2], str) # dtype
91
+ for item in input_info
92
+ )
93
+
94
+ if isinstance(self.input_info, list):
95
+ return all(is_valid_input_info(info) for info in self.input_info)
96
+ return False
97
+
98
+ def __post_init__(self):
99
+ def normalize_input_info(input_info):
100
+ return [(i[0], i[1], RBLNCompileConfig.normalize_dtype(i[2]) or "float32") for i in input_info]
101
+
102
+ if self.is_multiple_input_info:
103
+ self.input_info = [normalize_input_info(info) for info in self.input_info]
104
+ else:
105
+ self.input_info = normalize_input_info(self.input_info)
106
+
107
+ def update(self, kwargs: Dict[str, Any]):
108
+ self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
109
+ self.mod_name = kwargs.get("mod_name", self.mod_name)
110
+ self.input_info = kwargs.get("input_info", self.input_info)
111
+ self.fusion = kwargs.get("fusion", self.fusion)
112
+ self.npu = kwargs.get("npu", self.npu)
113
+ self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
114
+ return self
115
+
116
+ def get_dummy_inputs(
117
+ self, fill=0, static_tensors: Dict[str, torch.Tensor] = {}, meta_tensor_names: List[str] = []
118
+ ):
119
+ dummy = []
120
+ for name, shape, dtype in self.input_info:
121
+ if name in static_tensors:
122
+ tensor = static_tensors[name]
123
+ if shape != list(tensor.shape):
124
+ raise RuntimeError(f"Different shape for dummy inputs. ({shape} != {list(tensor.shape)})")
125
+ if getattr(torch, dtype) != tensor.dtype:
126
+ raise RuntimeError(f"Different dtype for dummy inputs ({dtype} != {tensor.dtype})")
127
+ dummy.append(tensor)
128
+ else:
129
+ if name in meta_tensor_names:
130
+ device = "meta"
131
+ else:
132
+ device = "cpu"
133
+
134
+ dummy.append(
135
+ torch.fill(torch.empty(*shape, dtype=getattr(torch, dtype), device=torch.device(device)), fill)
136
+ if len(shape) > 0
137
+ else torch.tensor(fill, dtype=getattr(torch, dtype), device=torch.device(device))
138
+ )
139
+ return tuple(dummy)
140
+
141
+ def asdict(self):
142
+ return asdict(self)
143
+
144
+
145
+ RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map", "activate_profiler"]
146
+
147
+
148
+ def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
149
+ path = Path(path)
150
+ if path.is_dir():
151
+ path = path / "rbln_config.json"
152
+
153
+ with open(path, "r") as jsonf:
154
+ config_file = json.load(jsonf)
155
+
156
+ if "_meta" in config_file:
157
+ is_legacy_rbln_config = True
158
+
159
+ if is_legacy_rbln_config:
160
+ raise RuntimeError(
161
+ f"`{path}` is an old version. Please recompile the model to get the latest config file."
162
+ )
163
+
164
+ cls_name = config_file["cls_name"]
165
+ cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
166
+ return cls, config_file
167
+
168
+
169
+ class RBLNAutoConfig:
170
+ def __new__(cls, **kwargs):
171
+ cls_name = kwargs.get("cls_name")
172
+ if cls_name is None:
173
+ raise ValueError("`cls_name` is required.")
174
+ cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
175
+ return cls(**kwargs)
176
+
177
+ @staticmethod
178
+ def load(
179
+ path: str,
180
+ passed_rbln_config: Optional["RBLNModelConfig"] = None,
181
+ kwargs: Optional[Dict[str, Any]] = {},
182
+ return_unused_kwargs: bool = False,
183
+ ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
184
+ """
185
+ Load RBLNModelConfig from a path.
186
+ Class name is automatically inferred from the `rbln_config.json` file.
187
+
188
+ Args:
189
+ path (str): Path to the RBLNModelConfig.
190
+ passed_rbln_config (Optional["RBLNModelConfig"]): RBLNModelConfig to pass its runtime options.
191
+
192
+ Returns:
193
+ RBLNModelConfig: The loaded RBLNModelConfig.
194
+ """
195
+ cls, config_file = load_config(path)
196
+
197
+ rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
198
+
199
+ rbln_runtime_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in RUNTIME_KEYWORDS}
200
+ rbln_kwargs = {
201
+ key[5:]: kwargs.pop(key)
202
+ for key in rbln_keys
203
+ if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
204
+ }
205
+
206
+ if len(rbln_kwargs) > 0:
207
+ raise ValueError(f"Cannot set the following arguments: {list(rbln_kwargs.keys())}")
208
+
209
+ if passed_rbln_config is not None:
210
+ config_file.update(passed_rbln_config._runtime_options)
211
+ # TODO(jongho): Reject if the passed_rbln_config has different attributes from the config_file
212
+
213
+ config_file.update(rbln_runtime_kwargs)
214
+
215
+ if return_unused_kwargs:
216
+ return cls(**config_file), kwargs
217
+ else:
218
+ return cls(**config_file)
219
+
220
+
221
+ class RBLNModelConfig:
222
+ """Base configuration class for RBLN models that handles compilation settings, runtime options, and submodules.
223
+
224
+ This class provides functionality for:
225
+ 1. Managing compilation configurations for RBLN devices
226
+ 2. Configuring runtime behavior such as device placement
227
+ 3. Handling nested configuration objects for complex model architectures
228
+ 4. Serializing and deserializing configurations
229
+
230
+ Examples:
231
+ Using with RBLNModel.from_pretrained():
232
+ ```python
233
+ from optimum.rbln import RBLNResNetForImageClassification
234
+
235
+ # Method 1: Using rbln_ prefixed arguments (recommended for simple cases)
236
+ model = RBLNResNetForImageClassification.from_pretrained(
237
+ "model_id",
238
+ export=True, # Compile the model
239
+ rbln_image_size=224,
240
+ rbln_batch_size=16,
241
+ rbln_create_runtimes=True,
242
+ rbln_device=0
243
+ )
244
+
245
+ # Method 2: Using a config dictionary
246
+ rbln_config_dict = {
247
+ "image_size": 224,
248
+ "batch_size": 16,
249
+ "create_runtimes": True
250
+ }
251
+ model = RBLNResNetForImageClassification.from_pretrained(
252
+ "model_id",
253
+ export=True,
254
+ rbln_config=rbln_config_dict
255
+ )
256
+
257
+ # Method 3: Using a RBLNModelConfig instance
258
+ from optimum.rbln import RBLNResNetForImageClassificationConfig
259
+
260
+ config = RBLNResNetForImageClassificationConfig(
261
+ image_size=224,
262
+ batch_size=16,
263
+ create_runtimes=True
264
+ )
265
+
266
+ model = RBLNResNetForImageClassification.from_pretrained(
267
+ "model_id",
268
+ export=True,
269
+ rbln_config=config
270
+ )
271
+
272
+ # Method 4: Combining a config object with override parameters
273
+ # (rbln_ prefixed parameters take precedence over rbln_config values)
274
+ model = RBLNResNetForImageClassification.from_pretrained(
275
+ "model_id",
276
+ export=True,
277
+ rbln_config=config,
278
+ rbln_image_size=320, # This overrides the value in config
279
+ rbln_device=1 # This sets a new value
280
+ )
281
+ ```
282
+
283
+
284
+ Save and load configuration:
285
+ ```python
286
+ # Save to disk
287
+ config.save("/path/to/model")
288
+
289
+ # Load configuration from disk
290
+ loaded_config = RBLNModelConfig.load("/path/to/model")
291
+
292
+ # Using AutoConfig
293
+ loaded_config = RBLNAutoConfig.load("/path/to/model")
294
+ ```
295
+
296
+
297
+ Converting between configuration formats:
298
+ ```python
299
+ # Converting a dictionary to a config instance
300
+ config_dict = {
301
+ "image_size": 224,
302
+ "batch_size": 8,
303
+ "create_runtimes": True
304
+ }
305
+ config = RBLNResNetForImageClassificationConfig(**config_dict)
306
+ ```
307
+
308
+ Configuration for language models:
309
+ ```python
310
+ from optimum.rbln import RBLNLlamaForCausalLMConfig, RBLNCompileConfig
311
+
312
+ # Configure a LLaMA for RBLN
313
+ config = RBLNLlamaForCausalLMConfig(
314
+ max_seq_len=4096,
315
+ device=[0, 1, 2, 3],
316
+ tensor_parallel_size=4 # For multi-NPU parallel inference
317
+ )
318
+ ```
319
+
320
+ Working with models that have submodules:
321
+ ```python
322
+ from optimum.rbln import RBLNLlavaNextForConditionalGeneration
323
+
324
+ # Configuring a model with submodules
325
+ # LlavaNext has a vision_tower and a language_model submodule
326
+ model = RBLNLlavaNextForConditionalGeneration.from_pretrained(
327
+ "llava-hf/llava-v1.6-mistral-7b-hf",
328
+ export=True,
329
+ rbln_config={
330
+ # Main model's (projector, which is not a submodule) configuration
331
+ "create_runtimes": True,
332
+ "device": 0,
333
+
334
+ # Submodule configurations as nested dictionaries
335
+ "vision_tower": {
336
+ "image_size": 336,
337
+ },
338
+ "language_model": {
339
+ "tensor_parallel_size": 4, # Distribute across 4 NPUs
340
+ "max_seq_len": 8192,
341
+ "use_inputs_embeds": True,
342
+ "batch_size": 1,
343
+ },
344
+ },
345
+ )
346
+ ```
347
+
348
+ Advanced multi-device deployment with tensor parallelism:
349
+ ```python
350
+ from optimum.rbln import RBLNLlamaForCausalLMConfig
351
+
352
+ # Setup a complex multi-device configuration for large language models
353
+ llm_config = RBLNLlamaForCausalLMConfig(
354
+ # Split model across 8 NPUs
355
+ tensor_parallel_size=8,
356
+
357
+ # Runtime options
358
+ device=[8, 9, 10, 11, 12, 13, 14, 15],
359
+ create_runtimes=True,
360
+ activate_profiler=True, # Enable profiling for performance analysis
361
+
362
+ # Model-specific parameters for the LLM
363
+ max_seq_len=131072,
364
+ batch_size=4,
365
+ attn_impl="flash_attn",
366
+ )
367
+ ```
368
+
369
+ Compilation without runtime creation (create_runtimes=False):
370
+ ```python
371
+ from optimum.rbln import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
372
+
373
+ # Compile a model on a machine without NPU or for later use
374
+ config = RBLNLlamaForCausalLMConfig(
375
+ create_runtimes=False, # Compile only, don't create runtime
376
+ npu="RBLN-CA25", # Specify target NPU for compilation
377
+ max_seq_len=4096,
378
+ tensor_parallel_size=4,
379
+ batch_size=1
380
+ )
381
+
382
+ # Export the model - will compile but not create runtimes
383
+ model = RBLNLlamaForCausalLM.from_pretrained(
384
+ "meta-llama/Llama-2-7b-hf",
385
+ export=True,
386
+ rbln_config=config
387
+ )
388
+
389
+ # Save the compiled model for later use on NPU
390
+ model.save_pretrained("./compiled_llama_model")
391
+
392
+ # Later, on a machine with the target NPU
393
+ inference_model = RBLNLlamaForCausalLM.from_pretrained(
394
+ "./compiled_llama_model",
395
+ rbln_create_runtimes=True, # Now create runtimes (Optional)
396
+ )
397
+ ```
398
+
399
+ Two-stage workflow with separate compilation and runtime:
400
+ ```python
401
+ from optimum.rbln import RBLNResNetForImageClassification
402
+
403
+ # Stage 1: Model engineer compiles model (can be on any machine)
404
+ def compile_model():
405
+ model = RBLNResNetForImageClassification.from_pretrained(
406
+ "microsoft/resnet-50",
407
+ export=True,
408
+ rbln_create_runtimes=False,
409
+ rbln_npu="RBLN-CA25",
410
+ rbln_image_size=224
411
+ )
412
+ model.save_pretrained("./compiled_model")
413
+ print("Model compiled and saved, ready for deployment")
414
+
415
+ # Stage 2: Deployment engineer loads model on NPU
416
+ def deploy_model():
417
+ model = RBLNResNetForImageClassification.from_pretrained(
418
+ "./compiled_model",
419
+ rbln_create_runtimes=True,
420
+ )
421
+ print("Model loaded and ready for inference")
422
+ return model
423
+ ```
424
+ """
425
+
426
+ non_save_attributes = [
427
+ "_frozen",
428
+ "_runtime_options",
429
+ "npu",
430
+ "tensor_parallel_size",
431
+ "create_runtimes",
432
+ "optimize_host_memory",
433
+ "device",
434
+ "device_map",
435
+ "activate_profiler",
436
+ ]
437
+ submodules: List[str] = []
438
+
439
+ def init_submodule_config(
440
+ self,
441
+ submodule_config_cls: Type["RBLNModelConfig"],
442
+ submodule_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
443
+ **kwargs,
444
+ ) -> "RBLNModelConfig":
445
+ """
446
+ Initialize a submodule config from a dict or a RBLNModelConfig.
447
+
448
+ kwargs is specified from the predecessor config.
449
+ """
450
+ if submodule_config is None:
451
+ submodule_config = {}
452
+
453
+ if isinstance(submodule_config, dict):
454
+ from_predecessor = self._runtime_options.copy()
455
+ from_predecessor.update(kwargs)
456
+ init_kwargs = from_predecessor
457
+ init_kwargs.update(submodule_config)
458
+ submodule_config = submodule_config_cls(**init_kwargs)
459
+
460
+ if not isinstance(submodule_config, submodule_config_cls):
461
+ raise TypeError(f"Invalid submodule config type: {type(submodule_config)}")
462
+
463
+ return submodule_config
464
+
465
+ def __setattr__(self, key, value):
466
+ if key != "_attributes_map" and key not in self.non_save_attributes:
467
+ self._attributes_map[key] = value
468
+
469
+ if hasattr(self, "_frozen") and self._frozen:
470
+ if not hasattr(self, key) or getattr(self, key) != value:
471
+ raise RuntimeError(
472
+ f"`{self.__class__.__name__}` is frozen. Cannot update or set attribute after freezing."
473
+ )
474
+
475
+ # If the submodule is a dict, Instantiate the submodule config class
476
+ if key in self.submodules and isinstance(value, dict) and (cls_name := value.get("cls_name")):
477
+ rbln_config_cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
478
+ value = rbln_config_cls(**value)
479
+
480
+ # Forbid setting keyword-only arguments
481
+ # keyword-only arguments should be translated to other attributes, not set directly
482
+ _keyword_only_args = set()
483
+ init_signature = inspect.signature(self.__class__.__init__)
484
+ for param_name, param in init_signature.parameters.items():
485
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
486
+ _keyword_only_args.add(param_name)
487
+
488
+ if key in _keyword_only_args:
489
+ raise AttributeError(
490
+ f"Cannot set attribute '{key}'. This is an internal error. Please report it to the developers."
491
+ )
492
+
493
+ super().__setattr__(key, value)
494
+
495
+ def __init__(
496
+ self,
497
+ cls_name: Optional[str] = None,
498
+ create_runtimes: Optional[bool] = None,
499
+ optimize_host_memory: Optional[bool] = None,
500
+ device: Optional[Union[int, List[int]]] = None,
501
+ device_map: Optional[Dict[str, Union[int, List[int]]]] = None,
502
+ activate_profiler: Optional[bool] = None,
503
+ npu: Optional[str] = None,
504
+ tensor_parallel_size: Optional[int] = None,
505
+ optimum_rbln_version: Optional[str] = None,
506
+ _compile_cfgs: List[RBLNCompileConfig] = [],
507
+ **kwargs,
508
+ ):
509
+ """
510
+ Initialize a RBLN model configuration with runtime options and compile configurations.
511
+
512
+ Args:
513
+ cls_name (Optional[str]): The class name of the configuration. Defaults to the current class name.
514
+ create_runtimes (Optional[bool]): Whether to create RBLN runtimes. Defaults to True if an NPU is available.
515
+ optimize_host_memory (Optional[bool]): Whether to optimize host memory usage. Defaults to True.
516
+ device (Optional[Union[int, List[int]]]): The device(s) to load the model onto. Can be a single device ID or a list.
517
+ device_map (Optional[Dict[str, Union[int, List[int]]]]): Mapping from compiled model names to device IDs.
518
+ activate_profiler (Optional[bool]): Whether to activate the profiler for performance analysis.
519
+ npu (Optional[str]): The NPU device name to use for compilation.
520
+ tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
521
+ optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
522
+ _compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
523
+ **kwargs: Additional keyword arguments.
524
+
525
+ Raises:
526
+ ValueError: If unexpected keyword arguments are provided.
527
+
528
+
529
+ """
530
+ self._attributes_map = {}
531
+ self._frozen = False
532
+
533
+ self.cls_name = cls_name
534
+ if self.cls_name is None:
535
+ self.cls_name = self.__class__.__name__
536
+
537
+ self._runtime_options = {}
538
+ self._runtime_options["create_runtimes"] = create_runtimes
539
+ self._runtime_options["optimize_host_memory"] = optimize_host_memory
540
+ self._runtime_options["device"] = device
541
+ self._runtime_options["device_map"] = device_map
542
+ self._runtime_options["activate_profiler"] = activate_profiler
543
+
544
+ # Automatically pass npu, tensor_parallel_size to compile_cfgs
545
+ self.npu = npu
546
+ self.tensor_parallel_size = tensor_parallel_size
547
+
548
+ self.optimum_rbln_version = optimum_rbln_version
549
+ if self.optimum_rbln_version is None:
550
+ self.optimum_rbln_version = __version__
551
+
552
+ self._compile_cfgs: List[RBLNCompileConfig] = _compile_cfgs
553
+
554
+ if not isinstance(self._compile_cfgs, list):
555
+ raise ValueError("`compile_cfgs` must be a list of `RBLNCompileConfig`.")
556
+ if len(self._compile_cfgs) > 0 and not isinstance(self._compile_cfgs[0], RBLNCompileConfig):
557
+ self.set_compile_cfgs([RBLNCompileConfig(**cfg) for cfg in self._compile_cfgs])
558
+
559
+ if len(kwargs) > 0:
560
+ raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
561
+
562
+ @property
563
+ def rbln_model_cls_name(self) -> str:
564
+ return self.__class__.__name__[:-6]
565
+
566
+ @property
567
+ def rbln_model_cls(self) -> Type:
568
+ rbln_model_cls = getattr(importlib.import_module("optimum.rbln"), self.rbln_model_cls_name, None)
569
+ if rbln_model_cls is None:
570
+ raise ValueError(
571
+ f"RBLN model class {self.rbln_model_cls_name} not found. This is an internal error. "
572
+ "Please report it to the developers."
573
+ )
574
+ return rbln_model_cls
575
+
576
+ def _prepare_for_serialization(self):
577
+ """
578
+ Prepare the attributes map for serialization by converting nested RBLNModelConfig
579
+ objects to their serializable form.
580
+ """
581
+ serializable_map = {}
582
+ for key, value in self._attributes_map.items():
583
+ if isinstance(value, RBLNModelConfig):
584
+ # Convert nested RBLNModelConfig to its serializable form
585
+ serializable_map[key] = value._prepare_for_serialization()
586
+ elif key == "_compile_cfgs":
587
+ serializable_map[key] = [cfg.asdict() for cfg in value]
588
+ else:
589
+ serializable_map[key] = value
590
+ return serializable_map
591
+
592
+ def __repr__(self):
593
+ repr_dict = self._prepare_for_serialization()
594
+ return json.dumps(repr_dict, indent=2)
595
+
596
+ @property
597
+ def compile_cfgs(self):
598
+ return self._compile_cfgs
599
+
600
+ @compile_cfgs.setter
601
+ def compile_cfgs(self, compile_cfgs: List[RBLNCompileConfig]):
602
+ raise RuntimeError("`compile_cfgs` cannot be set directly. Please use `set_compile_cfgs` instead.")
603
+
604
+ def set_compile_cfgs(self, compile_cfgs: List[RBLNCompileConfig]):
605
+ if not isinstance(compile_cfgs, list):
606
+ raise ValueError("`compile_cfgs` must be a list of `RBLNCompileConfig`.")
607
+ if len(compile_cfgs) == 0:
608
+ raise ValueError("`compile_cfgs` must contain at least one `RBLNCompileConfig`.")
609
+ if not isinstance(compile_cfgs[0], RBLNCompileConfig):
610
+ raise ValueError("`compile_cfgs` must contain only `RBLNCompileConfig`.")
611
+
612
+ self._compile_cfgs = compile_cfgs
613
+ for compile_cfg in self._compile_cfgs:
614
+ compile_cfg.npu = self.npu
615
+ compile_cfg.tensor_parallel_size = self.tensor_parallel_size
616
+
617
+ def freeze(self):
618
+ if self._frozen:
619
+ raise RuntimeError(f"`{self.__class__.__name__}` is already frozen.")
620
+
621
+ if (
622
+ not isinstance(self._compile_cfgs, list)
623
+ or len(self._compile_cfgs) == 0
624
+ or not all(isinstance(cfg, RBLNCompileConfig) for cfg in self._compile_cfgs)
625
+ ):
626
+ raise RuntimeError("`compile_cfgs` must be set before freezing.")
627
+
628
+ for submodule_name in self.submodules:
629
+ submodule_config = getattr(self, submodule_name, None)
630
+ if not isinstance(submodule_config, RBLNModelConfig):
631
+ raise ValueError(f"`{submodule_name}` must be an instance of `RBLNModelConfig` before freezing.")
632
+
633
+ if not submodule_config.is_frozen():
634
+ raise ValueError(f"`{submodule_name}` config must be frozen before freezing super config.")
635
+
636
+ self._frozen = True
637
+
638
+ def is_frozen(self):
639
+ return self._frozen
640
+
641
+ def save(self, path: str):
642
+ if not self._frozen:
643
+ raise RuntimeError("`RBLNModelConfig` is not frozen. Please call `set_compile_cfgs` first.")
644
+
645
+ # save as json file without runtime attributes
646
+ path = Path(path)
647
+ if path.is_dir():
648
+ path = path / "rbln_config.json"
649
+
650
+ with open(path, "w") as jsonf:
651
+ serializable_data = self._prepare_for_serialization()
652
+ json.dump(serializable_data, jsonf, indent=2)
653
+
654
+ @classmethod
655
+ def load(cls, path: str, **kwargs) -> "RBLNModelConfig":
656
+ """
657
+ Load a RBLNModelConfig from a path.
658
+
659
+ Args:
660
+ path (str): Path to the RBLNModelConfig file or directory containing the config file.
661
+ **kwargs: Additional keyword arguments to override configuration values.
662
+ Keys starting with 'rbln_' will have the prefix removed and be used
663
+ to update the configuration.
664
+
665
+ Returns:
666
+ RBLNModelConfig: The loaded configuration instance.
667
+
668
+ Note:
669
+ This method loads the configuration from the specified path and applies any
670
+ provided overrides. If the loaded configuration class doesn't match the expected
671
+ class, a warning will be logged.
672
+ """
673
+ cls_reserved, config_file = load_config(path)
674
+
675
+ if cls_reserved != cls:
676
+ logger.warning(f"Expected {cls.__name__}, but got {cls_reserved.__name__}.")
677
+
678
+ rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
679
+ rbln_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys}
680
+ config_file.update(rbln_kwargs)
681
+
682
+ return cls(**config_file)
683
+
684
+ @classmethod
685
+ def initialize_from_kwargs(
686
+ cls: Type["RBLNModelConfig"],
687
+ rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
688
+ **kwargs,
689
+ ) -> Tuple["RBLNModelConfig", Dict[str, Any]]:
690
+ """
691
+ Initialize RBLNModelConfig from kwargs.
692
+ """
693
+ kwargs_keys = list(kwargs.keys())
694
+ rbln_kwargs = {key[5:]: kwargs.pop(key) for key in kwargs_keys if key.startswith("rbln_")}
695
+
696
+ if isinstance(rbln_config, dict):
697
+ rbln_config.update(rbln_kwargs)
698
+ rbln_config = cls(**rbln_config)
699
+
700
+ elif rbln_config is None:
701
+ rbln_config = cls(**rbln_kwargs)
702
+
703
+ elif isinstance(rbln_config, RBLNModelConfig):
704
+ for key, value in rbln_kwargs.items():
705
+ setattr(rbln_config, key, value)
706
+
707
+ return rbln_config, kwargs
708
+
709
+ @property
710
+ def create_runtimes(self):
711
+ context = ContextRblnConfig.get_current_context()["create_runtimes"]
712
+ if context is not None:
713
+ return context
714
+ elif self._runtime_options["create_runtimes"] is None:
715
+ return rebel.npu_is_available()
716
+ return self._runtime_options["create_runtimes"]
717
+
718
+ @create_runtimes.setter
719
+ def create_runtimes(self, create_runtimes: bool):
720
+ self._runtime_options["create_runtimes"] = create_runtimes
721
+
722
+ @property
723
+ def optimize_host_memory(self):
724
+ context = ContextRblnConfig.get_current_context()["optimize_host_memory"]
725
+ if context is not None:
726
+ return context
727
+ elif self._runtime_options["optimize_host_memory"] is None:
728
+ return True
729
+ return self._runtime_options["optimize_host_memory"]
730
+
731
+ @optimize_host_memory.setter
732
+ def optimize_host_memory(self, optimize_host_memory: bool):
733
+ self._runtime_options["optimize_host_memory"] = optimize_host_memory
734
+
735
+ @property
736
+ def device(self):
737
+ context = ContextRblnConfig.get_current_context()["device"]
738
+ if context is not None:
739
+ return context
740
+ return self._runtime_options["device"]
741
+
742
+ @device.setter
743
+ def device(self, device: Union[int, List[int]]):
744
+ self._runtime_options["device"] = device
745
+
746
+ @property
747
+ def device_map(self):
748
+ context = ContextRblnConfig.get_current_context()["device_map"]
749
+ if context:
750
+ return context
751
+ elif self._runtime_options["device_map"] is None:
752
+ rbln_device_map = {}
753
+ device_val = self.device
754
+ for cfg in self.compile_cfgs:
755
+ rbln_device_map[cfg.compiled_model_name] = device_val
756
+ return rbln_device_map
757
+ return self._runtime_options["device_map"]
758
+
759
+ @device_map.setter
760
+ def device_map(self, device_map: Dict[str, Union[int, List[int]]]):
761
+ self._runtime_options["device_map"] = device_map
762
+
763
+ @property
764
+ def activate_profiler(self):
765
+ context = ContextRblnConfig.get_current_context()["activate_profiler"]
766
+ if context is not None:
767
+ return context
768
+ return self._runtime_options["activate_profiler"]
769
+
770
+ @activate_profiler.setter
771
+ def activate_profiler(self, activate_profiler: bool):
772
+ self._runtime_options["activate_profiler"] = activate_profiler