optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +230 -67
  4. optimum/rbln/diffusers/models/controlnet.py +2 -2
  5. optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
  6. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
  7. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
  8. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  13. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
  14. optimum/rbln/modeling_base.py +11 -10
  15. optimum/rbln/ops/__init__.py +1 -0
  16. optimum/rbln/ops/attn.py +10 -0
  17. optimum/rbln/ops/flash_attn.py +8 -0
  18. optimum/rbln/ops/moe.py +180 -0
  19. optimum/rbln/ops/sliding_window_attn.py +9 -0
  20. optimum/rbln/transformers/__init__.py +44 -0
  21. optimum/rbln/transformers/modeling_attention_utils.py +124 -222
  22. optimum/rbln/transformers/modeling_outputs.py +25 -0
  23. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  24. optimum/rbln/transformers/models/__init__.py +38 -0
  25. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  27. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
  28. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
  29. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  30. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  31. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  32. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
  33. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  34. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  35. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
  36. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  37. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
  38. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  39. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
  40. optimum/rbln/transformers/models/detr/__init__.py +23 -0
  41. optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
  42. optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
  43. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  44. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  45. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  46. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  47. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  48. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  49. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
  50. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  51. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  53. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  54. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
  55. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  56. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
  57. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  58. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  59. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  60. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  61. optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
  62. optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
  63. optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
  64. optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
  65. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  66. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  67. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  68. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  69. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  70. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  71. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  76. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  77. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  78. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  79. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  80. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  81. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
  82. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  83. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  85. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  86. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  87. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  88. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  89. optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
  90. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  91. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  92. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  94. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  95. optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
  96. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  97. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  98. optimum/rbln/utils/deprecation.py +78 -1
  99. optimum/rbln/utils/hub.py +93 -2
  100. optimum/rbln/utils/import_utils.py +16 -1
  101. optimum/rbln/utils/runtime_utils.py +12 -8
  102. optimum/rbln/utils/submodule.py +24 -0
  103. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
  104. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
  105. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  106. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
  107. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
  108. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -86,11 +86,17 @@ _import_structure = {
86
86
  "RBLNDPTForDepthEstimationConfig",
87
87
  "RBLNDepthAnythingForDepthEstimationConfig",
88
88
  "RBLNDepthAnythingForDepthEstimation",
89
+ "RBLNDetrForObjectDetection",
90
+ "RBLNDetrForObjectDetectionConfig",
89
91
  "RBLNExaoneForCausalLM",
90
92
  "RBLNExaoneForCausalLMConfig",
91
93
  "RBLNGemmaModel",
92
94
  "RBLNGemmaModelConfig",
93
95
  "RBLNGemmaForCausalLM",
96
+ "RBLNGemma2ForCausalLM",
97
+ "RBLNGemma2ForCausalLMConfig",
98
+ "RBLNGemma2Model",
99
+ "RBLNGemma2ModelConfig",
94
100
  "RBLNGemmaForCausalLMConfig",
95
101
  "RBLNGemma3ForCausalLM",
96
102
  "RBLNGemma3ForCausalLMConfig",
@@ -100,6 +106,8 @@ _import_structure = {
100
106
  "RBLNGPT2ModelConfig",
101
107
  "RBLNGPT2LMHeadModel",
102
108
  "RBLNGPT2LMHeadModelConfig",
109
+ "RBLNGptOssForCausalLM",
110
+ "RBLNGptOssForCausalLMConfig",
103
111
  "RBLNGroundingDinoDecoder",
104
112
  "RBLNGroundingDinoDecoderConfig",
105
113
  "RBLNGroundingDinoForObjectDetection",
@@ -114,6 +122,8 @@ _import_structure = {
114
122
  "RBLNLlamaForCausalLMConfig",
115
123
  "RBLNLlamaModel",
116
124
  "RBLNLlamaModelConfig",
125
+ "RBLNMixtralForCausalLM",
126
+ "RBLNMixtralForCausalLMConfig",
117
127
  "RBLNOPTForCausalLM",
118
128
  "RBLNOPTForCausalLMConfig",
119
129
  "RBLNLlavaForConditionalGeneration",
@@ -140,14 +150,24 @@ _import_structure = {
140
150
  "RBLNPixtralVisionModelConfig",
141
151
  "RBLNPhiModel",
142
152
  "RBLNPhiModelConfig",
153
+ "RBLNPaliGemmaForConditionalGeneration",
154
+ "RBLNPaliGemmaForConditionalGenerationConfig",
155
+ "RBLNPaliGemmaModel",
156
+ "RBLNPaliGemmaModelConfig",
143
157
  "RBLNQwen2ForCausalLM",
144
158
  "RBLNQwen2ForCausalLMConfig",
145
159
  "RBLNQwen2_5_VisionTransformerPretrainedModel",
146
160
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
147
161
  "RBLNQwen2_5_VLForConditionalGeneration",
148
162
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
163
+ "RBLNQwen3MoeForCausalLM",
164
+ "RBLNQwen3MoeForCausalLMConfig",
165
+ "RBLNQwen2_5_VLModel",
166
+ "RBLNQwen2_5_VLModelConfig",
149
167
  "RBLNQwen2Model",
150
168
  "RBLNQwen2ModelConfig",
169
+ "RBLNQwen2MoeForCausalLM",
170
+ "RBLNQwen2MoeForCausalLMConfig",
151
171
  "RBLNQwen3ForCausalLM",
152
172
  "RBLNQwen3ForCausalLMConfig",
153
173
  "RBLNQwen3Model",
@@ -156,6 +176,8 @@ _import_structure = {
156
176
  "RBLNQwen2VisionTransformerPretrainedModelConfig",
157
177
  "RBLNQwen2VLForConditionalGeneration",
158
178
  "RBLNQwen2VLForConditionalGenerationConfig",
179
+ "RBLNQwen2VLModel",
180
+ "RBLNQwen2VLModelConfig",
159
181
  "RBLNResNetForImageClassification",
160
182
  "RBLNResNetForImageClassificationConfig",
161
183
  "RBLNRobertaForMaskedLM",
@@ -388,12 +410,18 @@ if TYPE_CHECKING:
388
410
  RBLNDecoderOnlyModelForCausalLMConfig,
389
411
  RBLNDepthAnythingForDepthEstimation,
390
412
  RBLNDepthAnythingForDepthEstimationConfig,
413
+ RBLNDetrForObjectDetection,
414
+ RBLNDetrForObjectDetectionConfig,
391
415
  RBLNDistilBertForQuestionAnswering,
392
416
  RBLNDistilBertForQuestionAnsweringConfig,
393
417
  RBLNDPTForDepthEstimation,
394
418
  RBLNDPTForDepthEstimationConfig,
395
419
  RBLNExaoneForCausalLM,
396
420
  RBLNExaoneForCausalLMConfig,
421
+ RBLNGemma2ForCausalLM,
422
+ RBLNGemma2ForCausalLMConfig,
423
+ RBLNGemma2Model,
424
+ RBLNGemma2ModelConfig,
397
425
  RBLNGemma3ForCausalLM,
398
426
  RBLNGemma3ForCausalLMConfig,
399
427
  RBLNGemma3ForConditionalGeneration,
@@ -406,6 +434,8 @@ if TYPE_CHECKING:
406
434
  RBLNGPT2LMHeadModelConfig,
407
435
  RBLNGPT2Model,
408
436
  RBLNGPT2ModelConfig,
437
+ RBLNGptOssForCausalLM,
438
+ RBLNGptOssForCausalLMConfig,
409
439
  RBLNGroundingDinoDecoder,
410
440
  RBLNGroundingDinoDecoderConfig,
411
441
  RBLNGroundingDinoEncoder,
@@ -432,10 +462,16 @@ if TYPE_CHECKING:
432
462
  RBLNMistralForCausalLMConfig,
433
463
  RBLNMistralModel,
434
464
  RBLNMistralModelConfig,
465
+ RBLNMixtralForCausalLM,
466
+ RBLNMixtralForCausalLMConfig,
435
467
  RBLNOPTForCausalLM,
436
468
  RBLNOPTForCausalLMConfig,
437
469
  RBLNOPTModel,
438
470
  RBLNOPTModelConfig,
471
+ RBLNPaliGemmaForConditionalGeneration,
472
+ RBLNPaliGemmaForConditionalGenerationConfig,
473
+ RBLNPaliGemmaModel,
474
+ RBLNPaliGemmaModelConfig,
439
475
  RBLNPegasusForConditionalGeneration,
440
476
  RBLNPegasusForConditionalGenerationConfig,
441
477
  RBLNPegasusModel,
@@ -450,18 +486,26 @@ if TYPE_CHECKING:
450
486
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
451
487
  RBLNQwen2_5_VLForConditionalGeneration,
452
488
  RBLNQwen2_5_VLForConditionalGenerationConfig,
489
+ RBLNQwen2_5_VLModel,
490
+ RBLNQwen2_5_VLModelConfig,
453
491
  RBLNQwen2ForCausalLM,
454
492
  RBLNQwen2ForCausalLMConfig,
455
493
  RBLNQwen2Model,
456
494
  RBLNQwen2ModelConfig,
495
+ RBLNQwen2MoeForCausalLM,
496
+ RBLNQwen2MoeForCausalLMConfig,
457
497
  RBLNQwen2VisionTransformerPretrainedModel,
458
498
  RBLNQwen2VisionTransformerPretrainedModelConfig,
459
499
  RBLNQwen2VLForConditionalGeneration,
460
500
  RBLNQwen2VLForConditionalGenerationConfig,
501
+ RBLNQwen2VLModel,
502
+ RBLNQwen2VLModelConfig,
461
503
  RBLNQwen3ForCausalLM,
462
504
  RBLNQwen3ForCausalLMConfig,
463
505
  RBLNQwen3Model,
464
506
  RBLNQwen3ModelConfig,
507
+ RBLNQwen3MoeForCausalLM,
508
+ RBLNQwen3MoeForCausalLMConfig,
465
509
  RBLNResNetForImageClassification,
466
510
  RBLNResNetForImageClassificationConfig,
467
511
  RBLNRobertaForMaskedLM,
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.9.4a2'
32
- __version_tuple__ = version_tuple = (0, 9, 4, 'a2')
31
+ __version__ = version = '0.10.0.post1'
32
+ __version_tuple__ = version_tuple = (0, 10, 0, 'post1')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -24,7 +24,7 @@ import torch
24
24
  from packaging.version import Version
25
25
 
26
26
  from .__version__ import __version__
27
- from .utils.deprecation import warn_deprecated_npu
27
+ from .utils.deprecation import deprecate_kwarg, deprecate_method, warn_deprecated_npu
28
28
  from .utils.logging import get_logger
29
29
  from .utils.runtime_utils import ContextRblnConfig
30
30
 
@@ -36,6 +36,30 @@ DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
36
36
  TypeInputInfo = List[Tuple[str, Tuple[int], str]]
37
37
 
38
38
 
39
+ def nested_update(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
40
+ """
41
+ Recursively merge override dict into base dict.
42
+ For nested dicts, values are merged recursively instead of being replaced.
43
+ For non-dict values, override takes precedence.
44
+ Args:
45
+ base: The base dictionary to merge into (modified in-place).
46
+ override: The dictionary with values to merge.
47
+ Returns:
48
+ The merged base dictionary.
49
+ Example:
50
+ >>> base = {"a": 1, "nested": {"x": 10, "y": 20}}
51
+ >>> override = {"b": 2, "nested": {"y": 30, "z": 40}}
52
+ >>> nested_update(base, override)
53
+ {"a": 1, "b": 2, "nested": {"x": 10, "y": 30, "z": 40}}
54
+ """
55
+ for key, value in override.items():
56
+ if key in base and isinstance(base[key], dict) and isinstance(value, dict):
57
+ nested_update(base[key], value)
58
+ else:
59
+ base[key] = value
60
+ return base
61
+
62
+
39
63
  @runtime_checkable
40
64
  class RBLNSerializableConfigProtocol(Protocol):
41
65
  def _prepare_for_serialization(self) -> Dict[str, Any]: ...
@@ -92,7 +116,7 @@ class RBLNCompileConfig:
92
116
  and isinstance(item[0], str) # name
93
117
  and isinstance(item[1], (tuple, list)) # shape
94
118
  and all(isinstance(x, int) for x in item[1])
95
- and isinstance(item[2], str) # dtype
119
+ and (isinstance(item[2], str) or isinstance(item[2], torch.dtype)) # dtype
96
120
  for item in input_info
97
121
  )
98
122
 
@@ -216,8 +240,7 @@ class RBLNAutoConfig:
216
240
  For example, the parsed contents of `rbln_config.json`.
217
241
 
218
242
  Returns:
219
- RBLNModelConfig: A configuration instance. The specific subclass is
220
- selected by `config_dict["cls_name"]`.
243
+ RBLNModelConfig: A configuration instance. The specific subclass is selected by `config_dict["cls_name"]`.
221
244
 
222
245
  Raises:
223
246
  ValueError: If `cls_name` is missing.
@@ -256,12 +279,13 @@ class RBLNAutoConfig:
256
279
 
257
280
  CONFIG_MAPPING[config.__name__] = config
258
281
 
259
- @staticmethod
260
- def load(
282
+ @classmethod
283
+ def from_pretrained(
284
+ cls,
261
285
  path: str,
262
- passed_rbln_config: Optional["RBLNModelConfig"] = None,
263
- kwargs: Optional[Dict[str, Any]] = None,
286
+ rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
264
287
  return_unused_kwargs: bool = False,
288
+ **kwargs: Optional[Dict[str, Any]],
265
289
  ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
266
290
  """
267
291
  Load RBLNModelConfig from a path.
@@ -269,53 +293,58 @@ class RBLNAutoConfig:
269
293
 
270
294
  Args:
271
295
  path (str): Path to the RBLNModelConfig.
272
- passed_rbln_config (Optional["RBLNModelConfig"]): RBLNModelConfig to pass its runtime options.
296
+ rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
297
+ return_unused_kwargs (bool): Whether to return unused kwargs.
298
+ kwargs: Additional keyword arguments to override configuration values.
273
299
 
274
300
  Returns:
275
301
  RBLNModelConfig: The loaded RBLNModelConfig.
276
- """
277
- if kwargs is None:
278
- kwargs = {}
279
- cls, config_file = load_config(path)
280
-
281
- rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
282
- rbln_runtime_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in RUNTIME_KEYWORDS}
283
- rbln_submodule_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in cls.submodules}
284
302
 
285
- rbln_kwargs = {
286
- key[5:]: kwargs.pop(key)
287
- for key in rbln_keys
288
- if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
289
- }
290
-
291
- # Process submodule's rbln_config
292
- for submodule in cls.submodules:
293
- if submodule not in config_file:
294
- raise ValueError(f"Submodule {submodule} not found in rbln_config.json.")
295
- submodule_config = config_file[submodule]
296
- submodule_config.update(rbln_submodule_kwargs.pop(submodule, {}))
297
- config_file[submodule] = RBLNAutoConfig.load_from_dict(submodule_config)
303
+ Examples:
304
+ ```python
305
+ config = RBLNAutoConfig.from_pretrained("/path/to/model")
306
+ ```
307
+ """
308
+ target_cls, _ = load_config(path)
309
+ return target_cls.from_pretrained(
310
+ path, rbln_config=rbln_config, return_unused_kwargs=return_unused_kwargs, **kwargs
311
+ )
298
312
 
299
- if passed_rbln_config is not None:
300
- config_file.update(passed_rbln_config._runtime_options)
301
- # TODO(jongho): Reject if the passed_rbln_config has different attributes from the config_file
313
+ @classmethod
314
+ @deprecate_method(version="0.11.0", new_method="from_pretrained")
315
+ def load(
316
+ cls,
317
+ path: str,
318
+ rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
319
+ return_unused_kwargs: bool = False,
320
+ **kwargs: Optional[Dict[str, Any]],
321
+ ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
322
+ """
323
+ Load RBLNModelConfig from a path.
324
+ Class name is automatically inferred from the `rbln_config.json` file.
302
325
 
303
- config_file.update(rbln_runtime_kwargs)
326
+ Deprecated:
327
+ This method is deprecated and will be removed in version 0.11.0.
328
+ Use `from_pretrained` instead.
304
329
 
305
- rbln_config = cls(**config_file)
330
+ Args:
331
+ path (str): Path to the RBLNModelConfig file or directory.
332
+ rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
333
+ return_unused_kwargs (bool): Whether to return unused kwargs.
334
+ kwargs: Additional keyword arguments to override configuration values.
306
335
 
307
- if len(rbln_kwargs) > 0:
308
- for key, value in rbln_kwargs.items():
309
- if getattr(rbln_config, key) != value:
310
- raise ValueError(
311
- f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
312
- f"Since the value is already set to {getattr(rbln_config, key)}"
313
- )
336
+ Returns:
337
+ RBLNModelConfig: The loaded RBLNModelConfig.
314
338
 
315
- if return_unused_kwargs:
316
- return cls(**config_file), kwargs
317
- else:
318
- return cls(**config_file)
339
+ Examples:
340
+ ```python
341
+ # Deprecated usage:
342
+ config = RBLNAutoConfig.load("/path/to/model")
343
+ # Recommended usage:
344
+ config = RBLNAutoConfig.from_pretrained("/path/to/model")
345
+ ```
346
+ """
347
+ return cls.from_pretrained(path, rbln_config=rbln_config, return_unused_kwargs=return_unused_kwargs, **kwargs)
319
348
 
320
349
 
321
350
  class RBLNModelConfig(RBLNSerializableConfigProtocol):
@@ -524,8 +553,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
524
553
  non_save_attributes = [
525
554
  "_frozen",
526
555
  "_runtime_options",
527
- "torch_dtype",
528
556
  "npu",
557
+ "dtype",
529
558
  "tensor_parallel_size",
530
559
  "create_runtimes",
531
560
  "device",
@@ -650,6 +679,14 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
650
679
 
651
680
  super().__setattr__(key, value)
652
681
 
682
+ @deprecate_kwarg(
683
+ old_name="_torch_dtype",
684
+ new_name="dtype",
685
+ version="0.12.0",
686
+ deprecated_type=torch.dtype,
687
+ value_replacer=RBLNCompileConfig.normalize_dtype,
688
+ raise_if_greater_or_equal_version=False,
689
+ )
653
690
  def __init__(
654
691
  self,
655
692
  cls_name: Optional[str] = None,
@@ -661,7 +698,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
661
698
  tensor_parallel_size: Optional[int] = None,
662
699
  timeout: Optional[int] = None,
663
700
  optimum_rbln_version: Optional[str] = None,
664
- _torch_dtype: Optional[str] = None,
701
+ dtype: Optional[Union[str, torch.dtype]] = None,
665
702
  _compile_cfgs: Optional[List[RBLNCompileConfig]] = None,
666
703
  *,
667
704
  optimize_host_memory: Optional[bool] = None,
@@ -680,7 +717,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
680
717
  tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
681
718
  timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
682
719
  optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
683
- _torch_dtype (Optional[str]): The data type to use for the model.
720
+ dtype (Optional[Union[str, torch.dtype]]): The data type to use for the model.
684
721
  _compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
685
722
  kwargs: Additional keyword arguments.
686
723
 
@@ -710,7 +747,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
710
747
  self.npu = npu
711
748
  self.tensor_parallel_size = tensor_parallel_size
712
749
 
713
- self._torch_dtype = _torch_dtype or "float32"
750
+ if dtype is not None and isinstance(dtype, torch.dtype):
751
+ dtype = RBLNCompileConfig.normalize_dtype(dtype)
752
+ self._dtype = dtype or "float32"
714
753
  self.optimum_rbln_version = optimum_rbln_version
715
754
  if self.optimum_rbln_version is None:
716
755
  self.optimum_rbln_version = __version__
@@ -743,14 +782,24 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
743
782
 
744
783
  @property
745
784
  def torch_dtype(self):
746
- return getattr(torch, self._torch_dtype)
785
+ logger.warning_once("`torch_dtype` is deprecated. Use `dtype` instead.")
786
+ return self.dtype
747
787
 
748
788
  @torch_dtype.setter
749
789
  def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
750
- if isinstance(torch_dtype, torch.dtype):
751
- torch_dtype = RBLNCompileConfig.normalize_dtype(torch_dtype)
790
+ logger.warning_once("`torch_dtype` is deprecated. Use `dtype` instead.")
791
+ self.dtype = torch_dtype
792
+
793
+ @property
794
+ def dtype(self):
795
+ return getattr(torch, self._dtype)
752
796
 
753
- self._torch_dtype = torch_dtype
797
+ @dtype.setter
798
+ def dtype(self, dtype: Union[str, torch.dtype]):
799
+ if isinstance(dtype, torch.dtype):
800
+ dtype = RBLNCompileConfig.normalize_dtype(dtype)
801
+
802
+ self._dtype = dtype
754
803
 
755
804
  @property
756
805
  def rbln_model_cls_name(self) -> str:
@@ -774,10 +823,15 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
774
823
  if isinstance(value, RBLNSerializableConfigProtocol):
775
824
  # Convert nested RBLNModelConfig to its serializable form
776
825
  serializable_map[key] = value._prepare_for_serialization()
826
+ elif key == "_dtype":
827
+ serializable_map["dtype"] = value
828
+ elif isinstance(value, list) and all(isinstance(item, RBLNSerializableConfigProtocol) for item in value):
829
+ serializable_map[key] = [item._prepare_for_serialization() for item in value]
777
830
  elif key == "_compile_cfgs":
778
831
  serializable_map[key] = [cfg.asdict() for cfg in value]
779
832
  else:
780
833
  serializable_map[key] = value
834
+
781
835
  return serializable_map
782
836
 
783
837
  def __repr__(self):
@@ -825,18 +879,12 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
825
879
  if not isinstance(submodule_config, RBLNModelConfig):
826
880
  raise ValueError(f"`{submodule_name}` must be an instance of `RBLNModelConfig` before freezing.")
827
881
 
828
- if not submodule_config.is_frozen():
829
- raise ValueError(f"`{submodule_name}` config must be frozen before freezing super config.")
830
-
831
882
  self._frozen = True
832
883
 
833
884
  def is_frozen(self):
834
885
  return self._frozen
835
886
 
836
887
  def save(self, path: str):
837
- if not self._frozen:
838
- raise RuntimeError("`RBLNModelConfig` is not frozen. Please call `set_compile_cfgs` first.")
839
-
840
888
  # save as json file without runtime attributes
841
889
  path = Path(path)
842
890
  if path.is_dir():
@@ -847,15 +895,23 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
847
895
  json.dump(serializable_data, jsonf, indent=2)
848
896
 
849
897
  @classmethod
850
- def load(cls, path: str, **kwargs: Any) -> "RBLNModelConfig":
898
+ def from_pretrained(
899
+ cls,
900
+ path: str,
901
+ rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
902
+ return_unused_kwargs: bool = False,
903
+ **kwargs: Optional[Dict[str, Any]],
904
+ ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
851
905
  """
852
906
  Load a RBLNModelConfig from a path.
853
907
 
854
908
  Args:
855
909
  path (str): Path to the RBLNModelConfig file or directory containing the config file.
910
+ rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
911
+ return_unused_kwargs (bool): Whether to return unused kwargs.
856
912
  kwargs: Additional keyword arguments to override configuration values.
857
- Keys starting with 'rbln_' will have the prefix removed and be used
858
- to update the configuration.
913
+ Keys starting with 'rbln_' will have the prefix removed and be used
914
+ to update the configuration.
859
915
 
860
916
  Returns:
861
917
  RBLNModelConfig: The loaded configuration instance.
@@ -864,17 +920,109 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
864
920
  This method loads the configuration from the specified path and applies any
865
921
  provided overrides. If the loaded configuration class doesn't match the expected
866
922
  class, a warning will be logged.
923
+
924
+ Examples:
925
+ ```python
926
+ config = RBLNResNetForImageClassificationConfig.from_pretrained("/path/to/model")
927
+ ```
867
928
  """
868
929
  cls_reserved, config_file = load_config(path)
869
-
870
930
  if cls_reserved != cls:
871
931
  logger.warning(f"Expected {cls.__name__}, but got {cls_reserved.__name__}.")
872
932
 
933
+ if isinstance(rbln_config, dict):
934
+ for key, value in rbln_config.items():
935
+ if key not in kwargs:
936
+ kwargs[f"rbln_{key}"] = value
937
+
873
938
  rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
874
- rbln_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys}
875
- config_file.update(rbln_kwargs)
939
+ rbln_runtime_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in RUNTIME_KEYWORDS}
940
+ rbln_submodule_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in cls.submodules}
941
+
942
+ rbln_kwargs = {
943
+ key[5:]: kwargs.pop(key)
944
+ for key in rbln_keys
945
+ if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
946
+ }
947
+
948
+ # Process submodule's rbln_config
949
+ for submodule in cls.submodules:
950
+ if submodule not in config_file:
951
+ raise ValueError(f"Submodule {submodule} not found in rbln_config.json.")
952
+ submodule_config = config_file[submodule]
953
+ submodule_config.update(rbln_runtime_kwargs)
954
+
955
+ update_dict = rbln_submodule_kwargs.pop(submodule, {})
956
+ if update_dict:
957
+ nested_update(submodule_config, update_dict)
958
+ config_file[submodule] = RBLNAutoConfig.load_from_dict(submodule_config)
959
+
960
+ if isinstance(rbln_config, RBLNModelConfig):
961
+ config_file.update(rbln_config._runtime_options)
962
+
963
+ # update submodule runtime
964
+ for submodule in rbln_config.submodules:
965
+ if str(config_file[submodule]) != str(getattr(rbln_config, submodule)):
966
+ raise ValueError(
967
+ f"Passed rbln_config has different attributes for submodule {submodule} than the config_file"
968
+ )
969
+ config_file[submodule] = getattr(rbln_config, submodule)
970
+
971
+ config_file.update(rbln_runtime_kwargs)
972
+ rbln_config = cls(**config_file)
973
+ if len(rbln_kwargs) > 0:
974
+ for key, value in rbln_kwargs.items():
975
+ if getattr(rbln_config, key) != value:
976
+ raise ValueError(
977
+ f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
978
+ f"Since the value is already set to {getattr(rbln_config, key)}"
979
+ )
980
+ if return_unused_kwargs:
981
+ return rbln_config, kwargs
982
+ else:
983
+ return rbln_config
876
984
 
877
- return cls(**config_file)
985
+ @classmethod
986
+ @deprecate_method(version="0.11.0", new_method="from_pretrained")
987
+ def load(
988
+ cls,
989
+ path: str,
990
+ rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
991
+ return_unused_kwargs: bool = False,
992
+ **kwargs: Optional[Dict[str, Any]],
993
+ ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
994
+ """
995
+ Load a RBLNModelConfig from a path.
996
+
997
+ Deprecated:
998
+ This method is deprecated and will be removed in version 0.11.0.
999
+ Use `from_pretrained` instead.
1000
+
1001
+ Args:
1002
+ path (str): Path to the RBLNModelConfig file or directory containing the config file.
1003
+ rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
1004
+ return_unused_kwargs (bool): Whether to return unused kwargs.
1005
+ kwargs: Additional keyword arguments to override configuration values.
1006
+ Keys starting with 'rbln_' will have the prefix removed and be used
1007
+ to update the configuration.
1008
+
1009
+ Returns:
1010
+ RBLNModelConfig: The loaded configuration instance.
1011
+
1012
+ Note:
1013
+ This method loads the configuration from the specified path and applies any
1014
+ provided overrides. If the loaded configuration class doesn't match the expected
1015
+ class, a warning will be logged.
1016
+
1017
+ Examples:
1018
+ ```python
1019
+ # Deprecated usage:
1020
+ config = RBLNResNetForImageClassificationConfig.load("/path/to/model")
1021
+ # Recommended usage:
1022
+ config = RBLNResNetForImageClassificationConfig.from_pretrained("/path/to/model")
1023
+ ```
1024
+ """
1025
+ return cls.from_pretrained(path, rbln_config=rbln_config, return_unused_kwargs=return_unused_kwargs, **kwargs)
878
1026
 
879
1027
  @classmethod
880
1028
  def initialize_from_kwargs(
@@ -974,3 +1122,18 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
974
1122
  @timeout.setter
975
1123
  def timeout(self, timeout: int):
976
1124
  self._runtime_options["timeout"] = timeout
1125
+
1126
+
1127
+ def convert_rbln_config_dict(
1128
+ rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None, **kwargs
1129
+ ) -> Tuple[Optional[Union[Dict[str, Any], RBLNModelConfig]], Dict[str, Any]]:
1130
+ # Validate and merge rbln_ prefixed kwargs into rbln_config
1131
+ kwargs_keys = list(kwargs.keys())
1132
+ rbln_kwargs = {key[5:]: kwargs.pop(key) for key in kwargs_keys if key.startswith("rbln_")}
1133
+
1134
+ rbln_config = {} if rbln_config is None else rbln_config
1135
+
1136
+ if isinstance(rbln_config, dict) and len(rbln_kwargs) > 0:
1137
+ rbln_config.update(rbln_kwargs)
1138
+
1139
+ return rbln_config, kwargs
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Dict, Optional, Union
15
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from diffusers import ControlNetModel
@@ -218,7 +218,7 @@ class RBLNControlNetModel(RBLNModel):
218
218
  added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
219
219
  return_dict: bool = True,
220
220
  **kwargs,
221
- ):
221
+ ) -> Union[ControlNetOutput, Tuple]:
222
222
  """
223
223
  Forward pass for the RBLN-optimized ControlNetModel.
224
224
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from pathlib import Path
16
- from typing import TYPE_CHECKING, Optional, Union
16
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from diffusers.models.transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
@@ -134,7 +134,7 @@ class RBLNPriorTransformer(RBLNModel):
134
134
  encoder_hidden_states: Optional[torch.Tensor] = None,
135
135
  attention_mask: Optional[torch.Tensor] = None,
136
136
  return_dict: bool = True,
137
- ):
137
+ ) -> Union[PriorTransformerOutput, Tuple]:
138
138
  """
139
139
  Forward pass for the RBLN-optimized PriorTransformer.
140
140
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from pathlib import Path
16
- from typing import TYPE_CHECKING, List, Optional, Union
16
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
17
17
 
18
18
  import rebel
19
19
  import torch
@@ -302,7 +302,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
302
302
  condition_mask: Optional[torch.Tensor] = None,
303
303
  padding_mask: Optional[torch.Tensor] = None,
304
304
  return_dict: bool = True,
305
- ):
305
+ ) -> Union[Transformer2DModelOutput, Tuple]:
306
306
  """
307
307
  Forward pass for the RBLN-optimized CosmosTransformer3DModel.
308
308
 
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
@@ -160,7 +160,7 @@ class RBLNSD3Transformer2DModel(RBLNModel):
160
160
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
161
161
  return_dict: bool = True,
162
162
  **kwargs,
163
- ):
163
+ ) -> Union[Transformer2DModelOutput, Tuple]:
164
164
  """
165
165
  Forward pass for the RBLN-optimized SD3Transformer2DModel.
166
166