optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +24 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +45 -33
- optimum/rbln/diffusers/__init__.py +21 -1
- optimum/rbln/diffusers/configurations/__init__.py +4 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
- optimum/rbln/diffusers/modeling_diffusers.py +72 -65
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
- optimum/rbln/diffusers/models/controlnet.py +14 -8
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +71 -37
- optimum/rbln/modeling_base.py +63 -109
- optimum/rbln/transformers/__init__.py +41 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +21 -22
- optimum/rbln/transformers/modeling_rope_utils.py +5 -2
- optimum/rbln/transformers/models/__init__.py +54 -4
- optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
- optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
- optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
- optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
- optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
- optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
- optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
- optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
- optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/runtime_utils.py +49 -1
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
- optimum_rbln-0.8.1.dist-info/RECORD +211 -0
- optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -19,10 +19,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
|
19
19
|
|
20
20
|
import torch
|
21
21
|
|
22
|
-
from ..configuration_utils import ContextRblnConfig, RBLNModelConfig
|
22
|
+
from ..configuration_utils import ContextRblnConfig, RBLNModelConfig, get_rbln_config_class
|
23
23
|
from ..modeling import RBLNModel
|
24
24
|
from ..utils.decorator_utils import remove_compile_time_kwargs
|
25
25
|
from ..utils.logging import get_logger
|
26
|
+
from ..utils.model_utils import get_rbln_model_cls
|
26
27
|
|
27
28
|
|
28
29
|
logger = get_logger(__name__)
|
@@ -44,7 +45,7 @@ class RBLNDiffusionMixin:
|
|
44
45
|
To use this mixin:
|
45
46
|
|
46
47
|
1. Create a new pipeline class that inherits from both this mixin and the original StableDiffusionPipeline.
|
47
|
-
2. Define the required _submodules class variable listing the components to be compiled.
|
48
|
+
2. Define the required _submodules and _optional_submodules class variable listing the components to be compiled.
|
48
49
|
|
49
50
|
Example:
|
50
51
|
```python
|
@@ -54,6 +55,7 @@ class RBLNDiffusionMixin:
|
|
54
55
|
|
55
56
|
Class Variables:
|
56
57
|
_submodules: List of submodule names that should be compiled (typically ["text_encoder", "unet", "vae"])
|
58
|
+
_optional_submodules: List of submodule names compiled without inheriting RBLNModel (typically ["safety_checker"])
|
57
59
|
|
58
60
|
Methods:
|
59
61
|
from_pretrained: Creates and optionally compiles a model from a pretrained checkpoint
|
@@ -66,6 +68,7 @@ class RBLNDiffusionMixin:
|
|
66
68
|
|
67
69
|
_connected_classes = {}
|
68
70
|
_submodules = []
|
71
|
+
_optional_submodules = []
|
69
72
|
_prefix = {}
|
70
73
|
_rbln_config_class = None
|
71
74
|
_hf_class = None
|
@@ -110,18 +113,10 @@ class RBLNDiffusionMixin:
|
|
110
113
|
|
111
114
|
@classmethod
|
112
115
|
def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
|
113
|
-
|
114
|
-
Lazily loads and caches the corresponding RBLN model config class.
|
115
|
-
"""
|
116
|
+
# Lazily loads and caches the corresponding RBLN model config class.
|
116
117
|
if cls._rbln_config_class is None:
|
117
118
|
rbln_config_class_name = cls.__name__ + "Config"
|
118
|
-
|
119
|
-
cls._rbln_config_class = getattr(library, rbln_config_class_name, None)
|
120
|
-
if cls._rbln_config_class is None:
|
121
|
-
raise ValueError(
|
122
|
-
f"RBLN config class {rbln_config_class_name} not found. This is an internal error. "
|
123
|
-
"Please report it to the developers."
|
124
|
-
)
|
119
|
+
cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
|
125
120
|
return cls._rbln_config_class
|
126
121
|
|
127
122
|
@classmethod
|
@@ -143,7 +138,7 @@ class RBLNDiffusionMixin:
|
|
143
138
|
lora_ids: Optional[Union[str, List[str]]] = None,
|
144
139
|
lora_weights_names: Optional[Union[str, List[str]]] = None,
|
145
140
|
lora_scales: Optional[Union[float, List[float]]] = None,
|
146
|
-
**kwargs,
|
141
|
+
**kwargs: Dict[str, Any],
|
147
142
|
) -> "RBLNDiffusionMixin":
|
148
143
|
"""
|
149
144
|
Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
|
@@ -157,24 +152,25 @@ class RBLNDiffusionMixin:
|
|
157
152
|
Args:
|
158
153
|
model_id (`str`):
|
159
154
|
The model ID or path to the pretrained model to load. Can be either:
|
155
|
+
|
160
156
|
- A model ID from the HuggingFace Hub
|
161
157
|
- A local path to a saved model directory
|
162
|
-
export
|
158
|
+
export:
|
163
159
|
If True, takes a PyTorch model from `model_id` and compiles it for RBLN NPU execution.
|
164
160
|
If False, loads an already compiled RBLN model from `model_id` without recompilation.
|
165
|
-
model_save_dir
|
161
|
+
model_save_dir:
|
166
162
|
Directory to save the compiled model artifacts. Only used when `export=True`.
|
167
163
|
If not provided and `export=True`, a temporary directory is used.
|
168
|
-
rbln_config
|
164
|
+
rbln_config:
|
169
165
|
Configuration options for RBLN compilation. Can include settings for specific submodules
|
170
166
|
such as `text_encoder`, `unet`, and `vae`. Configuration can be tailored to the specific
|
171
167
|
pipeline being compiled.
|
172
|
-
lora_ids
|
168
|
+
lora_ids:
|
173
169
|
LoRA adapter ID(s) to load and apply before compilation. LoRA weights are fused
|
174
170
|
into the model weights during compilation. Only used when `export=True`.
|
175
|
-
lora_weights_names
|
171
|
+
lora_weights_names:
|
176
172
|
Names of specific LoRA weight files to load, corresponding to lora_ids. Only used when `export=True`.
|
177
|
-
lora_scales
|
173
|
+
lora_scales:
|
178
174
|
Scaling factor(s) to apply to the LoRA adapter(s). Only used when `export=True`.
|
179
175
|
**kwargs:
|
180
176
|
Additional arguments to pass to the underlying diffusion pipeline constructor or the
|
@@ -182,39 +178,50 @@ class RBLNDiffusionMixin:
|
|
182
178
|
or the particular diffusion pipeline being used.
|
183
179
|
|
184
180
|
Returns:
|
185
|
-
|
186
|
-
|
181
|
+
A compiled or loaded diffusion pipeline that can be used for inference on RBLN NPU.
|
182
|
+
The returned object is an instance of the class that called this method, inheriting from RBLNDiffusionMixin.
|
187
183
|
"""
|
188
184
|
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
189
185
|
|
190
186
|
if export:
|
191
187
|
# keep submodules if user passed any of them.
|
192
188
|
passed_submodules = {
|
193
|
-
name: kwargs.pop(name)
|
189
|
+
name: kwargs.pop(name)
|
190
|
+
for name in cls._submodules + cls._optional_submodules
|
191
|
+
if isinstance(kwargs.get(name), RBLNModel)
|
194
192
|
}
|
195
193
|
|
196
194
|
else:
|
197
195
|
# raise error if any of submodules are torch module.
|
198
196
|
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
199
|
-
for submodule_name in cls._submodules:
|
200
|
-
|
201
|
-
|
202
|
-
|
197
|
+
for submodule_name in cls._submodules + cls._optional_submodules:
|
198
|
+
passed_submodule = kwargs.get(submodule_name, None)
|
199
|
+
|
200
|
+
if passed_submodule is None:
|
201
|
+
module_name, class_name = model_index_config[submodule_name]
|
202
|
+
if module_name != "optimum.rbln":
|
203
|
+
raise ValueError(
|
204
|
+
f"Invalid module_name '{module_name}' found in model_index.json for "
|
205
|
+
f"submodule '{submodule_name}'. "
|
206
|
+
"Expected 'optimum.rbln'. Please check the model_index.json configuration."
|
207
|
+
"If you want to compile, set `export=True`."
|
208
|
+
)
|
209
|
+
|
210
|
+
submodule_cls = get_rbln_model_cls(class_name)
|
211
|
+
submodule_config = getattr(rbln_config, submodule_name)
|
212
|
+
submodule = submodule_cls.from_pretrained(
|
213
|
+
model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
|
203
214
|
)
|
204
215
|
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
216
|
+
else:
|
217
|
+
if passed_submodule.__class__.__name__.startswith("RBLN"):
|
218
|
+
submodule = passed_submodule
|
219
|
+
|
220
|
+
elif isinstance(passed_submodule, torch.nn.Module):
|
221
|
+
raise AssertionError(
|
222
|
+
f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
|
223
|
+
)
|
212
224
|
|
213
|
-
submodule_cls: Type[RBLNModel] = getattr(importlib.import_module("optimum.rbln"), class_name)
|
214
|
-
submodule_config = getattr(rbln_config, submodule_name)
|
215
|
-
submodule = submodule_cls.from_pretrained(
|
216
|
-
model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
|
217
|
-
)
|
218
225
|
kwargs[submodule_name] = submodule
|
219
226
|
|
220
227
|
with ContextRblnConfig(
|
@@ -293,7 +300,6 @@ class RBLNDiffusionMixin:
|
|
293
300
|
elif isinstance(submodule, RBLNModel):
|
294
301
|
pass
|
295
302
|
elif submodule_name == "controlnet" and hasattr(submodule, "nets"):
|
296
|
-
# In case of multicontrolnet
|
297
303
|
submodule = cls._compile_multicontrolnet(
|
298
304
|
controlnets=submodule,
|
299
305
|
model_save_dir=model_save_dir,
|
@@ -301,11 +307,8 @@ class RBLNDiffusionMixin:
|
|
301
307
|
prefix=prefix,
|
302
308
|
)
|
303
309
|
elif isinstance(submodule, torch.nn.Module):
|
304
|
-
submodule_cls: RBLNModel = getattr(
|
305
|
-
importlib.import_module("optimum.rbln"), f"RBLN{submodule.__class__.__name__}"
|
306
|
-
)
|
307
310
|
subfolder = prefix + submodule_name
|
308
|
-
submodule =
|
311
|
+
submodule = submodule_rbln_cls.from_model(
|
309
312
|
model=submodule,
|
310
313
|
subfolder=subfolder,
|
311
314
|
model_save_dir=model_save_dir,
|
@@ -362,10 +365,16 @@ class RBLNDiffusionMixin:
|
|
362
365
|
# Causing warning messeages.
|
363
366
|
|
364
367
|
update_dict = {}
|
365
|
-
for submodule_name in cls._submodules:
|
368
|
+
for submodule_name in cls._submodules + cls._optional_submodules:
|
366
369
|
# replace submodule
|
367
|
-
|
368
|
-
|
370
|
+
if submodule_name in submodules:
|
371
|
+
setattr(model, submodule_name, submodules[submodule_name])
|
372
|
+
update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
|
373
|
+
else:
|
374
|
+
# It assumes that the modules in _optional_components is compiled
|
375
|
+
# and already registered as an attribute of the model.
|
376
|
+
update_dict[submodule_name] = ("optimum.rbln", getattr(model, submodule_name).__class__.__name__)
|
377
|
+
|
369
378
|
if cls._load_connected_pipes:
|
370
379
|
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
371
380
|
prefix = cls._prefix.get(connected_pipe_name, "")
|
@@ -396,31 +405,29 @@ class RBLNDiffusionMixin:
|
|
396
405
|
return model
|
397
406
|
|
398
407
|
def get_compiled_image_size(self):
|
399
|
-
if hasattr(self, "vae"):
|
408
|
+
if hasattr(self, "vae") and hasattr(self.vae, "image_size"):
|
400
409
|
compiled_image_size = self.vae.image_size
|
401
410
|
else:
|
402
411
|
compiled_image_size = None
|
403
412
|
return compiled_image_size
|
404
413
|
|
405
414
|
def handle_additional_kwargs(self, **kwargs):
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
```
|
423
|
-
"""
|
415
|
+
# Function to handle additional compile-time parameters during inference.
|
416
|
+
|
417
|
+
# If the additional variable is determined by another module, this method should be overrided.
|
418
|
+
|
419
|
+
# Example:
|
420
|
+
# ```python
|
421
|
+
# if hasattr(self, "movq"):
|
422
|
+
# compiled_image_size = self.movq.image_size
|
423
|
+
# kwargs["height"] = compiled_image_size[0]
|
424
|
+
# kwargs["width"] = compiled_image_size[1]
|
425
|
+
|
426
|
+
# compiled_num_frames = self.unet.rbln_config.num_frames
|
427
|
+
# if compiled_num_frames is not None:
|
428
|
+
# kwargs["num_frames"] = compiled_num_frames
|
429
|
+
# return kwargs
|
430
|
+
# ```
|
424
431
|
return kwargs
|
425
432
|
|
426
433
|
@remove_compile_time_kwargs
|
@@ -20,6 +20,7 @@ from transformers.utils import _LazyModule
|
|
20
20
|
_import_structure = {
|
21
21
|
"autoencoders": [
|
22
22
|
"RBLNAutoencoderKL",
|
23
|
+
"RBLNAutoencoderKLCosmos",
|
23
24
|
"RBLNVQModel",
|
24
25
|
],
|
25
26
|
"unets": [
|
@@ -28,6 +29,7 @@ _import_structure = {
|
|
28
29
|
"controlnet": ["RBLNControlNetModel"],
|
29
30
|
"transformers": [
|
30
31
|
"RBLNPriorTransformer",
|
32
|
+
"RBLNCosmosTransformer3DModel",
|
31
33
|
"RBLNSD3Transformer2DModel",
|
32
34
|
],
|
33
35
|
}
|
@@ -35,10 +37,12 @@ _import_structure = {
|
|
35
37
|
if TYPE_CHECKING:
|
36
38
|
from .autoencoders import (
|
37
39
|
RBLNAutoencoderKL,
|
40
|
+
RBLNAutoencoderKLCosmos,
|
38
41
|
RBLNVQModel,
|
39
42
|
)
|
40
43
|
from .controlnet import RBLNControlNetModel
|
41
44
|
from .transformers import (
|
45
|
+
RBLNCosmosTransformer3DModel,
|
42
46
|
RBLNPriorTransformer,
|
43
47
|
RBLNSD3Transformer2DModel,
|
44
48
|
)
|
@@ -38,6 +38,17 @@ logger = get_logger(__name__)
|
|
38
38
|
|
39
39
|
|
40
40
|
class RBLNAutoencoderKL(RBLNModel):
|
41
|
+
"""
|
42
|
+
RBLN implementation of AutoencoderKL (VAE) for diffusion models.
|
43
|
+
|
44
|
+
This model is used to accelerate AutoencoderKL (VAE) models from diffusers library on RBLN NPUs.
|
45
|
+
It can be configured to include both encoder and decoder, or just the decoder part for latent-to-image
|
46
|
+
conversion.
|
47
|
+
|
48
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
49
|
+
the library implements for all its models.
|
50
|
+
"""
|
51
|
+
|
41
52
|
auto_model_class = AutoencoderKL
|
42
53
|
hf_library_name = "diffusers"
|
43
54
|
_rbln_config_class = RBLNAutoencoderKLConfig
|
@@ -69,7 +80,12 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
69
80
|
|
70
81
|
wrapped_model.eval()
|
71
82
|
|
72
|
-
compiled_models[model_name] = cls.compile(
|
83
|
+
compiled_models[model_name] = cls.compile(
|
84
|
+
wrapped_model,
|
85
|
+
rbln_compile_config=rbln_config.compile_cfgs[i],
|
86
|
+
create_runtimes=rbln_config.create_runtimes,
|
87
|
+
device=rbln_config.device_map[model_name],
|
88
|
+
)
|
73
89
|
|
74
90
|
return compiled_models
|
75
91
|
|
@@ -0,0 +1,219 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING, Dict, List, Union
|
16
|
+
|
17
|
+
import rebel
|
18
|
+
import torch
|
19
|
+
from diffusers.models.autoencoders.autoencoder_kl_cosmos import AutoencoderKLCosmos, CosmosCausalConv3d
|
20
|
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
21
|
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
22
|
+
from torch.nn import functional as F
|
23
|
+
from transformers import PretrainedConfig
|
24
|
+
|
25
|
+
from ....configuration_utils import RBLNCompileConfig
|
26
|
+
from ....modeling import RBLNModel
|
27
|
+
from ....utils.logging import get_logger
|
28
|
+
from ...configurations import RBLNAutoencoderKLCosmosConfig
|
29
|
+
from .vae import RBLNRuntimeCosmosVAEDecoder, RBLNRuntimeCosmosVAEEncoder, _VAECosmosDecoder, _VAECosmosEncoder
|
30
|
+
|
31
|
+
|
32
|
+
if TYPE_CHECKING:
|
33
|
+
import torch
|
34
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
35
|
+
|
36
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
37
|
+
|
38
|
+
logger = get_logger(__name__)
|
39
|
+
|
40
|
+
|
41
|
+
class RBLNAutoencoderKLCosmos(RBLNModel):
|
42
|
+
"""
|
43
|
+
RBLN implementation of AutoencoderKLCosmos for diffusion models.
|
44
|
+
|
45
|
+
This model is used to accelerate AutoencoderKLCosmos models from diffusers library on RBLN NPUs.
|
46
|
+
It can be configured to include both encoder and decoder, or just the decoder part for latent-to-video
|
47
|
+
conversion.
|
48
|
+
|
49
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
50
|
+
the library implements for all its models.
|
51
|
+
"""
|
52
|
+
|
53
|
+
auto_model_class = AutoencoderKLCosmos
|
54
|
+
hf_library_name = "diffusers"
|
55
|
+
_rbln_config_class = RBLNAutoencoderKLCosmosConfig
|
56
|
+
|
57
|
+
def __post_init__(self, **kwargs):
|
58
|
+
super().__post_init__(**kwargs)
|
59
|
+
|
60
|
+
if self.rbln_config.uses_encoder:
|
61
|
+
self.encoder = RBLNRuntimeCosmosVAEEncoder(
|
62
|
+
runtime=self.model[0], main_input_name="x", use_slicing=self.rbln_config.use_slicing
|
63
|
+
)
|
64
|
+
|
65
|
+
self.decoder = RBLNRuntimeCosmosVAEDecoder(
|
66
|
+
runtime=self.model[-1], main_input_name="z", use_slicing=self.rbln_config.use_slicing
|
67
|
+
)
|
68
|
+
self.image_size = self.rbln_config.image_size
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def wrap_model_if_needed(
|
72
|
+
cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLCosmosConfig
|
73
|
+
) -> torch.nn.Module:
|
74
|
+
decoder_model = _VAECosmosDecoder(model)
|
75
|
+
decoder_model.eval()
|
76
|
+
|
77
|
+
if rbln_config.uses_encoder:
|
78
|
+
encoder_model = _VAECosmosEncoder(model)
|
79
|
+
encoder_model.eval()
|
80
|
+
return encoder_model, decoder_model
|
81
|
+
else:
|
82
|
+
return decoder_model
|
83
|
+
|
84
|
+
@classmethod
|
85
|
+
def get_compiled_model(
|
86
|
+
cls, model, rbln_config: RBLNAutoencoderKLCosmosConfig
|
87
|
+
) -> Dict[str, rebel.RBLNCompiledModel]:
|
88
|
+
def replaced_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
89
|
+
if self.temporal_pad != 0:
|
90
|
+
hidden_states_prev = hidden_states[:, :, :1, ...].repeat(1, 1, self.temporal_pad, 1, 1)
|
91
|
+
hidden_states = torch.cat([hidden_states_prev, hidden_states], dim=2)
|
92
|
+
hidden_states = F.pad(hidden_states, (*self.spatial_pad, 0, 0), mode=self.pad_mode, value=0.0)
|
93
|
+
return super(CosmosCausalConv3d, self).forward(hidden_states)
|
94
|
+
|
95
|
+
try:
|
96
|
+
original_forward = CosmosCausalConv3d.forward
|
97
|
+
CosmosCausalConv3d.forward = replaced_forward
|
98
|
+
|
99
|
+
compiled_models = {}
|
100
|
+
if rbln_config.uses_encoder:
|
101
|
+
encoder_model, decoder_model = cls.wrap_model_if_needed(model, rbln_config)
|
102
|
+
enc_compiled_model = cls.compile(
|
103
|
+
encoder_model,
|
104
|
+
rbln_compile_config=rbln_config.compile_cfgs[0],
|
105
|
+
create_runtimes=rbln_config.create_runtimes,
|
106
|
+
device=rbln_config.device_map["encoder"],
|
107
|
+
)
|
108
|
+
compiled_models["encoder"] = enc_compiled_model
|
109
|
+
else:
|
110
|
+
decoder_model = cls.wrap_model_if_needed(model, rbln_config)
|
111
|
+
dec_compiled_model = cls.compile(
|
112
|
+
decoder_model,
|
113
|
+
rbln_compile_config=rbln_config.compile_cfgs[-1],
|
114
|
+
create_runtimes=rbln_config.create_runtimes,
|
115
|
+
device=rbln_config.device_map["decoder"],
|
116
|
+
)
|
117
|
+
compiled_models["decoder"] = dec_compiled_model
|
118
|
+
|
119
|
+
finally:
|
120
|
+
CosmosCausalConv3d.forward = original_forward
|
121
|
+
|
122
|
+
return compiled_models
|
123
|
+
|
124
|
+
@classmethod
|
125
|
+
def update_rbln_config_using_pipe(
|
126
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
127
|
+
) -> "RBLNDiffusionMixinConfig":
|
128
|
+
rbln_config.vae.num_channels_latents = pipe.transformer.config.out_channels
|
129
|
+
rbln_config.vae.vae_scale_factor_temporal = pipe.vae_scale_factor_temporal
|
130
|
+
rbln_config.vae.vae_scale_factor_spatial = pipe.vae_scale_factor_spatial
|
131
|
+
return rbln_config
|
132
|
+
|
133
|
+
@classmethod
|
134
|
+
def _update_rbln_config(
|
135
|
+
cls,
|
136
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
137
|
+
model: "PreTrainedModel",
|
138
|
+
model_config: "PretrainedConfig",
|
139
|
+
rbln_config: RBLNAutoencoderKLCosmosConfig,
|
140
|
+
) -> RBLNAutoencoderKLCosmosConfig:
|
141
|
+
batch_size = 1 if rbln_config.use_slicing else rbln_config.batch_size
|
142
|
+
compile_cfgs = []
|
143
|
+
if rbln_config.uses_encoder:
|
144
|
+
vae_enc_input_info = [
|
145
|
+
(
|
146
|
+
"x",
|
147
|
+
[
|
148
|
+
batch_size,
|
149
|
+
model_config.in_channels,
|
150
|
+
rbln_config.num_frames,
|
151
|
+
rbln_config.height,
|
152
|
+
rbln_config.width,
|
153
|
+
],
|
154
|
+
"float32",
|
155
|
+
),
|
156
|
+
]
|
157
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
|
158
|
+
|
159
|
+
num_latent_frames = (rbln_config.num_frames - 1) // rbln_config.vae_scale_factor_temporal + 1
|
160
|
+
latent_height = rbln_config.height // rbln_config.vae_scale_factor_spatial
|
161
|
+
latent_width = rbln_config.width // rbln_config.vae_scale_factor_spatial
|
162
|
+
|
163
|
+
vae_dec_input_info = [
|
164
|
+
(
|
165
|
+
"z",
|
166
|
+
[
|
167
|
+
batch_size,
|
168
|
+
rbln_config.num_channels_latents,
|
169
|
+
num_latent_frames,
|
170
|
+
latent_height,
|
171
|
+
latent_width,
|
172
|
+
],
|
173
|
+
"float32",
|
174
|
+
),
|
175
|
+
]
|
176
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
|
177
|
+
|
178
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
179
|
+
return rbln_config
|
180
|
+
|
181
|
+
@classmethod
|
182
|
+
def _create_runtimes(
|
183
|
+
cls,
|
184
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
185
|
+
rbln_config: RBLNAutoencoderKLCosmosConfig,
|
186
|
+
) -> List[rebel.Runtime]:
|
187
|
+
if len(compiled_models) == 1:
|
188
|
+
# decoder
|
189
|
+
expected_models = ["decoder"]
|
190
|
+
else:
|
191
|
+
expected_models = ["encoder", "decoder"]
|
192
|
+
|
193
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_models):
|
194
|
+
cls._raise_missing_compiled_file_error(expected_models)
|
195
|
+
|
196
|
+
device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
|
197
|
+
return [
|
198
|
+
rebel.Runtime(
|
199
|
+
compiled_model,
|
200
|
+
tensor_type="pt",
|
201
|
+
device=device_val,
|
202
|
+
activate_profiler=rbln_config.activate_profiler,
|
203
|
+
)
|
204
|
+
for compiled_model, device_val in zip(compiled_models, device_vals)
|
205
|
+
]
|
206
|
+
|
207
|
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
|
208
|
+
posterior = self.encoder.encode(x)
|
209
|
+
if not return_dict:
|
210
|
+
return (posterior,)
|
211
|
+
return AutoencoderKLOutput(latent_dist=posterior)
|
212
|
+
|
213
|
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
|
214
|
+
decoded = self.decoder.decode(z)
|
215
|
+
|
216
|
+
if not return_dict:
|
217
|
+
return (decoded,)
|
218
|
+
|
219
|
+
return DecoderOutput(sample=decoded)
|
@@ -15,17 +15,13 @@
|
|
15
15
|
from typing import TYPE_CHECKING, List
|
16
16
|
|
17
17
|
import torch
|
18
|
-
from diffusers import
|
19
|
-
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
18
|
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
|
20
19
|
|
21
|
-
from ....utils.logging import get_logger
|
22
20
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
23
21
|
|
24
22
|
|
25
23
|
if TYPE_CHECKING:
|
26
|
-
import
|
27
|
-
|
28
|
-
logger = get_logger(__name__)
|
24
|
+
from diffusers import AutoencoderKL, AutoencoderKLCosmos, VQModel
|
29
25
|
|
30
26
|
|
31
27
|
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
@@ -40,6 +36,27 @@ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
|
|
40
36
|
return self.forward(z)
|
41
37
|
|
42
38
|
|
39
|
+
class RBLNRuntimeCosmosVAEEncoder(RBLNPytorchRuntime):
|
40
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
41
|
+
if self.use_slicing and x.shape[0] > 1:
|
42
|
+
encoded_slices = [self.forward(x_slice) for x_slice in x.split(1)]
|
43
|
+
h = torch.cat(encoded_slices)
|
44
|
+
else:
|
45
|
+
h = self.forward(x)
|
46
|
+
posterior = IdentityDistribution(h)
|
47
|
+
return posterior
|
48
|
+
|
49
|
+
|
50
|
+
class RBLNRuntimeCosmosVAEDecoder(RBLNPytorchRuntime):
|
51
|
+
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
52
|
+
if self.use_slicing and z.shape[0] > 1:
|
53
|
+
decoded_slices = [self.forward(z_slice) for z_slice in z.split(1)]
|
54
|
+
decoded = torch.cat(decoded_slices)
|
55
|
+
else:
|
56
|
+
decoded = self.forward(z)
|
57
|
+
return decoded
|
58
|
+
|
59
|
+
|
43
60
|
class _VAEDecoder(torch.nn.Module):
|
44
61
|
def __init__(self, vae: "AutoencoderKL"):
|
45
62
|
super().__init__()
|
@@ -73,6 +90,26 @@ class _VAEEncoder(torch.nn.Module):
|
|
73
90
|
return vae_out
|
74
91
|
|
75
92
|
|
93
|
+
class _VAECosmosEncoder(torch.nn.Module):
|
94
|
+
def __init__(self, vae: "AutoencoderKLCosmos"):
|
95
|
+
super().__init__()
|
96
|
+
self.vae = vae
|
97
|
+
|
98
|
+
def forward(self, x):
|
99
|
+
vae_out = self.vae._encode(x)
|
100
|
+
return vae_out
|
101
|
+
|
102
|
+
|
103
|
+
class _VAECosmosDecoder(torch.nn.Module):
|
104
|
+
def __init__(self, vae: "AutoencoderKLCosmos"):
|
105
|
+
super().__init__()
|
106
|
+
self.vae = vae
|
107
|
+
|
108
|
+
def forward(self, z):
|
109
|
+
vae_out = self.vae._decode(z, return_dict=False)
|
110
|
+
return vae_out
|
111
|
+
|
112
|
+
|
76
113
|
class RBLNRuntimeVQEncoder(RBLNPytorchRuntime):
|
77
114
|
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
78
115
|
h = self.forward(x.contiguous())
|
@@ -91,7 +128,7 @@ class RBLNRuntimeVQDecoder(RBLNPytorchRuntime):
|
|
91
128
|
|
92
129
|
|
93
130
|
class _VQEncoder(torch.nn.Module):
|
94
|
-
def __init__(self, vq_model: VQModel):
|
131
|
+
def __init__(self, vq_model: "VQModel"):
|
95
132
|
super().__init__()
|
96
133
|
self.vq_model = vq_model
|
97
134
|
|
@@ -106,7 +143,7 @@ class _VQEncoder(torch.nn.Module):
|
|
106
143
|
|
107
144
|
|
108
145
|
class _VQDecoder(torch.nn.Module):
|
109
|
-
def __init__(self, vq_model: VQModel):
|
146
|
+
def __init__(self, vq_model: "VQModel"):
|
110
147
|
super().__init__()
|
111
148
|
self.vq_model = vq_model
|
112
149
|
|
@@ -35,6 +35,17 @@ logger = get_logger(__name__)
|
|
35
35
|
|
36
36
|
|
37
37
|
class RBLNVQModel(RBLNModel):
|
38
|
+
"""
|
39
|
+
RBLN implementation of VQModel for diffusion models.
|
40
|
+
|
41
|
+
This model is used to accelerate VQModel models from diffusers library on RBLN NPUs.
|
42
|
+
It can be configured to include both encoder and decoder, or just the decoder part for latent-to-image
|
43
|
+
conversion.
|
44
|
+
|
45
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
46
|
+
the library implements for all its models.
|
47
|
+
"""
|
48
|
+
|
38
49
|
auto_model_class = VQModel
|
39
50
|
config_name = "config.json"
|
40
51
|
hf_library_name = "diffusers"
|
@@ -67,7 +78,12 @@ class RBLNVQModel(RBLNModel):
|
|
67
78
|
|
68
79
|
wrapped_model.eval()
|
69
80
|
|
70
|
-
compiled_models[model_name] = cls.compile(
|
81
|
+
compiled_models[model_name] = cls.compile(
|
82
|
+
wrapped_model,
|
83
|
+
rbln_compile_config=rbln_config.compile_cfgs[i],
|
84
|
+
create_runtimes=rbln_config.create_runtimes,
|
85
|
+
device=rbln_config.device_map[model_name],
|
86
|
+
)
|
71
87
|
|
72
88
|
return compiled_models
|
73
89
|
|