optimum-rbln 0.9.5a4__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +196 -52
  4. optimum/rbln/diffusers/models/controlnet.py +2 -2
  5. optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
  6. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
  7. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
  8. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  13. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
  14. optimum/rbln/modeling_base.py +5 -4
  15. optimum/rbln/transformers/__init__.py +8 -0
  16. optimum/rbln/transformers/modeling_attention_utils.py +15 -9
  17. optimum/rbln/transformers/models/__init__.py +10 -0
  18. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  19. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
  20. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
  21. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +2 -2
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +26 -1
  23. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +2 -1
  24. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +45 -21
  25. optimum/rbln/transformers/models/detr/__init__.py +23 -0
  26. optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
  27. optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
  28. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
  29. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +4 -176
  30. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +4 -3
  31. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +10 -7
  32. optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
  33. optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
  34. optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
  35. optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
  36. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +7 -7
  37. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
  38. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +2 -0
  39. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +2 -0
  40. optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
  41. optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
  42. optimum/rbln/utils/deprecation.py +78 -1
  43. optimum/rbln/utils/hub.py +93 -2
  44. optimum/rbln/utils/runtime_utils.py +2 -2
  45. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +1 -1
  46. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +49 -42
  47. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
  48. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
  49. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -86,6 +86,8 @@ _import_structure = {
86
86
  "RBLNDPTForDepthEstimationConfig",
87
87
  "RBLNDepthAnythingForDepthEstimationConfig",
88
88
  "RBLNDepthAnythingForDepthEstimation",
89
+ "RBLNDetrForObjectDetection",
90
+ "RBLNDetrForObjectDetectionConfig",
89
91
  "RBLNExaoneForCausalLM",
90
92
  "RBLNExaoneForCausalLMConfig",
91
93
  "RBLNGemmaModel",
@@ -120,6 +122,8 @@ _import_structure = {
120
122
  "RBLNLlamaForCausalLMConfig",
121
123
  "RBLNLlamaModel",
122
124
  "RBLNLlamaModelConfig",
125
+ "RBLNMixtralForCausalLM",
126
+ "RBLNMixtralForCausalLMConfig",
123
127
  "RBLNOPTForCausalLM",
124
128
  "RBLNOPTForCausalLMConfig",
125
129
  "RBLNLlavaForConditionalGeneration",
@@ -406,6 +410,8 @@ if TYPE_CHECKING:
406
410
  RBLNDecoderOnlyModelForCausalLMConfig,
407
411
  RBLNDepthAnythingForDepthEstimation,
408
412
  RBLNDepthAnythingForDepthEstimationConfig,
413
+ RBLNDetrForObjectDetection,
414
+ RBLNDetrForObjectDetectionConfig,
409
415
  RBLNDistilBertForQuestionAnswering,
410
416
  RBLNDistilBertForQuestionAnsweringConfig,
411
417
  RBLNDPTForDepthEstimation,
@@ -456,6 +462,8 @@ if TYPE_CHECKING:
456
462
  RBLNMistralForCausalLMConfig,
457
463
  RBLNMistralModel,
458
464
  RBLNMistralModelConfig,
465
+ RBLNMixtralForCausalLM,
466
+ RBLNMixtralForCausalLMConfig,
459
467
  RBLNOPTForCausalLM,
460
468
  RBLNOPTForCausalLMConfig,
461
469
  RBLNOPTModel,
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.9.5a4'
32
- __version_tuple__ = version_tuple = (0, 9, 5, 'a4')
31
+ __version__ = version = '0.10.0.post1'
32
+ __version_tuple__ = version_tuple = (0, 10, 0, 'post1')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -24,7 +24,7 @@ import torch
24
24
  from packaging.version import Version
25
25
 
26
26
  from .__version__ import __version__
27
- from .utils.deprecation import deprecate_kwarg, warn_deprecated_npu
27
+ from .utils.deprecation import deprecate_kwarg, deprecate_method, warn_deprecated_npu
28
28
  from .utils.logging import get_logger
29
29
  from .utils.runtime_utils import ContextRblnConfig
30
30
 
@@ -36,6 +36,30 @@ DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
36
36
  TypeInputInfo = List[Tuple[str, Tuple[int], str]]
37
37
 
38
38
 
39
+ def nested_update(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
40
+ """
41
+ Recursively merge override dict into base dict.
42
+ For nested dicts, values are merged recursively instead of being replaced.
43
+ For non-dict values, override takes precedence.
44
+ Args:
45
+ base: The base dictionary to merge into (modified in-place).
46
+ override: The dictionary with values to merge.
47
+ Returns:
48
+ The merged base dictionary.
49
+ Example:
50
+ >>> base = {"a": 1, "nested": {"x": 10, "y": 20}}
51
+ >>> override = {"b": 2, "nested": {"y": 30, "z": 40}}
52
+ >>> nested_update(base, override)
53
+ {"a": 1, "b": 2, "nested": {"x": 10, "y": 30, "z": 40}}
54
+ """
55
+ for key, value in override.items():
56
+ if key in base and isinstance(base[key], dict) and isinstance(value, dict):
57
+ nested_update(base[key], value)
58
+ else:
59
+ base[key] = value
60
+ return base
61
+
62
+
39
63
  @runtime_checkable
40
64
  class RBLNSerializableConfigProtocol(Protocol):
41
65
  def _prepare_for_serialization(self) -> Dict[str, Any]: ...
@@ -216,8 +240,7 @@ class RBLNAutoConfig:
216
240
  For example, the parsed contents of `rbln_config.json`.
217
241
 
218
242
  Returns:
219
- RBLNModelConfig: A configuration instance. The specific subclass is
220
- selected by `config_dict["cls_name"]`.
243
+ RBLNModelConfig: A configuration instance. The specific subclass is selected by `config_dict["cls_name"]`.
221
244
 
222
245
  Raises:
223
246
  ValueError: If `cls_name` is missing.
@@ -256,12 +279,13 @@ class RBLNAutoConfig:
256
279
 
257
280
  CONFIG_MAPPING[config.__name__] = config
258
281
 
259
- @staticmethod
260
- def load(
282
+ @classmethod
283
+ def from_pretrained(
284
+ cls,
261
285
  path: str,
262
- passed_rbln_config: Optional["RBLNModelConfig"] = None,
263
- kwargs: Optional[Dict[str, Any]] = None,
286
+ rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
264
287
  return_unused_kwargs: bool = False,
288
+ **kwargs: Optional[Dict[str, Any]],
265
289
  ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
266
290
  """
267
291
  Load RBLNModelConfig from a path.
@@ -269,53 +293,58 @@ class RBLNAutoConfig:
269
293
 
270
294
  Args:
271
295
  path (str): Path to the RBLNModelConfig.
272
- passed_rbln_config (Optional["RBLNModelConfig"]): RBLNModelConfig to pass its runtime options.
296
+ rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
297
+ return_unused_kwargs (bool): Whether to return unused kwargs.
298
+ kwargs: Additional keyword arguments to override configuration values.
273
299
 
274
300
  Returns:
275
301
  RBLNModelConfig: The loaded RBLNModelConfig.
276
- """
277
- if kwargs is None:
278
- kwargs = {}
279
- cls, config_file = load_config(path)
280
-
281
- rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
282
- rbln_runtime_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in RUNTIME_KEYWORDS}
283
- rbln_submodule_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in cls.submodules}
284
-
285
- rbln_kwargs = {
286
- key[5:]: kwargs.pop(key)
287
- for key in rbln_keys
288
- if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
289
- }
290
302
 
291
- # Process submodule's rbln_config
292
- for submodule in cls.submodules:
293
- if submodule not in config_file:
294
- raise ValueError(f"Submodule {submodule} not found in rbln_config.json.")
295
- submodule_config = config_file[submodule]
296
- submodule_config.update(rbln_submodule_kwargs.pop(submodule, {}))
297
- config_file[submodule] = RBLNAutoConfig.load_from_dict(submodule_config)
303
+ Examples:
304
+ ```python
305
+ config = RBLNAutoConfig.from_pretrained("/path/to/model")
306
+ ```
307
+ """
308
+ target_cls, _ = load_config(path)
309
+ return target_cls.from_pretrained(
310
+ path, rbln_config=rbln_config, return_unused_kwargs=return_unused_kwargs, **kwargs
311
+ )
298
312
 
299
- if passed_rbln_config is not None:
300
- config_file.update(passed_rbln_config._runtime_options)
301
- # TODO(jongho): Reject if the passed_rbln_config has different attributes from the config_file
313
+ @classmethod
314
+ @deprecate_method(version="0.11.0", new_method="from_pretrained")
315
+ def load(
316
+ cls,
317
+ path: str,
318
+ rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
319
+ return_unused_kwargs: bool = False,
320
+ **kwargs: Optional[Dict[str, Any]],
321
+ ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
322
+ """
323
+ Load RBLNModelConfig from a path.
324
+ Class name is automatically inferred from the `rbln_config.json` file.
302
325
 
303
- config_file.update(rbln_runtime_kwargs)
326
+ Deprecated:
327
+ This method is deprecated and will be removed in version 0.11.0.
328
+ Use `from_pretrained` instead.
304
329
 
305
- rbln_config = cls(**config_file)
330
+ Args:
331
+ path (str): Path to the RBLNModelConfig file or directory.
332
+ rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
333
+ return_unused_kwargs (bool): Whether to return unused kwargs.
334
+ kwargs: Additional keyword arguments to override configuration values.
306
335
 
307
- if len(rbln_kwargs) > 0:
308
- for key, value in rbln_kwargs.items():
309
- if getattr(rbln_config, key) != value:
310
- raise ValueError(
311
- f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
312
- f"Since the value is already set to {getattr(rbln_config, key)}"
313
- )
336
+ Returns:
337
+ RBLNModelConfig: The loaded RBLNModelConfig.
314
338
 
315
- if return_unused_kwargs:
316
- return cls(**config_file), kwargs
317
- else:
318
- return cls(**config_file)
339
+ Examples:
340
+ ```python
341
+ # Deprecated usage:
342
+ config = RBLNAutoConfig.load("/path/to/model")
343
+ # Recommended usage:
344
+ config = RBLNAutoConfig.from_pretrained("/path/to/model")
345
+ ```
346
+ """
347
+ return cls.from_pretrained(path, rbln_config=rbln_config, return_unused_kwargs=return_unused_kwargs, **kwargs)
319
348
 
320
349
 
321
350
  class RBLNModelConfig(RBLNSerializableConfigProtocol):
@@ -866,15 +895,23 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
866
895
  json.dump(serializable_data, jsonf, indent=2)
867
896
 
868
897
  @classmethod
869
- def load(cls, path: str, **kwargs: Any) -> "RBLNModelConfig":
898
+ def from_pretrained(
899
+ cls,
900
+ path: str,
901
+ rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
902
+ return_unused_kwargs: bool = False,
903
+ **kwargs: Optional[Dict[str, Any]],
904
+ ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
870
905
  """
871
906
  Load a RBLNModelConfig from a path.
872
907
 
873
908
  Args:
874
909
  path (str): Path to the RBLNModelConfig file or directory containing the config file.
910
+ rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
911
+ return_unused_kwargs (bool): Whether to return unused kwargs.
875
912
  kwargs: Additional keyword arguments to override configuration values.
876
- Keys starting with 'rbln_' will have the prefix removed and be used
877
- to update the configuration.
913
+ Keys starting with 'rbln_' will have the prefix removed and be used
914
+ to update the configuration.
878
915
 
879
916
  Returns:
880
917
  RBLNModelConfig: The loaded configuration instance.
@@ -883,17 +920,109 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
883
920
  This method loads the configuration from the specified path and applies any
884
921
  provided overrides. If the loaded configuration class doesn't match the expected
885
922
  class, a warning will be logged.
923
+
924
+ Examples:
925
+ ```python
926
+ config = RBLNResNetForImageClassificationConfig.from_pretrained("/path/to/model")
927
+ ```
886
928
  """
887
929
  cls_reserved, config_file = load_config(path)
888
-
889
930
  if cls_reserved != cls:
890
931
  logger.warning(f"Expected {cls.__name__}, but got {cls_reserved.__name__}.")
891
932
 
933
+ if isinstance(rbln_config, dict):
934
+ for key, value in rbln_config.items():
935
+ if key not in kwargs:
936
+ kwargs[f"rbln_{key}"] = value
937
+
892
938
  rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
893
- rbln_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys}
894
- config_file.update(rbln_kwargs)
939
+ rbln_runtime_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in RUNTIME_KEYWORDS}
940
+ rbln_submodule_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in cls.submodules}
941
+
942
+ rbln_kwargs = {
943
+ key[5:]: kwargs.pop(key)
944
+ for key in rbln_keys
945
+ if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
946
+ }
947
+
948
+ # Process submodule's rbln_config
949
+ for submodule in cls.submodules:
950
+ if submodule not in config_file:
951
+ raise ValueError(f"Submodule {submodule} not found in rbln_config.json.")
952
+ submodule_config = config_file[submodule]
953
+ submodule_config.update(rbln_runtime_kwargs)
895
954
 
896
- return cls(**config_file)
955
+ update_dict = rbln_submodule_kwargs.pop(submodule, {})
956
+ if update_dict:
957
+ nested_update(submodule_config, update_dict)
958
+ config_file[submodule] = RBLNAutoConfig.load_from_dict(submodule_config)
959
+
960
+ if isinstance(rbln_config, RBLNModelConfig):
961
+ config_file.update(rbln_config._runtime_options)
962
+
963
+ # update submodule runtime
964
+ for submodule in rbln_config.submodules:
965
+ if str(config_file[submodule]) != str(getattr(rbln_config, submodule)):
966
+ raise ValueError(
967
+ f"Passed rbln_config has different attributes for submodule {submodule} than the config_file"
968
+ )
969
+ config_file[submodule] = getattr(rbln_config, submodule)
970
+
971
+ config_file.update(rbln_runtime_kwargs)
972
+ rbln_config = cls(**config_file)
973
+ if len(rbln_kwargs) > 0:
974
+ for key, value in rbln_kwargs.items():
975
+ if getattr(rbln_config, key) != value:
976
+ raise ValueError(
977
+ f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
978
+ f"Since the value is already set to {getattr(rbln_config, key)}"
979
+ )
980
+ if return_unused_kwargs:
981
+ return rbln_config, kwargs
982
+ else:
983
+ return rbln_config
984
+
985
+ @classmethod
986
+ @deprecate_method(version="0.11.0", new_method="from_pretrained")
987
+ def load(
988
+ cls,
989
+ path: str,
990
+ rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
991
+ return_unused_kwargs: bool = False,
992
+ **kwargs: Optional[Dict[str, Any]],
993
+ ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
994
+ """
995
+ Load a RBLNModelConfig from a path.
996
+
997
+ Deprecated:
998
+ This method is deprecated and will be removed in version 0.11.0.
999
+ Use `from_pretrained` instead.
1000
+
1001
+ Args:
1002
+ path (str): Path to the RBLNModelConfig file or directory containing the config file.
1003
+ rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
1004
+ return_unused_kwargs (bool): Whether to return unused kwargs.
1005
+ kwargs: Additional keyword arguments to override configuration values.
1006
+ Keys starting with 'rbln_' will have the prefix removed and be used
1007
+ to update the configuration.
1008
+
1009
+ Returns:
1010
+ RBLNModelConfig: The loaded configuration instance.
1011
+
1012
+ Note:
1013
+ This method loads the configuration from the specified path and applies any
1014
+ provided overrides. If the loaded configuration class doesn't match the expected
1015
+ class, a warning will be logged.
1016
+
1017
+ Examples:
1018
+ ```python
1019
+ # Deprecated usage:
1020
+ config = RBLNResNetForImageClassificationConfig.load("/path/to/model")
1021
+ # Recommended usage:
1022
+ config = RBLNResNetForImageClassificationConfig.from_pretrained("/path/to/model")
1023
+ ```
1024
+ """
1025
+ return cls.from_pretrained(path, rbln_config=rbln_config, return_unused_kwargs=return_unused_kwargs, **kwargs)
897
1026
 
898
1027
  @classmethod
899
1028
  def initialize_from_kwargs(
@@ -993,3 +1122,18 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
993
1122
  @timeout.setter
994
1123
  def timeout(self, timeout: int):
995
1124
  self._runtime_options["timeout"] = timeout
1125
+
1126
+
1127
+ def convert_rbln_config_dict(
1128
+ rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None, **kwargs
1129
+ ) -> Tuple[Optional[Union[Dict[str, Any], RBLNModelConfig]], Dict[str, Any]]:
1130
+ # Validate and merge rbln_ prefixed kwargs into rbln_config
1131
+ kwargs_keys = list(kwargs.keys())
1132
+ rbln_kwargs = {key[5:]: kwargs.pop(key) for key in kwargs_keys if key.startswith("rbln_")}
1133
+
1134
+ rbln_config = {} if rbln_config is None else rbln_config
1135
+
1136
+ if isinstance(rbln_config, dict) and len(rbln_kwargs) > 0:
1137
+ rbln_config.update(rbln_kwargs)
1138
+
1139
+ return rbln_config, kwargs
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Dict, Optional, Union
15
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from diffusers import ControlNetModel
@@ -218,7 +218,7 @@ class RBLNControlNetModel(RBLNModel):
218
218
  added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
219
219
  return_dict: bool = True,
220
220
  **kwargs,
221
- ):
221
+ ) -> Union[ControlNetOutput, Tuple]:
222
222
  """
223
223
  Forward pass for the RBLN-optimized ControlNetModel.
224
224
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from pathlib import Path
16
- from typing import TYPE_CHECKING, Optional, Union
16
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from diffusers.models.transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
@@ -134,7 +134,7 @@ class RBLNPriorTransformer(RBLNModel):
134
134
  encoder_hidden_states: Optional[torch.Tensor] = None,
135
135
  attention_mask: Optional[torch.Tensor] = None,
136
136
  return_dict: bool = True,
137
- ):
137
+ ) -> Union[PriorTransformerOutput, Tuple]:
138
138
  """
139
139
  Forward pass for the RBLN-optimized PriorTransformer.
140
140
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from pathlib import Path
16
- from typing import TYPE_CHECKING, List, Optional, Union
16
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
17
17
 
18
18
  import rebel
19
19
  import torch
@@ -302,7 +302,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
302
302
  condition_mask: Optional[torch.Tensor] = None,
303
303
  padding_mask: Optional[torch.Tensor] = None,
304
304
  return_dict: bool = True,
305
- ):
305
+ ) -> Union[Transformer2DModelOutput, Tuple]:
306
306
  """
307
307
  Forward pass for the RBLN-optimized CosmosTransformer3DModel.
308
308
 
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
@@ -160,7 +160,7 @@ class RBLNSD3Transformer2DModel(RBLNModel):
160
160
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
161
161
  return_dict: bool = True,
162
162
  **kwargs,
163
- ):
163
+ ) -> Union[Transformer2DModelOutput, Tuple]:
164
164
  """
165
165
  Forward pass for the RBLN-optimized SD3Transformer2DModel.
166
166
 
@@ -176,7 +176,7 @@ class RBLNAutoPipelineBase:
176
176
  export: bool = None,
177
177
  rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None,
178
178
  **kwargs: Any,
179
- ):
179
+ ) -> RBLNBaseModel:
180
180
  """
181
181
  Load an RBLN-accelerated Diffusers pipeline from a pretrained checkpoint or a compiled RBLN artifact.
182
182
 
@@ -201,8 +201,7 @@ class RBLNAutoPipelineBase:
201
201
  - Remaining arguments are forwarded to the Diffusers loader.
202
202
 
203
203
  Returns:
204
- RBLNBaseModel: An instantiated RBLN model wrapping the Diffusers pipeline, ready for
205
- inference on RBLN NPUs.
204
+ RBLNBaseModel: An instantiated RBLN model wrapping the Diffusers pipeline, ready for inference on RBLN NPUs.
206
205
 
207
206
  """
208
207
  rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
@@ -26,7 +26,7 @@
26
26
  # See the License for the specific language governing permissions and
27
27
  # limitations under the License.
28
28
 
29
- from typing import Any, Callable, Dict, List, Optional, Union
29
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
30
30
 
31
31
  import torch
32
32
  import torch.nn.functional as F
@@ -260,7 +260,7 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
260
260
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
261
261
  callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
262
262
  **kwargs,
263
- ):
263
+ ) -> Union[StableDiffusionPipelineOutput, Tuple]:
264
264
  r"""
265
265
  The call function to the pipeline for generation.
266
266
 
@@ -321,14 +321,7 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
321
321
  output_type (`str`, *optional*, defaults to `"pil"`):
322
322
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
323
323
  return_dict (`bool`, *optional*, defaults to `True`):
324
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
325
- plain tuple.
326
- callback (`Callable`, *optional*):
327
- A function that calls every `callback_steps` steps during inference. The function is called with the
328
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
329
- callback_steps (`int`, *optional*, defaults to 1):
330
- The frequency at which the `callback` function is called. If not specified, the callback is called at
331
- every step.
324
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple.
332
325
  cross_attention_kwargs (`dict`, *optional*):
333
326
  A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
334
327
  [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -356,8 +349,6 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
356
349
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
357
350
  `._callback_tensor_inputs` attribute of your pipeine class.
358
351
 
359
- Examples:
360
-
361
352
  Returns:
362
353
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
363
354
  If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
@@ -26,7 +26,7 @@
26
26
  # See the License for the specific language governing permissions and
27
27
  # limitations under the License.
28
28
 
29
- from typing import Any, Callable, Dict, List, Optional, Union
29
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
30
30
 
31
31
  import torch
32
32
  import torch.nn.functional as F
@@ -253,7 +253,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
253
253
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
254
254
  callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
255
255
  **kwargs,
256
- ):
256
+ ) -> Union[StableDiffusionPipelineOutput, Tuple]:
257
257
  r"""
258
258
  The call function to the pipeline for generation.
259
259
 
@@ -347,8 +347,6 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
347
347
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
348
348
  `._callback_tensor_inputs` attribute of your pipeine class.
349
349
 
350
- Examples:
351
-
352
350
  Returns:
353
351
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
354
352
  If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
@@ -294,7 +294,7 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
294
294
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
295
295
  callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
296
296
  **kwargs,
297
- ):
297
+ ) -> Union[StableDiffusionXLPipelineOutput, Tuple]:
298
298
  r"""
299
299
  The call function to the pipeline for generation.
300
300
 
@@ -431,8 +431,6 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
431
431
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
432
432
  `._callback_tensor_inputs` attribute of your pipeine class.
433
433
 
434
- Examples:
435
-
436
434
  Returns:
437
435
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
438
436
  If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
@@ -309,7 +309,7 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
309
309
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
310
310
  callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
311
311
  **kwargs,
312
- ):
312
+ ) -> Union[StableDiffusionXLPipelineOutput, Tuple]:
313
313
  r"""
314
314
  Function invoked when calling the pipeline for generation.
315
315
 
@@ -465,8 +465,6 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
465
465
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
466
466
  `._callback_tensor_inputs` attribute of your pipeine class.
467
467
 
468
- Examples:
469
-
470
468
  Returns:
471
469
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
472
470
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
@@ -203,7 +203,7 @@ class RBLNRetinaFaceFilter(RetinaFaceFilter):
203
203
  f"If you only need to compile the model without loading it to NPU, you can use:\n"
204
204
  f" from_pretrained(..., rbln_create_runtimes=False) or\n"
205
205
  f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
206
- f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
206
+ f"To check your NPU status, run the 'rbln-smi' command in your terminal.\n"
207
207
  f"Make sure your NPU is properly installed and operational."
208
208
  )
209
209
  raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
@@ -278,7 +278,7 @@ class RBLNVideoSafetyModel(VideoSafetyModel):
278
278
  f"If you only need to compile the model without loading it to NPU, you can use:\n"
279
279
  f" from_pretrained(..., rbln_create_runtimes=False) or\n"
280
280
  f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
281
- f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
281
+ f"To check your NPU status, run the 'rbln-smi' command in your terminal.\n"
282
282
  f"Make sure your NPU is properly installed and operational."
283
283
  )
284
284
  raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
@@ -24,7 +24,7 @@ import torch
24
24
  from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
25
25
  from transformers.utils.hub import PushToHubMixin
26
26
 
27
- from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
27
+ from .configuration_utils import RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
28
28
  from .utils.hub import pull_compiled_model_from_hub, validate_files
29
29
  from .utils.logging import get_logger
30
30
  from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
@@ -206,8 +206,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
206
206
  f"does not match the expected model class name ({cls.__name__})."
207
207
  )
208
208
 
209
- rbln_config, kwargs = RBLNAutoConfig.load(
210
- model_path_subfolder, passed_rbln_config=rbln_config, kwargs=kwargs, return_unused_kwargs=True
209
+ config_cls = cls.get_rbln_config_class()
210
+ rbln_config, kwargs = config_cls.from_pretrained(
211
+ model_path_subfolder, rbln_config=rbln_config, return_unused_kwargs=True, **kwargs
211
212
  )
212
213
 
213
214
  if rbln_config.rbln_model_cls_name != cls.__name__:
@@ -306,7 +307,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
306
307
  f"If you only need to compile the model without loading it to NPU, you can use:\n"
307
308
  f" from_pretrained(..., rbln_create_runtimes=False) or\n"
308
309
  f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
309
- f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
310
+ f"To check your NPU status, run the 'rbln-smi' command in your terminal.\n"
310
311
  f"Make sure your NPU is properly installed and operational."
311
312
  )
312
313
  raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
@@ -68,6 +68,8 @@ _import_structure = {
68
68
  "RBLNDecoderOnlyModelForCausalLMConfig",
69
69
  "RBLNDecoderOnlyModelConfig",
70
70
  "RBLNDecoderOnlyModel",
71
+ "RBLNDetrForObjectDetection",
72
+ "RBLNDetrForObjectDetectionConfig",
71
73
  "RBLNDistilBertForQuestionAnswering",
72
74
  "RBLNDistilBertForQuestionAnsweringConfig",
73
75
  "RBLNDPTForDepthEstimation",
@@ -130,6 +132,8 @@ _import_structure = {
130
132
  "RBLNMistralForCausalLMConfig",
131
133
  "RBLNMistralModel",
132
134
  "RBLNMistralModelConfig",
135
+ "RBLNMixtralForCausalLM",
136
+ "RBLNMixtralForCausalLMConfig",
133
137
  "RBLNOPTForCausalLM",
134
138
  "RBLNOPTForCausalLMConfig",
135
139
  "RBLNOPTModel",
@@ -246,6 +250,8 @@ if TYPE_CHECKING:
246
250
  RBLNDecoderOnlyModelForCausalLMConfig,
247
251
  RBLNDepthAnythingForDepthEstimation,
248
252
  RBLNDepthAnythingForDepthEstimationConfig,
253
+ RBLNDetrForObjectDetection,
254
+ RBLNDetrForObjectDetectionConfig,
249
255
  RBLNDistilBertForQuestionAnswering,
250
256
  RBLNDistilBertForQuestionAnsweringConfig,
251
257
  RBLNDPTForDepthEstimation,
@@ -296,6 +302,8 @@ if TYPE_CHECKING:
296
302
  RBLNMistralForCausalLMConfig,
297
303
  RBLNMistralModel,
298
304
  RBLNMistralModelConfig,
305
+ RBLNMixtralForCausalLM,
306
+ RBLNMixtralForCausalLMConfig,
299
307
  RBLNOPTForCausalLM,
300
308
  RBLNOPTForCausalLMConfig,
301
309
  RBLNOPTModel,