optimum-rbln 0.9.5a4__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +8 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +196 -52
- optimum/rbln/diffusers/models/controlnet.py +2 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
- optimum/rbln/modeling_base.py +5 -4
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/modeling_attention_utils.py +15 -9
- optimum/rbln/transformers/models/__init__.py +10 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +2 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +26 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +2 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +45 -21
- optimum/rbln/transformers/models/detr/__init__.py +23 -0
- optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
- optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +4 -176
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +4 -3
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +10 -7
- optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
- optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
- optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +7 -7
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +2 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +2 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
- optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
- optimum/rbln/utils/deprecation.py +78 -1
- optimum/rbln/utils/hub.py +93 -2
- optimum/rbln/utils/runtime_utils.py +2 -2
- {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +49 -42
- {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
|
@@ -86,6 +86,8 @@ _import_structure = {
|
|
|
86
86
|
"RBLNDPTForDepthEstimationConfig",
|
|
87
87
|
"RBLNDepthAnythingForDepthEstimationConfig",
|
|
88
88
|
"RBLNDepthAnythingForDepthEstimation",
|
|
89
|
+
"RBLNDetrForObjectDetection",
|
|
90
|
+
"RBLNDetrForObjectDetectionConfig",
|
|
89
91
|
"RBLNExaoneForCausalLM",
|
|
90
92
|
"RBLNExaoneForCausalLMConfig",
|
|
91
93
|
"RBLNGemmaModel",
|
|
@@ -120,6 +122,8 @@ _import_structure = {
|
|
|
120
122
|
"RBLNLlamaForCausalLMConfig",
|
|
121
123
|
"RBLNLlamaModel",
|
|
122
124
|
"RBLNLlamaModelConfig",
|
|
125
|
+
"RBLNMixtralForCausalLM",
|
|
126
|
+
"RBLNMixtralForCausalLMConfig",
|
|
123
127
|
"RBLNOPTForCausalLM",
|
|
124
128
|
"RBLNOPTForCausalLMConfig",
|
|
125
129
|
"RBLNLlavaForConditionalGeneration",
|
|
@@ -406,6 +410,8 @@ if TYPE_CHECKING:
|
|
|
406
410
|
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
407
411
|
RBLNDepthAnythingForDepthEstimation,
|
|
408
412
|
RBLNDepthAnythingForDepthEstimationConfig,
|
|
413
|
+
RBLNDetrForObjectDetection,
|
|
414
|
+
RBLNDetrForObjectDetectionConfig,
|
|
409
415
|
RBLNDistilBertForQuestionAnswering,
|
|
410
416
|
RBLNDistilBertForQuestionAnsweringConfig,
|
|
411
417
|
RBLNDPTForDepthEstimation,
|
|
@@ -456,6 +462,8 @@ if TYPE_CHECKING:
|
|
|
456
462
|
RBLNMistralForCausalLMConfig,
|
|
457
463
|
RBLNMistralModel,
|
|
458
464
|
RBLNMistralModelConfig,
|
|
465
|
+
RBLNMixtralForCausalLM,
|
|
466
|
+
RBLNMixtralForCausalLMConfig,
|
|
459
467
|
RBLNOPTForCausalLM,
|
|
460
468
|
RBLNOPTForCausalLMConfig,
|
|
461
469
|
RBLNOPTModel,
|
optimum/rbln/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.10.0.post1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 10, 0, 'post1')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -24,7 +24,7 @@ import torch
|
|
|
24
24
|
from packaging.version import Version
|
|
25
25
|
|
|
26
26
|
from .__version__ import __version__
|
|
27
|
-
from .utils.deprecation import deprecate_kwarg, warn_deprecated_npu
|
|
27
|
+
from .utils.deprecation import deprecate_kwarg, deprecate_method, warn_deprecated_npu
|
|
28
28
|
from .utils.logging import get_logger
|
|
29
29
|
from .utils.runtime_utils import ContextRblnConfig
|
|
30
30
|
|
|
@@ -36,6 +36,30 @@ DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
|
|
|
36
36
|
TypeInputInfo = List[Tuple[str, Tuple[int], str]]
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
def nested_update(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
|
40
|
+
"""
|
|
41
|
+
Recursively merge override dict into base dict.
|
|
42
|
+
For nested dicts, values are merged recursively instead of being replaced.
|
|
43
|
+
For non-dict values, override takes precedence.
|
|
44
|
+
Args:
|
|
45
|
+
base: The base dictionary to merge into (modified in-place).
|
|
46
|
+
override: The dictionary with values to merge.
|
|
47
|
+
Returns:
|
|
48
|
+
The merged base dictionary.
|
|
49
|
+
Example:
|
|
50
|
+
>>> base = {"a": 1, "nested": {"x": 10, "y": 20}}
|
|
51
|
+
>>> override = {"b": 2, "nested": {"y": 30, "z": 40}}
|
|
52
|
+
>>> nested_update(base, override)
|
|
53
|
+
{"a": 1, "b": 2, "nested": {"x": 10, "y": 30, "z": 40}}
|
|
54
|
+
"""
|
|
55
|
+
for key, value in override.items():
|
|
56
|
+
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
|
|
57
|
+
nested_update(base[key], value)
|
|
58
|
+
else:
|
|
59
|
+
base[key] = value
|
|
60
|
+
return base
|
|
61
|
+
|
|
62
|
+
|
|
39
63
|
@runtime_checkable
|
|
40
64
|
class RBLNSerializableConfigProtocol(Protocol):
|
|
41
65
|
def _prepare_for_serialization(self) -> Dict[str, Any]: ...
|
|
@@ -216,8 +240,7 @@ class RBLNAutoConfig:
|
|
|
216
240
|
For example, the parsed contents of `rbln_config.json`.
|
|
217
241
|
|
|
218
242
|
Returns:
|
|
219
|
-
RBLNModelConfig: A configuration instance. The specific subclass is
|
|
220
|
-
selected by `config_dict["cls_name"]`.
|
|
243
|
+
RBLNModelConfig: A configuration instance. The specific subclass is selected by `config_dict["cls_name"]`.
|
|
221
244
|
|
|
222
245
|
Raises:
|
|
223
246
|
ValueError: If `cls_name` is missing.
|
|
@@ -256,12 +279,13 @@ class RBLNAutoConfig:
|
|
|
256
279
|
|
|
257
280
|
CONFIG_MAPPING[config.__name__] = config
|
|
258
281
|
|
|
259
|
-
@
|
|
260
|
-
def
|
|
282
|
+
@classmethod
|
|
283
|
+
def from_pretrained(
|
|
284
|
+
cls,
|
|
261
285
|
path: str,
|
|
262
|
-
|
|
263
|
-
kwargs: Optional[Dict[str, Any]] = None,
|
|
286
|
+
rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
|
264
287
|
return_unused_kwargs: bool = False,
|
|
288
|
+
**kwargs: Optional[Dict[str, Any]],
|
|
265
289
|
) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
|
|
266
290
|
"""
|
|
267
291
|
Load RBLNModelConfig from a path.
|
|
@@ -269,53 +293,58 @@ class RBLNAutoConfig:
|
|
|
269
293
|
|
|
270
294
|
Args:
|
|
271
295
|
path (str): Path to the RBLNModelConfig.
|
|
272
|
-
|
|
296
|
+
rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
|
|
297
|
+
return_unused_kwargs (bool): Whether to return unused kwargs.
|
|
298
|
+
kwargs: Additional keyword arguments to override configuration values.
|
|
273
299
|
|
|
274
300
|
Returns:
|
|
275
301
|
RBLNModelConfig: The loaded RBLNModelConfig.
|
|
276
|
-
"""
|
|
277
|
-
if kwargs is None:
|
|
278
|
-
kwargs = {}
|
|
279
|
-
cls, config_file = load_config(path)
|
|
280
|
-
|
|
281
|
-
rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
|
|
282
|
-
rbln_runtime_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in RUNTIME_KEYWORDS}
|
|
283
|
-
rbln_submodule_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in cls.submodules}
|
|
284
|
-
|
|
285
|
-
rbln_kwargs = {
|
|
286
|
-
key[5:]: kwargs.pop(key)
|
|
287
|
-
for key in rbln_keys
|
|
288
|
-
if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
|
|
289
|
-
}
|
|
290
302
|
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
303
|
+
Examples:
|
|
304
|
+
```python
|
|
305
|
+
config = RBLNAutoConfig.from_pretrained("/path/to/model")
|
|
306
|
+
```
|
|
307
|
+
"""
|
|
308
|
+
target_cls, _ = load_config(path)
|
|
309
|
+
return target_cls.from_pretrained(
|
|
310
|
+
path, rbln_config=rbln_config, return_unused_kwargs=return_unused_kwargs, **kwargs
|
|
311
|
+
)
|
|
298
312
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
313
|
+
@classmethod
|
|
314
|
+
@deprecate_method(version="0.11.0", new_method="from_pretrained")
|
|
315
|
+
def load(
|
|
316
|
+
cls,
|
|
317
|
+
path: str,
|
|
318
|
+
rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
|
319
|
+
return_unused_kwargs: bool = False,
|
|
320
|
+
**kwargs: Optional[Dict[str, Any]],
|
|
321
|
+
) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
|
|
322
|
+
"""
|
|
323
|
+
Load RBLNModelConfig from a path.
|
|
324
|
+
Class name is automatically inferred from the `rbln_config.json` file.
|
|
302
325
|
|
|
303
|
-
|
|
326
|
+
Deprecated:
|
|
327
|
+
This method is deprecated and will be removed in version 0.11.0.
|
|
328
|
+
Use `from_pretrained` instead.
|
|
304
329
|
|
|
305
|
-
|
|
330
|
+
Args:
|
|
331
|
+
path (str): Path to the RBLNModelConfig file or directory.
|
|
332
|
+
rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
|
|
333
|
+
return_unused_kwargs (bool): Whether to return unused kwargs.
|
|
334
|
+
kwargs: Additional keyword arguments to override configuration values.
|
|
306
335
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
if getattr(rbln_config, key) != value:
|
|
310
|
-
raise ValueError(
|
|
311
|
-
f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
|
|
312
|
-
f"Since the value is already set to {getattr(rbln_config, key)}"
|
|
313
|
-
)
|
|
336
|
+
Returns:
|
|
337
|
+
RBLNModelConfig: The loaded RBLNModelConfig.
|
|
314
338
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
339
|
+
Examples:
|
|
340
|
+
```python
|
|
341
|
+
# Deprecated usage:
|
|
342
|
+
config = RBLNAutoConfig.load("/path/to/model")
|
|
343
|
+
# Recommended usage:
|
|
344
|
+
config = RBLNAutoConfig.from_pretrained("/path/to/model")
|
|
345
|
+
```
|
|
346
|
+
"""
|
|
347
|
+
return cls.from_pretrained(path, rbln_config=rbln_config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
|
319
348
|
|
|
320
349
|
|
|
321
350
|
class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
@@ -866,15 +895,23 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
866
895
|
json.dump(serializable_data, jsonf, indent=2)
|
|
867
896
|
|
|
868
897
|
@classmethod
|
|
869
|
-
def
|
|
898
|
+
def from_pretrained(
|
|
899
|
+
cls,
|
|
900
|
+
path: str,
|
|
901
|
+
rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
|
902
|
+
return_unused_kwargs: bool = False,
|
|
903
|
+
**kwargs: Optional[Dict[str, Any]],
|
|
904
|
+
) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
|
|
870
905
|
"""
|
|
871
906
|
Load a RBLNModelConfig from a path.
|
|
872
907
|
|
|
873
908
|
Args:
|
|
874
909
|
path (str): Path to the RBLNModelConfig file or directory containing the config file.
|
|
910
|
+
rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
|
|
911
|
+
return_unused_kwargs (bool): Whether to return unused kwargs.
|
|
875
912
|
kwargs: Additional keyword arguments to override configuration values.
|
|
876
|
-
|
|
877
|
-
|
|
913
|
+
Keys starting with 'rbln_' will have the prefix removed and be used
|
|
914
|
+
to update the configuration.
|
|
878
915
|
|
|
879
916
|
Returns:
|
|
880
917
|
RBLNModelConfig: The loaded configuration instance.
|
|
@@ -883,17 +920,109 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
883
920
|
This method loads the configuration from the specified path and applies any
|
|
884
921
|
provided overrides. If the loaded configuration class doesn't match the expected
|
|
885
922
|
class, a warning will be logged.
|
|
923
|
+
|
|
924
|
+
Examples:
|
|
925
|
+
```python
|
|
926
|
+
config = RBLNResNetForImageClassificationConfig.from_pretrained("/path/to/model")
|
|
927
|
+
```
|
|
886
928
|
"""
|
|
887
929
|
cls_reserved, config_file = load_config(path)
|
|
888
|
-
|
|
889
930
|
if cls_reserved != cls:
|
|
890
931
|
logger.warning(f"Expected {cls.__name__}, but got {cls_reserved.__name__}.")
|
|
891
932
|
|
|
933
|
+
if isinstance(rbln_config, dict):
|
|
934
|
+
for key, value in rbln_config.items():
|
|
935
|
+
if key not in kwargs:
|
|
936
|
+
kwargs[f"rbln_{key}"] = value
|
|
937
|
+
|
|
892
938
|
rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
|
|
893
|
-
|
|
894
|
-
|
|
939
|
+
rbln_runtime_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in RUNTIME_KEYWORDS}
|
|
940
|
+
rbln_submodule_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in cls.submodules}
|
|
941
|
+
|
|
942
|
+
rbln_kwargs = {
|
|
943
|
+
key[5:]: kwargs.pop(key)
|
|
944
|
+
for key in rbln_keys
|
|
945
|
+
if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
|
|
946
|
+
}
|
|
947
|
+
|
|
948
|
+
# Process submodule's rbln_config
|
|
949
|
+
for submodule in cls.submodules:
|
|
950
|
+
if submodule not in config_file:
|
|
951
|
+
raise ValueError(f"Submodule {submodule} not found in rbln_config.json.")
|
|
952
|
+
submodule_config = config_file[submodule]
|
|
953
|
+
submodule_config.update(rbln_runtime_kwargs)
|
|
895
954
|
|
|
896
|
-
|
|
955
|
+
update_dict = rbln_submodule_kwargs.pop(submodule, {})
|
|
956
|
+
if update_dict:
|
|
957
|
+
nested_update(submodule_config, update_dict)
|
|
958
|
+
config_file[submodule] = RBLNAutoConfig.load_from_dict(submodule_config)
|
|
959
|
+
|
|
960
|
+
if isinstance(rbln_config, RBLNModelConfig):
|
|
961
|
+
config_file.update(rbln_config._runtime_options)
|
|
962
|
+
|
|
963
|
+
# update submodule runtime
|
|
964
|
+
for submodule in rbln_config.submodules:
|
|
965
|
+
if str(config_file[submodule]) != str(getattr(rbln_config, submodule)):
|
|
966
|
+
raise ValueError(
|
|
967
|
+
f"Passed rbln_config has different attributes for submodule {submodule} than the config_file"
|
|
968
|
+
)
|
|
969
|
+
config_file[submodule] = getattr(rbln_config, submodule)
|
|
970
|
+
|
|
971
|
+
config_file.update(rbln_runtime_kwargs)
|
|
972
|
+
rbln_config = cls(**config_file)
|
|
973
|
+
if len(rbln_kwargs) > 0:
|
|
974
|
+
for key, value in rbln_kwargs.items():
|
|
975
|
+
if getattr(rbln_config, key) != value:
|
|
976
|
+
raise ValueError(
|
|
977
|
+
f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
|
|
978
|
+
f"Since the value is already set to {getattr(rbln_config, key)}"
|
|
979
|
+
)
|
|
980
|
+
if return_unused_kwargs:
|
|
981
|
+
return rbln_config, kwargs
|
|
982
|
+
else:
|
|
983
|
+
return rbln_config
|
|
984
|
+
|
|
985
|
+
@classmethod
|
|
986
|
+
@deprecate_method(version="0.11.0", new_method="from_pretrained")
|
|
987
|
+
def load(
|
|
988
|
+
cls,
|
|
989
|
+
path: str,
|
|
990
|
+
rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
|
991
|
+
return_unused_kwargs: bool = False,
|
|
992
|
+
**kwargs: Optional[Dict[str, Any]],
|
|
993
|
+
) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
|
|
994
|
+
"""
|
|
995
|
+
Load a RBLNModelConfig from a path.
|
|
996
|
+
|
|
997
|
+
Deprecated:
|
|
998
|
+
This method is deprecated and will be removed in version 0.11.0.
|
|
999
|
+
Use `from_pretrained` instead.
|
|
1000
|
+
|
|
1001
|
+
Args:
|
|
1002
|
+
path (str): Path to the RBLNModelConfig file or directory containing the config file.
|
|
1003
|
+
rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
|
|
1004
|
+
return_unused_kwargs (bool): Whether to return unused kwargs.
|
|
1005
|
+
kwargs: Additional keyword arguments to override configuration values.
|
|
1006
|
+
Keys starting with 'rbln_' will have the prefix removed and be used
|
|
1007
|
+
to update the configuration.
|
|
1008
|
+
|
|
1009
|
+
Returns:
|
|
1010
|
+
RBLNModelConfig: The loaded configuration instance.
|
|
1011
|
+
|
|
1012
|
+
Note:
|
|
1013
|
+
This method loads the configuration from the specified path and applies any
|
|
1014
|
+
provided overrides. If the loaded configuration class doesn't match the expected
|
|
1015
|
+
class, a warning will be logged.
|
|
1016
|
+
|
|
1017
|
+
Examples:
|
|
1018
|
+
```python
|
|
1019
|
+
# Deprecated usage:
|
|
1020
|
+
config = RBLNResNetForImageClassificationConfig.load("/path/to/model")
|
|
1021
|
+
# Recommended usage:
|
|
1022
|
+
config = RBLNResNetForImageClassificationConfig.from_pretrained("/path/to/model")
|
|
1023
|
+
```
|
|
1024
|
+
"""
|
|
1025
|
+
return cls.from_pretrained(path, rbln_config=rbln_config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
|
897
1026
|
|
|
898
1027
|
@classmethod
|
|
899
1028
|
def initialize_from_kwargs(
|
|
@@ -993,3 +1122,18 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
993
1122
|
@timeout.setter
|
|
994
1123
|
def timeout(self, timeout: int):
|
|
995
1124
|
self._runtime_options["timeout"] = timeout
|
|
1125
|
+
|
|
1126
|
+
|
|
1127
|
+
def convert_rbln_config_dict(
|
|
1128
|
+
rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None, **kwargs
|
|
1129
|
+
) -> Tuple[Optional[Union[Dict[str, Any], RBLNModelConfig]], Dict[str, Any]]:
|
|
1130
|
+
# Validate and merge rbln_ prefixed kwargs into rbln_config
|
|
1131
|
+
kwargs_keys = list(kwargs.keys())
|
|
1132
|
+
rbln_kwargs = {key[5:]: kwargs.pop(key) for key in kwargs_keys if key.startswith("rbln_")}
|
|
1133
|
+
|
|
1134
|
+
rbln_config = {} if rbln_config is None else rbln_config
|
|
1135
|
+
|
|
1136
|
+
if isinstance(rbln_config, dict) and len(rbln_kwargs) > 0:
|
|
1137
|
+
rbln_config.update(rbln_kwargs)
|
|
1138
|
+
|
|
1139
|
+
return rbln_config, kwargs
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, Dict, Optional, Union
|
|
15
|
+
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from diffusers import ControlNetModel
|
|
@@ -218,7 +218,7 @@ class RBLNControlNetModel(RBLNModel):
|
|
|
218
218
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
219
219
|
return_dict: bool = True,
|
|
220
220
|
**kwargs,
|
|
221
|
-
):
|
|
221
|
+
) -> Union[ControlNetOutput, Tuple]:
|
|
222
222
|
"""
|
|
223
223
|
Forward pass for the RBLN-optimized ControlNetModel.
|
|
224
224
|
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from pathlib import Path
|
|
16
|
-
from typing import TYPE_CHECKING, Optional, Union
|
|
16
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
from diffusers.models.transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
|
|
@@ -134,7 +134,7 @@ class RBLNPriorTransformer(RBLNModel):
|
|
|
134
134
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
135
135
|
attention_mask: Optional[torch.Tensor] = None,
|
|
136
136
|
return_dict: bool = True,
|
|
137
|
-
):
|
|
137
|
+
) -> Union[PriorTransformerOutput, Tuple]:
|
|
138
138
|
"""
|
|
139
139
|
Forward pass for the RBLN-optimized PriorTransformer.
|
|
140
140
|
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from pathlib import Path
|
|
16
|
-
from typing import TYPE_CHECKING, List, Optional, Union
|
|
16
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|
17
17
|
|
|
18
18
|
import rebel
|
|
19
19
|
import torch
|
|
@@ -302,7 +302,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
|
|
|
302
302
|
condition_mask: Optional[torch.Tensor] = None,
|
|
303
303
|
padding_mask: Optional[torch.Tensor] = None,
|
|
304
304
|
return_dict: bool = True,
|
|
305
|
-
):
|
|
305
|
+
) -> Union[Transformer2DModelOutput, Tuple]:
|
|
306
306
|
"""
|
|
307
307
|
Forward pass for the RBLN-optimized CosmosTransformer3DModel.
|
|
308
308
|
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
@@ -160,7 +160,7 @@ class RBLNSD3Transformer2DModel(RBLNModel):
|
|
|
160
160
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
161
161
|
return_dict: bool = True,
|
|
162
162
|
**kwargs,
|
|
163
|
-
):
|
|
163
|
+
) -> Union[Transformer2DModelOutput, Tuple]:
|
|
164
164
|
"""
|
|
165
165
|
Forward pass for the RBLN-optimized SD3Transformer2DModel.
|
|
166
166
|
|
|
@@ -176,7 +176,7 @@ class RBLNAutoPipelineBase:
|
|
|
176
176
|
export: bool = None,
|
|
177
177
|
rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None,
|
|
178
178
|
**kwargs: Any,
|
|
179
|
-
):
|
|
179
|
+
) -> RBLNBaseModel:
|
|
180
180
|
"""
|
|
181
181
|
Load an RBLN-accelerated Diffusers pipeline from a pretrained checkpoint or a compiled RBLN artifact.
|
|
182
182
|
|
|
@@ -201,8 +201,7 @@ class RBLNAutoPipelineBase:
|
|
|
201
201
|
- Remaining arguments are forwarded to the Diffusers loader.
|
|
202
202
|
|
|
203
203
|
Returns:
|
|
204
|
-
RBLNBaseModel: An instantiated RBLN model wrapping the Diffusers pipeline, ready for
|
|
205
|
-
inference on RBLN NPUs.
|
|
204
|
+
RBLNBaseModel: An instantiated RBLN model wrapping the Diffusers pipeline, ready for inference on RBLN NPUs.
|
|
206
205
|
|
|
207
206
|
"""
|
|
208
207
|
rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
|
|
@@ -26,7 +26,7 @@
|
|
|
26
26
|
# See the License for the specific language governing permissions and
|
|
27
27
|
# limitations under the License.
|
|
28
28
|
|
|
29
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
|
29
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
30
30
|
|
|
31
31
|
import torch
|
|
32
32
|
import torch.nn.functional as F
|
|
@@ -260,7 +260,7 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
|
|
|
260
260
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
261
261
|
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
|
262
262
|
**kwargs,
|
|
263
|
-
):
|
|
263
|
+
) -> Union[StableDiffusionPipelineOutput, Tuple]:
|
|
264
264
|
r"""
|
|
265
265
|
The call function to the pipeline for generation.
|
|
266
266
|
|
|
@@ -321,14 +321,7 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
|
|
|
321
321
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
322
322
|
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
|
323
323
|
return_dict (`bool`, *optional*, defaults to `True`):
|
|
324
|
-
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
|
325
|
-
plain tuple.
|
|
326
|
-
callback (`Callable`, *optional*):
|
|
327
|
-
A function that calls every `callback_steps` steps during inference. The function is called with the
|
|
328
|
-
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
|
329
|
-
callback_steps (`int`, *optional*, defaults to 1):
|
|
330
|
-
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
|
331
|
-
every step.
|
|
324
|
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple.
|
|
332
325
|
cross_attention_kwargs (`dict`, *optional*):
|
|
333
326
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
|
334
327
|
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
@@ -356,8 +349,6 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
|
|
|
356
349
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
|
357
350
|
`._callback_tensor_inputs` attribute of your pipeine class.
|
|
358
351
|
|
|
359
|
-
Examples:
|
|
360
|
-
|
|
361
352
|
Returns:
|
|
362
353
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
363
354
|
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
|
@@ -26,7 +26,7 @@
|
|
|
26
26
|
# See the License for the specific language governing permissions and
|
|
27
27
|
# limitations under the License.
|
|
28
28
|
|
|
29
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
|
29
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
30
30
|
|
|
31
31
|
import torch
|
|
32
32
|
import torch.nn.functional as F
|
|
@@ -253,7 +253,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
|
|
|
253
253
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
254
254
|
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
|
255
255
|
**kwargs,
|
|
256
|
-
):
|
|
256
|
+
) -> Union[StableDiffusionPipelineOutput, Tuple]:
|
|
257
257
|
r"""
|
|
258
258
|
The call function to the pipeline for generation.
|
|
259
259
|
|
|
@@ -347,8 +347,6 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
|
|
|
347
347
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
|
348
348
|
`._callback_tensor_inputs` attribute of your pipeine class.
|
|
349
349
|
|
|
350
|
-
Examples:
|
|
351
|
-
|
|
352
350
|
Returns:
|
|
353
351
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
354
352
|
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
|
@@ -294,7 +294,7 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
|
|
|
294
294
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
295
295
|
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
|
296
296
|
**kwargs,
|
|
297
|
-
):
|
|
297
|
+
) -> Union[StableDiffusionXLPipelineOutput, Tuple]:
|
|
298
298
|
r"""
|
|
299
299
|
The call function to the pipeline for generation.
|
|
300
300
|
|
|
@@ -431,8 +431,6 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
|
|
|
431
431
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
|
432
432
|
`._callback_tensor_inputs` attribute of your pipeine class.
|
|
433
433
|
|
|
434
|
-
Examples:
|
|
435
|
-
|
|
436
434
|
Returns:
|
|
437
435
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
438
436
|
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
|
@@ -309,7 +309,7 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
|
|
|
309
309
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
310
310
|
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
|
311
311
|
**kwargs,
|
|
312
|
-
):
|
|
312
|
+
) -> Union[StableDiffusionXLPipelineOutput, Tuple]:
|
|
313
313
|
r"""
|
|
314
314
|
Function invoked when calling the pipeline for generation.
|
|
315
315
|
|
|
@@ -465,8 +465,6 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
|
|
|
465
465
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
|
466
466
|
`._callback_tensor_inputs` attribute of your pipeine class.
|
|
467
467
|
|
|
468
|
-
Examples:
|
|
469
|
-
|
|
470
468
|
Returns:
|
|
471
469
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
472
470
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
|
|
@@ -203,7 +203,7 @@ class RBLNRetinaFaceFilter(RetinaFaceFilter):
|
|
|
203
203
|
f"If you only need to compile the model without loading it to NPU, you can use:\n"
|
|
204
204
|
f" from_pretrained(..., rbln_create_runtimes=False) or\n"
|
|
205
205
|
f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
|
|
206
|
-
f"To check your NPU status, run the 'rbln-
|
|
206
|
+
f"To check your NPU status, run the 'rbln-smi' command in your terminal.\n"
|
|
207
207
|
f"Make sure your NPU is properly installed and operational."
|
|
208
208
|
)
|
|
209
209
|
raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
|
|
@@ -278,7 +278,7 @@ class RBLNVideoSafetyModel(VideoSafetyModel):
|
|
|
278
278
|
f"If you only need to compile the model without loading it to NPU, you can use:\n"
|
|
279
279
|
f" from_pretrained(..., rbln_create_runtimes=False) or\n"
|
|
280
280
|
f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
|
|
281
|
-
f"To check your NPU status, run the 'rbln-
|
|
281
|
+
f"To check your NPU status, run the 'rbln-smi' command in your terminal.\n"
|
|
282
282
|
f"Make sure your NPU is properly installed and operational."
|
|
283
283
|
)
|
|
284
284
|
raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
|
optimum/rbln/modeling_base.py
CHANGED
|
@@ -24,7 +24,7 @@ import torch
|
|
|
24
24
|
from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
|
|
25
25
|
from transformers.utils.hub import PushToHubMixin
|
|
26
26
|
|
|
27
|
-
from .configuration_utils import
|
|
27
|
+
from .configuration_utils import RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
|
|
28
28
|
from .utils.hub import pull_compiled_model_from_hub, validate_files
|
|
29
29
|
from .utils.logging import get_logger
|
|
30
30
|
from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
|
|
@@ -206,8 +206,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
206
206
|
f"does not match the expected model class name ({cls.__name__})."
|
|
207
207
|
)
|
|
208
208
|
|
|
209
|
-
|
|
210
|
-
|
|
209
|
+
config_cls = cls.get_rbln_config_class()
|
|
210
|
+
rbln_config, kwargs = config_cls.from_pretrained(
|
|
211
|
+
model_path_subfolder, rbln_config=rbln_config, return_unused_kwargs=True, **kwargs
|
|
211
212
|
)
|
|
212
213
|
|
|
213
214
|
if rbln_config.rbln_model_cls_name != cls.__name__:
|
|
@@ -306,7 +307,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
306
307
|
f"If you only need to compile the model without loading it to NPU, you can use:\n"
|
|
307
308
|
f" from_pretrained(..., rbln_create_runtimes=False) or\n"
|
|
308
309
|
f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
|
|
309
|
-
f"To check your NPU status, run the 'rbln-
|
|
310
|
+
f"To check your NPU status, run the 'rbln-smi' command in your terminal.\n"
|
|
310
311
|
f"Make sure your NPU is properly installed and operational."
|
|
311
312
|
)
|
|
312
313
|
raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
|
|
@@ -68,6 +68,8 @@ _import_structure = {
|
|
|
68
68
|
"RBLNDecoderOnlyModelForCausalLMConfig",
|
|
69
69
|
"RBLNDecoderOnlyModelConfig",
|
|
70
70
|
"RBLNDecoderOnlyModel",
|
|
71
|
+
"RBLNDetrForObjectDetection",
|
|
72
|
+
"RBLNDetrForObjectDetectionConfig",
|
|
71
73
|
"RBLNDistilBertForQuestionAnswering",
|
|
72
74
|
"RBLNDistilBertForQuestionAnsweringConfig",
|
|
73
75
|
"RBLNDPTForDepthEstimation",
|
|
@@ -130,6 +132,8 @@ _import_structure = {
|
|
|
130
132
|
"RBLNMistralForCausalLMConfig",
|
|
131
133
|
"RBLNMistralModel",
|
|
132
134
|
"RBLNMistralModelConfig",
|
|
135
|
+
"RBLNMixtralForCausalLM",
|
|
136
|
+
"RBLNMixtralForCausalLMConfig",
|
|
133
137
|
"RBLNOPTForCausalLM",
|
|
134
138
|
"RBLNOPTForCausalLMConfig",
|
|
135
139
|
"RBLNOPTModel",
|
|
@@ -246,6 +250,8 @@ if TYPE_CHECKING:
|
|
|
246
250
|
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
247
251
|
RBLNDepthAnythingForDepthEstimation,
|
|
248
252
|
RBLNDepthAnythingForDepthEstimationConfig,
|
|
253
|
+
RBLNDetrForObjectDetection,
|
|
254
|
+
RBLNDetrForObjectDetectionConfig,
|
|
249
255
|
RBLNDistilBertForQuestionAnswering,
|
|
250
256
|
RBLNDistilBertForQuestionAnsweringConfig,
|
|
251
257
|
RBLNDPTForDepthEstimation,
|
|
@@ -296,6 +302,8 @@ if TYPE_CHECKING:
|
|
|
296
302
|
RBLNMistralForCausalLMConfig,
|
|
297
303
|
RBLNMistralModel,
|
|
298
304
|
RBLNMistralModelConfig,
|
|
305
|
+
RBLNMixtralForCausalLM,
|
|
306
|
+
RBLNMixtralForCausalLMConfig,
|
|
299
307
|
RBLNOPTForCausalLM,
|
|
300
308
|
RBLNOPTForCausalLMConfig,
|
|
301
309
|
RBLNOPTModel,
|