optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +44 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +230 -67
- optimum/rbln/diffusers/models/controlnet.py +2 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
- optimum/rbln/modeling_base.py +11 -10
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +44 -0
- optimum/rbln/transformers/modeling_attention_utils.py +124 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +38 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
- optimum/rbln/transformers/models/detr/__init__.py +23 -0
- optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
- optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
- optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
- optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/deprecation.py +78 -1
- optimum/rbln/utils/hub.py +93 -2
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +12 -8
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
|
@@ -86,11 +86,17 @@ _import_structure = {
|
|
|
86
86
|
"RBLNDPTForDepthEstimationConfig",
|
|
87
87
|
"RBLNDepthAnythingForDepthEstimationConfig",
|
|
88
88
|
"RBLNDepthAnythingForDepthEstimation",
|
|
89
|
+
"RBLNDetrForObjectDetection",
|
|
90
|
+
"RBLNDetrForObjectDetectionConfig",
|
|
89
91
|
"RBLNExaoneForCausalLM",
|
|
90
92
|
"RBLNExaoneForCausalLMConfig",
|
|
91
93
|
"RBLNGemmaModel",
|
|
92
94
|
"RBLNGemmaModelConfig",
|
|
93
95
|
"RBLNGemmaForCausalLM",
|
|
96
|
+
"RBLNGemma2ForCausalLM",
|
|
97
|
+
"RBLNGemma2ForCausalLMConfig",
|
|
98
|
+
"RBLNGemma2Model",
|
|
99
|
+
"RBLNGemma2ModelConfig",
|
|
94
100
|
"RBLNGemmaForCausalLMConfig",
|
|
95
101
|
"RBLNGemma3ForCausalLM",
|
|
96
102
|
"RBLNGemma3ForCausalLMConfig",
|
|
@@ -100,6 +106,8 @@ _import_structure = {
|
|
|
100
106
|
"RBLNGPT2ModelConfig",
|
|
101
107
|
"RBLNGPT2LMHeadModel",
|
|
102
108
|
"RBLNGPT2LMHeadModelConfig",
|
|
109
|
+
"RBLNGptOssForCausalLM",
|
|
110
|
+
"RBLNGptOssForCausalLMConfig",
|
|
103
111
|
"RBLNGroundingDinoDecoder",
|
|
104
112
|
"RBLNGroundingDinoDecoderConfig",
|
|
105
113
|
"RBLNGroundingDinoForObjectDetection",
|
|
@@ -114,6 +122,8 @@ _import_structure = {
|
|
|
114
122
|
"RBLNLlamaForCausalLMConfig",
|
|
115
123
|
"RBLNLlamaModel",
|
|
116
124
|
"RBLNLlamaModelConfig",
|
|
125
|
+
"RBLNMixtralForCausalLM",
|
|
126
|
+
"RBLNMixtralForCausalLMConfig",
|
|
117
127
|
"RBLNOPTForCausalLM",
|
|
118
128
|
"RBLNOPTForCausalLMConfig",
|
|
119
129
|
"RBLNLlavaForConditionalGeneration",
|
|
@@ -140,14 +150,24 @@ _import_structure = {
|
|
|
140
150
|
"RBLNPixtralVisionModelConfig",
|
|
141
151
|
"RBLNPhiModel",
|
|
142
152
|
"RBLNPhiModelConfig",
|
|
153
|
+
"RBLNPaliGemmaForConditionalGeneration",
|
|
154
|
+
"RBLNPaliGemmaForConditionalGenerationConfig",
|
|
155
|
+
"RBLNPaliGemmaModel",
|
|
156
|
+
"RBLNPaliGemmaModelConfig",
|
|
143
157
|
"RBLNQwen2ForCausalLM",
|
|
144
158
|
"RBLNQwen2ForCausalLMConfig",
|
|
145
159
|
"RBLNQwen2_5_VisionTransformerPretrainedModel",
|
|
146
160
|
"RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
|
|
147
161
|
"RBLNQwen2_5_VLForConditionalGeneration",
|
|
148
162
|
"RBLNQwen2_5_VLForConditionalGenerationConfig",
|
|
163
|
+
"RBLNQwen3MoeForCausalLM",
|
|
164
|
+
"RBLNQwen3MoeForCausalLMConfig",
|
|
165
|
+
"RBLNQwen2_5_VLModel",
|
|
166
|
+
"RBLNQwen2_5_VLModelConfig",
|
|
149
167
|
"RBLNQwen2Model",
|
|
150
168
|
"RBLNQwen2ModelConfig",
|
|
169
|
+
"RBLNQwen2MoeForCausalLM",
|
|
170
|
+
"RBLNQwen2MoeForCausalLMConfig",
|
|
151
171
|
"RBLNQwen3ForCausalLM",
|
|
152
172
|
"RBLNQwen3ForCausalLMConfig",
|
|
153
173
|
"RBLNQwen3Model",
|
|
@@ -156,6 +176,8 @@ _import_structure = {
|
|
|
156
176
|
"RBLNQwen2VisionTransformerPretrainedModelConfig",
|
|
157
177
|
"RBLNQwen2VLForConditionalGeneration",
|
|
158
178
|
"RBLNQwen2VLForConditionalGenerationConfig",
|
|
179
|
+
"RBLNQwen2VLModel",
|
|
180
|
+
"RBLNQwen2VLModelConfig",
|
|
159
181
|
"RBLNResNetForImageClassification",
|
|
160
182
|
"RBLNResNetForImageClassificationConfig",
|
|
161
183
|
"RBLNRobertaForMaskedLM",
|
|
@@ -388,12 +410,18 @@ if TYPE_CHECKING:
|
|
|
388
410
|
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
389
411
|
RBLNDepthAnythingForDepthEstimation,
|
|
390
412
|
RBLNDepthAnythingForDepthEstimationConfig,
|
|
413
|
+
RBLNDetrForObjectDetection,
|
|
414
|
+
RBLNDetrForObjectDetectionConfig,
|
|
391
415
|
RBLNDistilBertForQuestionAnswering,
|
|
392
416
|
RBLNDistilBertForQuestionAnsweringConfig,
|
|
393
417
|
RBLNDPTForDepthEstimation,
|
|
394
418
|
RBLNDPTForDepthEstimationConfig,
|
|
395
419
|
RBLNExaoneForCausalLM,
|
|
396
420
|
RBLNExaoneForCausalLMConfig,
|
|
421
|
+
RBLNGemma2ForCausalLM,
|
|
422
|
+
RBLNGemma2ForCausalLMConfig,
|
|
423
|
+
RBLNGemma2Model,
|
|
424
|
+
RBLNGemma2ModelConfig,
|
|
397
425
|
RBLNGemma3ForCausalLM,
|
|
398
426
|
RBLNGemma3ForCausalLMConfig,
|
|
399
427
|
RBLNGemma3ForConditionalGeneration,
|
|
@@ -406,6 +434,8 @@ if TYPE_CHECKING:
|
|
|
406
434
|
RBLNGPT2LMHeadModelConfig,
|
|
407
435
|
RBLNGPT2Model,
|
|
408
436
|
RBLNGPT2ModelConfig,
|
|
437
|
+
RBLNGptOssForCausalLM,
|
|
438
|
+
RBLNGptOssForCausalLMConfig,
|
|
409
439
|
RBLNGroundingDinoDecoder,
|
|
410
440
|
RBLNGroundingDinoDecoderConfig,
|
|
411
441
|
RBLNGroundingDinoEncoder,
|
|
@@ -432,10 +462,16 @@ if TYPE_CHECKING:
|
|
|
432
462
|
RBLNMistralForCausalLMConfig,
|
|
433
463
|
RBLNMistralModel,
|
|
434
464
|
RBLNMistralModelConfig,
|
|
465
|
+
RBLNMixtralForCausalLM,
|
|
466
|
+
RBLNMixtralForCausalLMConfig,
|
|
435
467
|
RBLNOPTForCausalLM,
|
|
436
468
|
RBLNOPTForCausalLMConfig,
|
|
437
469
|
RBLNOPTModel,
|
|
438
470
|
RBLNOPTModelConfig,
|
|
471
|
+
RBLNPaliGemmaForConditionalGeneration,
|
|
472
|
+
RBLNPaliGemmaForConditionalGenerationConfig,
|
|
473
|
+
RBLNPaliGemmaModel,
|
|
474
|
+
RBLNPaliGemmaModelConfig,
|
|
439
475
|
RBLNPegasusForConditionalGeneration,
|
|
440
476
|
RBLNPegasusForConditionalGenerationConfig,
|
|
441
477
|
RBLNPegasusModel,
|
|
@@ -450,18 +486,26 @@ if TYPE_CHECKING:
|
|
|
450
486
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
451
487
|
RBLNQwen2_5_VLForConditionalGeneration,
|
|
452
488
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
489
|
+
RBLNQwen2_5_VLModel,
|
|
490
|
+
RBLNQwen2_5_VLModelConfig,
|
|
453
491
|
RBLNQwen2ForCausalLM,
|
|
454
492
|
RBLNQwen2ForCausalLMConfig,
|
|
455
493
|
RBLNQwen2Model,
|
|
456
494
|
RBLNQwen2ModelConfig,
|
|
495
|
+
RBLNQwen2MoeForCausalLM,
|
|
496
|
+
RBLNQwen2MoeForCausalLMConfig,
|
|
457
497
|
RBLNQwen2VisionTransformerPretrainedModel,
|
|
458
498
|
RBLNQwen2VisionTransformerPretrainedModelConfig,
|
|
459
499
|
RBLNQwen2VLForConditionalGeneration,
|
|
460
500
|
RBLNQwen2VLForConditionalGenerationConfig,
|
|
501
|
+
RBLNQwen2VLModel,
|
|
502
|
+
RBLNQwen2VLModelConfig,
|
|
461
503
|
RBLNQwen3ForCausalLM,
|
|
462
504
|
RBLNQwen3ForCausalLMConfig,
|
|
463
505
|
RBLNQwen3Model,
|
|
464
506
|
RBLNQwen3ModelConfig,
|
|
507
|
+
RBLNQwen3MoeForCausalLM,
|
|
508
|
+
RBLNQwen3MoeForCausalLMConfig,
|
|
465
509
|
RBLNResNetForImageClassification,
|
|
466
510
|
RBLNResNetForImageClassificationConfig,
|
|
467
511
|
RBLNRobertaForMaskedLM,
|
optimum/rbln/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.10.0.post1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 10, 0, 'post1')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -24,7 +24,7 @@ import torch
|
|
|
24
24
|
from packaging.version import Version
|
|
25
25
|
|
|
26
26
|
from .__version__ import __version__
|
|
27
|
-
from .utils.deprecation import warn_deprecated_npu
|
|
27
|
+
from .utils.deprecation import deprecate_kwarg, deprecate_method, warn_deprecated_npu
|
|
28
28
|
from .utils.logging import get_logger
|
|
29
29
|
from .utils.runtime_utils import ContextRblnConfig
|
|
30
30
|
|
|
@@ -36,6 +36,30 @@ DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
|
|
|
36
36
|
TypeInputInfo = List[Tuple[str, Tuple[int], str]]
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
def nested_update(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
|
40
|
+
"""
|
|
41
|
+
Recursively merge override dict into base dict.
|
|
42
|
+
For nested dicts, values are merged recursively instead of being replaced.
|
|
43
|
+
For non-dict values, override takes precedence.
|
|
44
|
+
Args:
|
|
45
|
+
base: The base dictionary to merge into (modified in-place).
|
|
46
|
+
override: The dictionary with values to merge.
|
|
47
|
+
Returns:
|
|
48
|
+
The merged base dictionary.
|
|
49
|
+
Example:
|
|
50
|
+
>>> base = {"a": 1, "nested": {"x": 10, "y": 20}}
|
|
51
|
+
>>> override = {"b": 2, "nested": {"y": 30, "z": 40}}
|
|
52
|
+
>>> nested_update(base, override)
|
|
53
|
+
{"a": 1, "b": 2, "nested": {"x": 10, "y": 30, "z": 40}}
|
|
54
|
+
"""
|
|
55
|
+
for key, value in override.items():
|
|
56
|
+
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
|
|
57
|
+
nested_update(base[key], value)
|
|
58
|
+
else:
|
|
59
|
+
base[key] = value
|
|
60
|
+
return base
|
|
61
|
+
|
|
62
|
+
|
|
39
63
|
@runtime_checkable
|
|
40
64
|
class RBLNSerializableConfigProtocol(Protocol):
|
|
41
65
|
def _prepare_for_serialization(self) -> Dict[str, Any]: ...
|
|
@@ -92,7 +116,7 @@ class RBLNCompileConfig:
|
|
|
92
116
|
and isinstance(item[0], str) # name
|
|
93
117
|
and isinstance(item[1], (tuple, list)) # shape
|
|
94
118
|
and all(isinstance(x, int) for x in item[1])
|
|
95
|
-
and isinstance(item[2], str) # dtype
|
|
119
|
+
and (isinstance(item[2], str) or isinstance(item[2], torch.dtype)) # dtype
|
|
96
120
|
for item in input_info
|
|
97
121
|
)
|
|
98
122
|
|
|
@@ -216,8 +240,7 @@ class RBLNAutoConfig:
|
|
|
216
240
|
For example, the parsed contents of `rbln_config.json`.
|
|
217
241
|
|
|
218
242
|
Returns:
|
|
219
|
-
RBLNModelConfig: A configuration instance. The specific subclass is
|
|
220
|
-
selected by `config_dict["cls_name"]`.
|
|
243
|
+
RBLNModelConfig: A configuration instance. The specific subclass is selected by `config_dict["cls_name"]`.
|
|
221
244
|
|
|
222
245
|
Raises:
|
|
223
246
|
ValueError: If `cls_name` is missing.
|
|
@@ -256,12 +279,13 @@ class RBLNAutoConfig:
|
|
|
256
279
|
|
|
257
280
|
CONFIG_MAPPING[config.__name__] = config
|
|
258
281
|
|
|
259
|
-
@
|
|
260
|
-
def
|
|
282
|
+
@classmethod
|
|
283
|
+
def from_pretrained(
|
|
284
|
+
cls,
|
|
261
285
|
path: str,
|
|
262
|
-
|
|
263
|
-
kwargs: Optional[Dict[str, Any]] = None,
|
|
286
|
+
rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
|
264
287
|
return_unused_kwargs: bool = False,
|
|
288
|
+
**kwargs: Optional[Dict[str, Any]],
|
|
265
289
|
) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
|
|
266
290
|
"""
|
|
267
291
|
Load RBLNModelConfig from a path.
|
|
@@ -269,53 +293,58 @@ class RBLNAutoConfig:
|
|
|
269
293
|
|
|
270
294
|
Args:
|
|
271
295
|
path (str): Path to the RBLNModelConfig.
|
|
272
|
-
|
|
296
|
+
rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
|
|
297
|
+
return_unused_kwargs (bool): Whether to return unused kwargs.
|
|
298
|
+
kwargs: Additional keyword arguments to override configuration values.
|
|
273
299
|
|
|
274
300
|
Returns:
|
|
275
301
|
RBLNModelConfig: The loaded RBLNModelConfig.
|
|
276
|
-
"""
|
|
277
|
-
if kwargs is None:
|
|
278
|
-
kwargs = {}
|
|
279
|
-
cls, config_file = load_config(path)
|
|
280
|
-
|
|
281
|
-
rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
|
|
282
|
-
rbln_runtime_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in RUNTIME_KEYWORDS}
|
|
283
|
-
rbln_submodule_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in cls.submodules}
|
|
284
302
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
raise ValueError(f"Submodule {submodule} not found in rbln_config.json.")
|
|
295
|
-
submodule_config = config_file[submodule]
|
|
296
|
-
submodule_config.update(rbln_submodule_kwargs.pop(submodule, {}))
|
|
297
|
-
config_file[submodule] = RBLNAutoConfig.load_from_dict(submodule_config)
|
|
303
|
+
Examples:
|
|
304
|
+
```python
|
|
305
|
+
config = RBLNAutoConfig.from_pretrained("/path/to/model")
|
|
306
|
+
```
|
|
307
|
+
"""
|
|
308
|
+
target_cls, _ = load_config(path)
|
|
309
|
+
return target_cls.from_pretrained(
|
|
310
|
+
path, rbln_config=rbln_config, return_unused_kwargs=return_unused_kwargs, **kwargs
|
|
311
|
+
)
|
|
298
312
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
313
|
+
@classmethod
|
|
314
|
+
@deprecate_method(version="0.11.0", new_method="from_pretrained")
|
|
315
|
+
def load(
|
|
316
|
+
cls,
|
|
317
|
+
path: str,
|
|
318
|
+
rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
|
319
|
+
return_unused_kwargs: bool = False,
|
|
320
|
+
**kwargs: Optional[Dict[str, Any]],
|
|
321
|
+
) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
|
|
322
|
+
"""
|
|
323
|
+
Load RBLNModelConfig from a path.
|
|
324
|
+
Class name is automatically inferred from the `rbln_config.json` file.
|
|
302
325
|
|
|
303
|
-
|
|
326
|
+
Deprecated:
|
|
327
|
+
This method is deprecated and will be removed in version 0.11.0.
|
|
328
|
+
Use `from_pretrained` instead.
|
|
304
329
|
|
|
305
|
-
|
|
330
|
+
Args:
|
|
331
|
+
path (str): Path to the RBLNModelConfig file or directory.
|
|
332
|
+
rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
|
|
333
|
+
return_unused_kwargs (bool): Whether to return unused kwargs.
|
|
334
|
+
kwargs: Additional keyword arguments to override configuration values.
|
|
306
335
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
if getattr(rbln_config, key) != value:
|
|
310
|
-
raise ValueError(
|
|
311
|
-
f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
|
|
312
|
-
f"Since the value is already set to {getattr(rbln_config, key)}"
|
|
313
|
-
)
|
|
336
|
+
Returns:
|
|
337
|
+
RBLNModelConfig: The loaded RBLNModelConfig.
|
|
314
338
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
339
|
+
Examples:
|
|
340
|
+
```python
|
|
341
|
+
# Deprecated usage:
|
|
342
|
+
config = RBLNAutoConfig.load("/path/to/model")
|
|
343
|
+
# Recommended usage:
|
|
344
|
+
config = RBLNAutoConfig.from_pretrained("/path/to/model")
|
|
345
|
+
```
|
|
346
|
+
"""
|
|
347
|
+
return cls.from_pretrained(path, rbln_config=rbln_config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
|
319
348
|
|
|
320
349
|
|
|
321
350
|
class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
@@ -524,8 +553,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
524
553
|
non_save_attributes = [
|
|
525
554
|
"_frozen",
|
|
526
555
|
"_runtime_options",
|
|
527
|
-
"torch_dtype",
|
|
528
556
|
"npu",
|
|
557
|
+
"dtype",
|
|
529
558
|
"tensor_parallel_size",
|
|
530
559
|
"create_runtimes",
|
|
531
560
|
"device",
|
|
@@ -650,6 +679,14 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
650
679
|
|
|
651
680
|
super().__setattr__(key, value)
|
|
652
681
|
|
|
682
|
+
@deprecate_kwarg(
|
|
683
|
+
old_name="_torch_dtype",
|
|
684
|
+
new_name="dtype",
|
|
685
|
+
version="0.12.0",
|
|
686
|
+
deprecated_type=torch.dtype,
|
|
687
|
+
value_replacer=RBLNCompileConfig.normalize_dtype,
|
|
688
|
+
raise_if_greater_or_equal_version=False,
|
|
689
|
+
)
|
|
653
690
|
def __init__(
|
|
654
691
|
self,
|
|
655
692
|
cls_name: Optional[str] = None,
|
|
@@ -661,7 +698,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
661
698
|
tensor_parallel_size: Optional[int] = None,
|
|
662
699
|
timeout: Optional[int] = None,
|
|
663
700
|
optimum_rbln_version: Optional[str] = None,
|
|
664
|
-
|
|
701
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
|
665
702
|
_compile_cfgs: Optional[List[RBLNCompileConfig]] = None,
|
|
666
703
|
*,
|
|
667
704
|
optimize_host_memory: Optional[bool] = None,
|
|
@@ -680,7 +717,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
680
717
|
tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
|
|
681
718
|
timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
|
|
682
719
|
optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
|
|
683
|
-
|
|
720
|
+
dtype (Optional[Union[str, torch.dtype]]): The data type to use for the model.
|
|
684
721
|
_compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
|
|
685
722
|
kwargs: Additional keyword arguments.
|
|
686
723
|
|
|
@@ -710,7 +747,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
710
747
|
self.npu = npu
|
|
711
748
|
self.tensor_parallel_size = tensor_parallel_size
|
|
712
749
|
|
|
713
|
-
|
|
750
|
+
if dtype is not None and isinstance(dtype, torch.dtype):
|
|
751
|
+
dtype = RBLNCompileConfig.normalize_dtype(dtype)
|
|
752
|
+
self._dtype = dtype or "float32"
|
|
714
753
|
self.optimum_rbln_version = optimum_rbln_version
|
|
715
754
|
if self.optimum_rbln_version is None:
|
|
716
755
|
self.optimum_rbln_version = __version__
|
|
@@ -743,14 +782,24 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
743
782
|
|
|
744
783
|
@property
|
|
745
784
|
def torch_dtype(self):
|
|
746
|
-
|
|
785
|
+
logger.warning_once("`torch_dtype` is deprecated. Use `dtype` instead.")
|
|
786
|
+
return self.dtype
|
|
747
787
|
|
|
748
788
|
@torch_dtype.setter
|
|
749
789
|
def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
|
|
750
|
-
|
|
751
|
-
|
|
790
|
+
logger.warning_once("`torch_dtype` is deprecated. Use `dtype` instead.")
|
|
791
|
+
self.dtype = torch_dtype
|
|
792
|
+
|
|
793
|
+
@property
|
|
794
|
+
def dtype(self):
|
|
795
|
+
return getattr(torch, self._dtype)
|
|
752
796
|
|
|
753
|
-
|
|
797
|
+
@dtype.setter
|
|
798
|
+
def dtype(self, dtype: Union[str, torch.dtype]):
|
|
799
|
+
if isinstance(dtype, torch.dtype):
|
|
800
|
+
dtype = RBLNCompileConfig.normalize_dtype(dtype)
|
|
801
|
+
|
|
802
|
+
self._dtype = dtype
|
|
754
803
|
|
|
755
804
|
@property
|
|
756
805
|
def rbln_model_cls_name(self) -> str:
|
|
@@ -774,10 +823,15 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
774
823
|
if isinstance(value, RBLNSerializableConfigProtocol):
|
|
775
824
|
# Convert nested RBLNModelConfig to its serializable form
|
|
776
825
|
serializable_map[key] = value._prepare_for_serialization()
|
|
826
|
+
elif key == "_dtype":
|
|
827
|
+
serializable_map["dtype"] = value
|
|
828
|
+
elif isinstance(value, list) and all(isinstance(item, RBLNSerializableConfigProtocol) for item in value):
|
|
829
|
+
serializable_map[key] = [item._prepare_for_serialization() for item in value]
|
|
777
830
|
elif key == "_compile_cfgs":
|
|
778
831
|
serializable_map[key] = [cfg.asdict() for cfg in value]
|
|
779
832
|
else:
|
|
780
833
|
serializable_map[key] = value
|
|
834
|
+
|
|
781
835
|
return serializable_map
|
|
782
836
|
|
|
783
837
|
def __repr__(self):
|
|
@@ -825,18 +879,12 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
825
879
|
if not isinstance(submodule_config, RBLNModelConfig):
|
|
826
880
|
raise ValueError(f"`{submodule_name}` must be an instance of `RBLNModelConfig` before freezing.")
|
|
827
881
|
|
|
828
|
-
if not submodule_config.is_frozen():
|
|
829
|
-
raise ValueError(f"`{submodule_name}` config must be frozen before freezing super config.")
|
|
830
|
-
|
|
831
882
|
self._frozen = True
|
|
832
883
|
|
|
833
884
|
def is_frozen(self):
|
|
834
885
|
return self._frozen
|
|
835
886
|
|
|
836
887
|
def save(self, path: str):
|
|
837
|
-
if not self._frozen:
|
|
838
|
-
raise RuntimeError("`RBLNModelConfig` is not frozen. Please call `set_compile_cfgs` first.")
|
|
839
|
-
|
|
840
888
|
# save as json file without runtime attributes
|
|
841
889
|
path = Path(path)
|
|
842
890
|
if path.is_dir():
|
|
@@ -847,15 +895,23 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
847
895
|
json.dump(serializable_data, jsonf, indent=2)
|
|
848
896
|
|
|
849
897
|
@classmethod
|
|
850
|
-
def
|
|
898
|
+
def from_pretrained(
|
|
899
|
+
cls,
|
|
900
|
+
path: str,
|
|
901
|
+
rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
|
902
|
+
return_unused_kwargs: bool = False,
|
|
903
|
+
**kwargs: Optional[Dict[str, Any]],
|
|
904
|
+
) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
|
|
851
905
|
"""
|
|
852
906
|
Load a RBLNModelConfig from a path.
|
|
853
907
|
|
|
854
908
|
Args:
|
|
855
909
|
path (str): Path to the RBLNModelConfig file or directory containing the config file.
|
|
910
|
+
rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
|
|
911
|
+
return_unused_kwargs (bool): Whether to return unused kwargs.
|
|
856
912
|
kwargs: Additional keyword arguments to override configuration values.
|
|
857
|
-
|
|
858
|
-
|
|
913
|
+
Keys starting with 'rbln_' will have the prefix removed and be used
|
|
914
|
+
to update the configuration.
|
|
859
915
|
|
|
860
916
|
Returns:
|
|
861
917
|
RBLNModelConfig: The loaded configuration instance.
|
|
@@ -864,17 +920,109 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
864
920
|
This method loads the configuration from the specified path and applies any
|
|
865
921
|
provided overrides. If the loaded configuration class doesn't match the expected
|
|
866
922
|
class, a warning will be logged.
|
|
923
|
+
|
|
924
|
+
Examples:
|
|
925
|
+
```python
|
|
926
|
+
config = RBLNResNetForImageClassificationConfig.from_pretrained("/path/to/model")
|
|
927
|
+
```
|
|
867
928
|
"""
|
|
868
929
|
cls_reserved, config_file = load_config(path)
|
|
869
|
-
|
|
870
930
|
if cls_reserved != cls:
|
|
871
931
|
logger.warning(f"Expected {cls.__name__}, but got {cls_reserved.__name__}.")
|
|
872
932
|
|
|
933
|
+
if isinstance(rbln_config, dict):
|
|
934
|
+
for key, value in rbln_config.items():
|
|
935
|
+
if key not in kwargs:
|
|
936
|
+
kwargs[f"rbln_{key}"] = value
|
|
937
|
+
|
|
873
938
|
rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
|
|
874
|
-
|
|
875
|
-
|
|
939
|
+
rbln_runtime_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in RUNTIME_KEYWORDS}
|
|
940
|
+
rbln_submodule_kwargs = {key[5:]: kwargs.pop(key) for key in rbln_keys if key[5:] in cls.submodules}
|
|
941
|
+
|
|
942
|
+
rbln_kwargs = {
|
|
943
|
+
key[5:]: kwargs.pop(key)
|
|
944
|
+
for key in rbln_keys
|
|
945
|
+
if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
|
|
946
|
+
}
|
|
947
|
+
|
|
948
|
+
# Process submodule's rbln_config
|
|
949
|
+
for submodule in cls.submodules:
|
|
950
|
+
if submodule not in config_file:
|
|
951
|
+
raise ValueError(f"Submodule {submodule} not found in rbln_config.json.")
|
|
952
|
+
submodule_config = config_file[submodule]
|
|
953
|
+
submodule_config.update(rbln_runtime_kwargs)
|
|
954
|
+
|
|
955
|
+
update_dict = rbln_submodule_kwargs.pop(submodule, {})
|
|
956
|
+
if update_dict:
|
|
957
|
+
nested_update(submodule_config, update_dict)
|
|
958
|
+
config_file[submodule] = RBLNAutoConfig.load_from_dict(submodule_config)
|
|
959
|
+
|
|
960
|
+
if isinstance(rbln_config, RBLNModelConfig):
|
|
961
|
+
config_file.update(rbln_config._runtime_options)
|
|
962
|
+
|
|
963
|
+
# update submodule runtime
|
|
964
|
+
for submodule in rbln_config.submodules:
|
|
965
|
+
if str(config_file[submodule]) != str(getattr(rbln_config, submodule)):
|
|
966
|
+
raise ValueError(
|
|
967
|
+
f"Passed rbln_config has different attributes for submodule {submodule} than the config_file"
|
|
968
|
+
)
|
|
969
|
+
config_file[submodule] = getattr(rbln_config, submodule)
|
|
970
|
+
|
|
971
|
+
config_file.update(rbln_runtime_kwargs)
|
|
972
|
+
rbln_config = cls(**config_file)
|
|
973
|
+
if len(rbln_kwargs) > 0:
|
|
974
|
+
for key, value in rbln_kwargs.items():
|
|
975
|
+
if getattr(rbln_config, key) != value:
|
|
976
|
+
raise ValueError(
|
|
977
|
+
f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
|
|
978
|
+
f"Since the value is already set to {getattr(rbln_config, key)}"
|
|
979
|
+
)
|
|
980
|
+
if return_unused_kwargs:
|
|
981
|
+
return rbln_config, kwargs
|
|
982
|
+
else:
|
|
983
|
+
return rbln_config
|
|
876
984
|
|
|
877
|
-
|
|
985
|
+
@classmethod
|
|
986
|
+
@deprecate_method(version="0.11.0", new_method="from_pretrained")
|
|
987
|
+
def load(
|
|
988
|
+
cls,
|
|
989
|
+
path: str,
|
|
990
|
+
rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
|
991
|
+
return_unused_kwargs: bool = False,
|
|
992
|
+
**kwargs: Optional[Dict[str, Any]],
|
|
993
|
+
) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
|
|
994
|
+
"""
|
|
995
|
+
Load a RBLNModelConfig from a path.
|
|
996
|
+
|
|
997
|
+
Deprecated:
|
|
998
|
+
This method is deprecated and will be removed in version 0.11.0.
|
|
999
|
+
Use `from_pretrained` instead.
|
|
1000
|
+
|
|
1001
|
+
Args:
|
|
1002
|
+
path (str): Path to the RBLNModelConfig file or directory containing the config file.
|
|
1003
|
+
rbln_config (Optional[Dict[str, Any]]): Additional configuration to override.
|
|
1004
|
+
return_unused_kwargs (bool): Whether to return unused kwargs.
|
|
1005
|
+
kwargs: Additional keyword arguments to override configuration values.
|
|
1006
|
+
Keys starting with 'rbln_' will have the prefix removed and be used
|
|
1007
|
+
to update the configuration.
|
|
1008
|
+
|
|
1009
|
+
Returns:
|
|
1010
|
+
RBLNModelConfig: The loaded configuration instance.
|
|
1011
|
+
|
|
1012
|
+
Note:
|
|
1013
|
+
This method loads the configuration from the specified path and applies any
|
|
1014
|
+
provided overrides. If the loaded configuration class doesn't match the expected
|
|
1015
|
+
class, a warning will be logged.
|
|
1016
|
+
|
|
1017
|
+
Examples:
|
|
1018
|
+
```python
|
|
1019
|
+
# Deprecated usage:
|
|
1020
|
+
config = RBLNResNetForImageClassificationConfig.load("/path/to/model")
|
|
1021
|
+
# Recommended usage:
|
|
1022
|
+
config = RBLNResNetForImageClassificationConfig.from_pretrained("/path/to/model")
|
|
1023
|
+
```
|
|
1024
|
+
"""
|
|
1025
|
+
return cls.from_pretrained(path, rbln_config=rbln_config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
|
878
1026
|
|
|
879
1027
|
@classmethod
|
|
880
1028
|
def initialize_from_kwargs(
|
|
@@ -974,3 +1122,18 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
974
1122
|
@timeout.setter
|
|
975
1123
|
def timeout(self, timeout: int):
|
|
976
1124
|
self._runtime_options["timeout"] = timeout
|
|
1125
|
+
|
|
1126
|
+
|
|
1127
|
+
def convert_rbln_config_dict(
|
|
1128
|
+
rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None, **kwargs
|
|
1129
|
+
) -> Tuple[Optional[Union[Dict[str, Any], RBLNModelConfig]], Dict[str, Any]]:
|
|
1130
|
+
# Validate and merge rbln_ prefixed kwargs into rbln_config
|
|
1131
|
+
kwargs_keys = list(kwargs.keys())
|
|
1132
|
+
rbln_kwargs = {key[5:]: kwargs.pop(key) for key in kwargs_keys if key.startswith("rbln_")}
|
|
1133
|
+
|
|
1134
|
+
rbln_config = {} if rbln_config is None else rbln_config
|
|
1135
|
+
|
|
1136
|
+
if isinstance(rbln_config, dict) and len(rbln_kwargs) > 0:
|
|
1137
|
+
rbln_config.update(rbln_kwargs)
|
|
1138
|
+
|
|
1139
|
+
return rbln_config, kwargs
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, Dict, Optional, Union
|
|
15
|
+
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from diffusers import ControlNetModel
|
|
@@ -218,7 +218,7 @@ class RBLNControlNetModel(RBLNModel):
|
|
|
218
218
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
219
219
|
return_dict: bool = True,
|
|
220
220
|
**kwargs,
|
|
221
|
-
):
|
|
221
|
+
) -> Union[ControlNetOutput, Tuple]:
|
|
222
222
|
"""
|
|
223
223
|
Forward pass for the RBLN-optimized ControlNetModel.
|
|
224
224
|
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from pathlib import Path
|
|
16
|
-
from typing import TYPE_CHECKING, Optional, Union
|
|
16
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
from diffusers.models.transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
|
|
@@ -134,7 +134,7 @@ class RBLNPriorTransformer(RBLNModel):
|
|
|
134
134
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
135
135
|
attention_mask: Optional[torch.Tensor] = None,
|
|
136
136
|
return_dict: bool = True,
|
|
137
|
-
):
|
|
137
|
+
) -> Union[PriorTransformerOutput, Tuple]:
|
|
138
138
|
"""
|
|
139
139
|
Forward pass for the RBLN-optimized PriorTransformer.
|
|
140
140
|
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from pathlib import Path
|
|
16
|
-
from typing import TYPE_CHECKING, List, Optional, Union
|
|
16
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|
17
17
|
|
|
18
18
|
import rebel
|
|
19
19
|
import torch
|
|
@@ -302,7 +302,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
|
|
|
302
302
|
condition_mask: Optional[torch.Tensor] = None,
|
|
303
303
|
padding_mask: Optional[torch.Tensor] = None,
|
|
304
304
|
return_dict: bool = True,
|
|
305
|
-
):
|
|
305
|
+
) -> Union[Transformer2DModelOutput, Tuple]:
|
|
306
306
|
"""
|
|
307
307
|
Forward pass for the RBLN-optimized CosmosTransformer3DModel.
|
|
308
308
|
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
@@ -160,7 +160,7 @@ class RBLNSD3Transformer2DModel(RBLNModel):
|
|
|
160
160
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
161
161
|
return_dict: bool = True,
|
|
162
162
|
**kwargs,
|
|
163
|
-
):
|
|
163
|
+
) -> Union[Transformer2DModelOutput, Tuple]:
|
|
164
164
|
"""
|
|
165
165
|
Forward pass for the RBLN-optimized SD3Transformer2DModel.
|
|
166
166
|
|