optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +173 -35
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +816 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +111 -137
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +56 -71
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
- optimum/rbln/modeling.py +66 -40
- optimum/rbln/modeling_base.py +111 -86
- optimum/rbln/ops/__init__.py +4 -7
- optimum/rbln/ops/attn.py +271 -205
- optimum/rbln/ops/flash_attn.py +161 -67
- optimum/rbln/ops/kv_cache_update.py +4 -40
- optimum/rbln/ops/linear.py +25 -0
- optimum/rbln/transformers/__init__.py +97 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +120 -32
- optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
- optimum/rbln/transformers/models/bart/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
- optimum/rbln/transformers/models/t5/__init__.py +2 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
- optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +2 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/hub.py +2 -2
- optimum/rbln/utils/import_utils.py +23 -6
- optimum/rbln/utils/model_utils.py +4 -4
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +36 -44
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
- optimum_rbln-0.7.4.dist-info/RECORD +169 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,143 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional, Tuple
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
from ....transformers import RBLNCLIPTextModelWithProjectionConfig, RBLNT5EncoderModelConfig
|
19
|
+
from ..models import RBLNAutoencoderKLConfig, RBLNSD3Transformer2DModelConfig
|
20
|
+
|
21
|
+
|
22
|
+
class _RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
|
23
|
+
submodules = ["transformer", "text_encoder", "text_encoder_2", "text_encoder_3", "vae"]
|
24
|
+
_vae_uses_encoder = False
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
transformer: Optional[RBLNSD3Transformer2DModelConfig] = None,
|
29
|
+
text_encoder: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
|
30
|
+
text_encoder_2: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
|
31
|
+
text_encoder_3: Optional[RBLNT5EncoderModelConfig] = None,
|
32
|
+
vae: Optional[RBLNAutoencoderKLConfig] = None,
|
33
|
+
*,
|
34
|
+
max_seq_len: Optional[int] = None,
|
35
|
+
sample_size: Optional[Tuple[int, int]] = None,
|
36
|
+
image_size: Optional[Tuple[int, int]] = None,
|
37
|
+
batch_size: Optional[int] = None,
|
38
|
+
img_height: Optional[int] = None,
|
39
|
+
img_width: Optional[int] = None,
|
40
|
+
guidance_scale: Optional[float] = None,
|
41
|
+
**kwargs,
|
42
|
+
):
|
43
|
+
"""
|
44
|
+
Args:
|
45
|
+
transformer (Optional[RBLNSD3Transformer2DModelConfig]): Configuration for the transformer model component.
|
46
|
+
Initialized as RBLNSD3Transformer2DModelConfig if not provided.
|
47
|
+
text_encoder (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the primary text encoder.
|
48
|
+
Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
|
49
|
+
text_encoder_2 (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the secondary text encoder.
|
50
|
+
Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
|
51
|
+
text_encoder_3 (Optional[RBLNT5EncoderModelConfig]): Configuration for the tertiary text encoder.
|
52
|
+
Initialized as RBLNT5EncoderModelConfig if not provided.
|
53
|
+
vae (Optional[RBLNAutoencoderKLConfig]): Configuration for the VAE model component.
|
54
|
+
Initialized as RBLNAutoencoderKLConfig if not provided.
|
55
|
+
max_seq_len (Optional[int]): Maximum sequence length for text inputs. Defaults to 256.
|
56
|
+
sample_size (Optional[Tuple[int, int]]): Spatial dimensions for the transformer model.
|
57
|
+
image_size (Optional[Tuple[int, int]]): Dimensions for the generated images.
|
58
|
+
Cannot be used together with img_height/img_width.
|
59
|
+
batch_size (Optional[int]): Batch size for inference, applied to all submodules.
|
60
|
+
img_height (Optional[int]): Height of the generated images.
|
61
|
+
img_width (Optional[int]): Width of the generated images.
|
62
|
+
guidance_scale (Optional[float]): Scale for classifier-free guidance.
|
63
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
64
|
+
|
65
|
+
Raises:
|
66
|
+
ValueError: If both image_size and img_height/img_width are provided.
|
67
|
+
|
68
|
+
Note:
|
69
|
+
When guidance_scale > 1.0, the transformer batch size is automatically doubled to
|
70
|
+
accommodate classifier-free guidance.
|
71
|
+
"""
|
72
|
+
super().__init__(**kwargs)
|
73
|
+
if image_size is not None and (img_height is not None or img_width is not None):
|
74
|
+
raise ValueError("image_size and img_height/img_width cannot both be provided")
|
75
|
+
|
76
|
+
if img_height is not None and img_width is not None:
|
77
|
+
image_size = (img_height, img_width)
|
78
|
+
|
79
|
+
max_seq_len = max_seq_len or 256
|
80
|
+
|
81
|
+
self.text_encoder = self.init_submodule_config(
|
82
|
+
RBLNCLIPTextModelWithProjectionConfig, text_encoder, batch_size=batch_size
|
83
|
+
)
|
84
|
+
self.text_encoder_2 = self.init_submodule_config(
|
85
|
+
RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
|
86
|
+
)
|
87
|
+
self.text_encoder_3 = self.init_submodule_config(
|
88
|
+
RBLNT5EncoderModelConfig,
|
89
|
+
text_encoder_3,
|
90
|
+
batch_size=batch_size,
|
91
|
+
max_seq_len=max_seq_len,
|
92
|
+
)
|
93
|
+
self.transformer = self.init_submodule_config(
|
94
|
+
RBLNSD3Transformer2DModelConfig,
|
95
|
+
transformer,
|
96
|
+
sample_size=sample_size,
|
97
|
+
)
|
98
|
+
self.vae = self.init_submodule_config(
|
99
|
+
RBLNAutoencoderKLConfig,
|
100
|
+
vae,
|
101
|
+
batch_size=batch_size,
|
102
|
+
uses_encoder=self.__class__._vae_uses_encoder,
|
103
|
+
sample_size=image_size,
|
104
|
+
)
|
105
|
+
|
106
|
+
# Get default guidance scale from original class to set Transformer batch size
|
107
|
+
if guidance_scale is None:
|
108
|
+
guidance_scale = self.get_default_values_for_original_cls("__call__", ["guidance_scale"])["guidance_scale"]
|
109
|
+
|
110
|
+
if not self.transformer.batch_size_is_specified:
|
111
|
+
do_classifier_free_guidance = guidance_scale > 1.0
|
112
|
+
if do_classifier_free_guidance:
|
113
|
+
self.transformer.batch_size = self.text_encoder.batch_size * 2
|
114
|
+
else:
|
115
|
+
self.transformer.batch_size = self.text_encoder.batch_size
|
116
|
+
|
117
|
+
@property
|
118
|
+
def max_seq_len(self):
|
119
|
+
return self.text_encoder_3.max_seq_len
|
120
|
+
|
121
|
+
@property
|
122
|
+
def batch_size(self):
|
123
|
+
return self.vae.batch_size
|
124
|
+
|
125
|
+
@property
|
126
|
+
def sample_size(self):
|
127
|
+
return self.transformer.sample_size
|
128
|
+
|
129
|
+
@property
|
130
|
+
def image_size(self):
|
131
|
+
return self.vae.sample_size
|
132
|
+
|
133
|
+
|
134
|
+
class RBLNStableDiffusion3PipelineConfig(_RBLNStableDiffusion3PipelineBaseConfig):
|
135
|
+
_vae_uses_encoder = False
|
136
|
+
|
137
|
+
|
138
|
+
class RBLNStableDiffusion3Img2ImgPipelineConfig(_RBLNStableDiffusion3PipelineBaseConfig):
|
139
|
+
_vae_uses_encoder = True
|
140
|
+
|
141
|
+
|
142
|
+
class RBLNStableDiffusion3InpaintPipelineConfig(_RBLNStableDiffusion3PipelineBaseConfig):
|
143
|
+
_vae_uses_encoder = True
|
@@ -0,0 +1,124 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional, Tuple
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
from ....transformers import RBLNCLIPTextModelConfig, RBLNCLIPTextModelWithProjectionConfig
|
19
|
+
from ..models import RBLNAutoencoderKLConfig, RBLNUNet2DConditionModelConfig
|
20
|
+
|
21
|
+
|
22
|
+
class _RBLNStableDiffusionXLPipelineBaseConfig(RBLNModelConfig):
|
23
|
+
submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
|
24
|
+
_vae_uses_encoder = False
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
text_encoder: Optional[RBLNCLIPTextModelConfig] = None,
|
29
|
+
text_encoder_2: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
|
30
|
+
unet: Optional[RBLNUNet2DConditionModelConfig] = None,
|
31
|
+
vae: Optional[RBLNAutoencoderKLConfig] = None,
|
32
|
+
*,
|
33
|
+
batch_size: Optional[int] = None,
|
34
|
+
img_height: Optional[int] = None,
|
35
|
+
img_width: Optional[int] = None,
|
36
|
+
sample_size: Optional[Tuple[int, int]] = None,
|
37
|
+
image_size: Optional[Tuple[int, int]] = None,
|
38
|
+
guidance_scale: Optional[float] = None,
|
39
|
+
**kwargs,
|
40
|
+
):
|
41
|
+
"""
|
42
|
+
Args:
|
43
|
+
text_encoder (Optional[RBLNCLIPTextModelConfig]): Configuration for the primary text encoder component.
|
44
|
+
Initialized as RBLNCLIPTextModelConfig if not provided.
|
45
|
+
text_encoder_2 (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the secondary text encoder component.
|
46
|
+
Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
|
47
|
+
unet (Optional[RBLNUNet2DConditionModelConfig]): Configuration for the UNet model component.
|
48
|
+
Initialized as RBLNUNet2DConditionModelConfig if not provided.
|
49
|
+
vae (Optional[RBLNAutoencoderKLConfig]): Configuration for the VAE model component.
|
50
|
+
Initialized as RBLNAutoencoderKLConfig if not provided.
|
51
|
+
batch_size (Optional[int]): Batch size for inference, applied to all submodules.
|
52
|
+
img_height (Optional[int]): Height of the generated images.
|
53
|
+
img_width (Optional[int]): Width of the generated images.
|
54
|
+
sample_size (Optional[Tuple[int, int]]): Spatial dimensions for the UNet model.
|
55
|
+
image_size (Optional[Tuple[int, int]]): Alternative way to specify image dimensions.
|
56
|
+
Cannot be used together with img_height/img_width.
|
57
|
+
guidance_scale (Optional[float]): Scale for classifier-free guidance.
|
58
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
59
|
+
|
60
|
+
Raises:
|
61
|
+
ValueError: If both image_size and img_height/img_width are provided.
|
62
|
+
|
63
|
+
Note:
|
64
|
+
When guidance_scale > 1.0, the UNet batch size is automatically doubled to
|
65
|
+
accommodate classifier-free guidance.
|
66
|
+
"""
|
67
|
+
super().__init__(**kwargs)
|
68
|
+
if image_size is not None and (img_height is not None or img_width is not None):
|
69
|
+
raise ValueError("image_size and img_height/img_width cannot both be provided")
|
70
|
+
|
71
|
+
if img_height is not None and img_width is not None:
|
72
|
+
image_size = (img_height, img_width)
|
73
|
+
|
74
|
+
self.text_encoder = self.init_submodule_config(RBLNCLIPTextModelConfig, text_encoder, batch_size=batch_size)
|
75
|
+
self.text_encoder_2 = self.init_submodule_config(
|
76
|
+
RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
|
77
|
+
)
|
78
|
+
self.unet = self.init_submodule_config(
|
79
|
+
RBLNUNet2DConditionModelConfig,
|
80
|
+
unet,
|
81
|
+
sample_size=sample_size,
|
82
|
+
)
|
83
|
+
self.vae = self.init_submodule_config(
|
84
|
+
RBLNAutoencoderKLConfig,
|
85
|
+
vae,
|
86
|
+
batch_size=batch_size,
|
87
|
+
uses_encoder=self.__class__._vae_uses_encoder,
|
88
|
+
sample_size=image_size, # image size is equal to sample size in vae
|
89
|
+
)
|
90
|
+
|
91
|
+
# Get default guidance scale from original class to set UNet batch size
|
92
|
+
if guidance_scale is None:
|
93
|
+
guidance_scale = self.get_default_values_for_original_cls("__call__", ["guidance_scale"])["guidance_scale"]
|
94
|
+
|
95
|
+
if not self.unet.batch_size_is_specified:
|
96
|
+
do_classifier_free_guidance = guidance_scale > 1.0
|
97
|
+
if do_classifier_free_guidance:
|
98
|
+
self.unet.batch_size = self.text_encoder.batch_size * 2
|
99
|
+
else:
|
100
|
+
self.unet.batch_size = self.text_encoder.batch_size
|
101
|
+
|
102
|
+
@property
|
103
|
+
def batch_size(self):
|
104
|
+
return self.vae.batch_size
|
105
|
+
|
106
|
+
@property
|
107
|
+
def sample_size(self):
|
108
|
+
return self.unet.sample_size
|
109
|
+
|
110
|
+
@property
|
111
|
+
def image_size(self):
|
112
|
+
return self.vae.sample_size
|
113
|
+
|
114
|
+
|
115
|
+
class RBLNStableDiffusionXLPipelineConfig(_RBLNStableDiffusionXLPipelineBaseConfig):
|
116
|
+
_vae_uses_encoder = False
|
117
|
+
|
118
|
+
|
119
|
+
class RBLNStableDiffusionXLImg2ImgPipelineConfig(_RBLNStableDiffusionXLPipelineBaseConfig):
|
120
|
+
_vae_uses_encoder = True
|
121
|
+
|
122
|
+
|
123
|
+
class RBLNStableDiffusionXLInpaintPipelineConfig(_RBLNStableDiffusionXLPipelineBaseConfig):
|
124
|
+
_vae_uses_encoder = True
|
@@ -15,12 +15,12 @@
|
|
15
15
|
import copy
|
16
16
|
import importlib
|
17
17
|
from os import PathLike
|
18
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
18
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
19
19
|
|
20
20
|
import torch
|
21
21
|
|
22
|
+
from ..configuration_utils import ContextRblnConfig, RBLNModelConfig
|
22
23
|
from ..modeling import RBLNModel
|
23
|
-
from ..modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
|
24
24
|
from ..utils.decorator_utils import remove_compile_time_kwargs
|
25
25
|
from ..utils.logging import get_logger
|
26
26
|
|
@@ -31,6 +31,10 @@ if TYPE_CHECKING:
|
|
31
31
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
32
32
|
|
33
33
|
|
34
|
+
class RBLNDiffusionMixinConfig(RBLNModelConfig):
|
35
|
+
pass
|
36
|
+
|
37
|
+
|
34
38
|
class RBLNDiffusionMixin:
|
35
39
|
"""
|
36
40
|
RBLNDiffusionMixin provides essential functionalities for compiling Stable Diffusion pipeline components to run on RBLN NPUs.
|
@@ -41,17 +45,11 @@ class RBLNDiffusionMixin:
|
|
41
45
|
|
42
46
|
1. Create a new pipeline class that inherits from both this mixin and the original StableDiffusionPipeline.
|
43
47
|
2. Define the required _submodules class variable listing the components to be compiled.
|
44
|
-
3. If needed, implement get_default_rbln_config for custom configuration of submodules.
|
45
48
|
|
46
49
|
Example:
|
47
50
|
```python
|
48
51
|
class RBLNStableDiffusionPipeline(RBLNDiffusionMixin, StableDiffusionPipeline):
|
49
52
|
_submodules = ["text_encoder", "unet", "vae"]
|
50
|
-
|
51
|
-
@classmethod
|
52
|
-
def get_default_rbln_config(cls, model, submodule_name, rbln_config):
|
53
|
-
# Configuration for other submodules...
|
54
|
-
pass
|
55
53
|
```
|
56
54
|
|
57
55
|
Class Variables:
|
@@ -69,43 +67,8 @@ class RBLNDiffusionMixin:
|
|
69
67
|
_connected_classes = {}
|
70
68
|
_submodules = []
|
71
69
|
_prefix = {}
|
72
|
-
|
73
|
-
|
74
|
-
def is_img2img_pipeline(cls):
|
75
|
-
return "Img2Img" in cls.__name__
|
76
|
-
|
77
|
-
@classmethod
|
78
|
-
def is_inpaint_pipeline(cls):
|
79
|
-
return "Inpaint" in cls.__name__
|
80
|
-
|
81
|
-
@classmethod
|
82
|
-
def get_submodule_rbln_config(
|
83
|
-
cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
|
84
|
-
) -> Dict[str, Any]:
|
85
|
-
submodule = getattr(model, submodule_name)
|
86
|
-
submodule_class_name = submodule.__class__.__name__
|
87
|
-
if isinstance(submodule, torch.nn.Module):
|
88
|
-
if submodule_class_name == "MultiControlNetModel":
|
89
|
-
submodule_class_name = "ControlNetModel"
|
90
|
-
|
91
|
-
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
|
92
|
-
|
93
|
-
submodule_config = rbln_config.get(submodule_name, {})
|
94
|
-
submodule_config = copy.deepcopy(submodule_config)
|
95
|
-
|
96
|
-
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
|
97
|
-
|
98
|
-
submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
|
99
|
-
submodule_config.update(
|
100
|
-
{
|
101
|
-
"img2img_pipeline": cls.is_img2img_pipeline(),
|
102
|
-
"inpaint_pipeline": cls.is_inpaint_pipeline(),
|
103
|
-
}
|
104
|
-
)
|
105
|
-
submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
|
106
|
-
else:
|
107
|
-
raise ValueError(f"submodule {submodule_name} isn't supported")
|
108
|
-
return submodule_config
|
70
|
+
_rbln_config_class = None
|
71
|
+
_hf_class = None
|
109
72
|
|
110
73
|
@staticmethod
|
111
74
|
def _maybe_apply_and_fuse_lora(
|
@@ -146,7 +109,30 @@ class RBLNDiffusionMixin:
|
|
146
109
|
return model
|
147
110
|
|
148
111
|
@classmethod
|
149
|
-
|
112
|
+
def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
|
113
|
+
"""
|
114
|
+
Lazily loads and caches the corresponding RBLN model config class.
|
115
|
+
"""
|
116
|
+
if cls._rbln_config_class is None:
|
117
|
+
rbln_config_class_name = cls.__name__ + "Config"
|
118
|
+
library = importlib.import_module("optimum.rbln")
|
119
|
+
cls._rbln_config_class = getattr(library, rbln_config_class_name, None)
|
120
|
+
if cls._rbln_config_class is None:
|
121
|
+
raise ValueError(
|
122
|
+
f"RBLN config class {rbln_config_class_name} not found. This is an internal error. "
|
123
|
+
"Please report it to the developers."
|
124
|
+
)
|
125
|
+
return cls._rbln_config_class
|
126
|
+
|
127
|
+
@classmethod
|
128
|
+
def get_hf_class(cls):
|
129
|
+
if cls._hf_class is None:
|
130
|
+
hf_cls_name = cls.__name__[4:]
|
131
|
+
library = importlib.import_module("diffusers")
|
132
|
+
cls._hf_class = getattr(library, hf_cls_name, None)
|
133
|
+
return cls._hf_class
|
134
|
+
|
135
|
+
@classmethod
|
150
136
|
def from_pretrained(
|
151
137
|
cls,
|
152
138
|
model_id: str,
|
@@ -158,7 +144,49 @@ class RBLNDiffusionMixin:
|
|
158
144
|
lora_weights_names: Optional[Union[str, List[str]]] = None,
|
159
145
|
lora_scales: Optional[Union[float, List[float]]] = None,
|
160
146
|
**kwargs,
|
161
|
-
) ->
|
147
|
+
) -> "RBLNDiffusionMixin":
|
148
|
+
"""
|
149
|
+
Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
|
150
|
+
|
151
|
+
This method has two distinct operating modes:
|
152
|
+
- When `export=True`: Takes a PyTorch-based diffusion model, compiles it for RBLN NPUs, and loads the compiled model
|
153
|
+
- When `export=False`: Loads an already compiled RBLN model from `model_id` without recompilation
|
154
|
+
|
155
|
+
It supports various diffusion pipelines including Stable Diffusion, Kandinsky, ControlNet, and other diffusers-based models.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
model_id (`str`):
|
159
|
+
The model ID or path to the pretrained model to load. Can be either:
|
160
|
+
- A model ID from the HuggingFace Hub
|
161
|
+
- A local path to a saved model directory
|
162
|
+
export (`bool`, *optional*, defaults to `False`):
|
163
|
+
If True, takes a PyTorch model from `model_id` and compiles it for RBLN NPU execution.
|
164
|
+
If False, loads an already compiled RBLN model from `model_id` without recompilation.
|
165
|
+
model_save_dir (`os.PathLike`, *optional*):
|
166
|
+
Directory to save the compiled model artifacts. Only used when `export=True`.
|
167
|
+
If not provided and `export=True`, a temporary directory is used.
|
168
|
+
rbln_config (`Dict[str, Any]`, *optional*, defaults to `{}`):
|
169
|
+
Configuration options for RBLN compilation. Can include settings for specific submodules
|
170
|
+
such as `text_encoder`, `unet`, and `vae`. Configuration can be tailored to the specific
|
171
|
+
pipeline being compiled.
|
172
|
+
lora_ids (`str` or `List[str]`, *optional*):
|
173
|
+
LoRA adapter ID(s) to load and apply before compilation. LoRA weights are fused
|
174
|
+
into the model weights during compilation. Only used when `export=True`.
|
175
|
+
lora_weights_names (`str` or `List[str]`, *optional*):
|
176
|
+
Names of specific LoRA weight files to load, corresponding to lora_ids. Only used when `export=True`.
|
177
|
+
lora_scales (`float` or `List[float]`, *optional*):
|
178
|
+
Scaling factor(s) to apply to the LoRA adapter(s). Only used when `export=True`.
|
179
|
+
**kwargs:
|
180
|
+
Additional arguments to pass to the underlying diffusion pipeline constructor or the
|
181
|
+
RBLN compilation process. These may include parameters specific to individual submodules
|
182
|
+
or the particular diffusion pipeline being used.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
`RBLNDiffusionMixin`: A compiled or loaded diffusion pipeline that can be used for inference on RBLN NPU.
|
186
|
+
The returned object is an instance of the class that called this method, inheriting from RBLNDiffusionMixin.
|
187
|
+
"""
|
188
|
+
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
189
|
+
|
162
190
|
if export:
|
163
191
|
# keep submodules if user passed any of them.
|
164
192
|
passed_submodules = {
|
@@ -168,22 +196,12 @@ class RBLNDiffusionMixin:
|
|
168
196
|
else:
|
169
197
|
# raise error if any of submodules are torch module.
|
170
198
|
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
171
|
-
rbln_config = cls._flatten_rbln_config(rbln_config)
|
172
199
|
for submodule_name in cls._submodules:
|
173
200
|
if isinstance(kwargs.get(submodule_name), torch.nn.Module):
|
174
201
|
raise AssertionError(
|
175
202
|
f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
|
176
203
|
)
|
177
204
|
|
178
|
-
submodule_config = rbln_config.get(submodule_name, {})
|
179
|
-
|
180
|
-
for key, value in rbln_config.items():
|
181
|
-
if key in RUNTIME_KEYWORDS and key not in submodule_config:
|
182
|
-
submodule_config[key] = value
|
183
|
-
|
184
|
-
if not any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
|
185
|
-
continue
|
186
|
-
|
187
205
|
module_name, class_name = model_index_config[submodule_name]
|
188
206
|
if module_name != "optimum.rbln":
|
189
207
|
raise ValueError(
|
@@ -192,19 +210,19 @@ class RBLNDiffusionMixin:
|
|
192
210
|
"Expected 'optimum.rbln'. Please check the model_index.json configuration."
|
193
211
|
)
|
194
212
|
|
195
|
-
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), class_name)
|
196
|
-
|
213
|
+
submodule_cls: Type[RBLNModel] = getattr(importlib.import_module("optimum.rbln"), class_name)
|
214
|
+
submodule_config = getattr(rbln_config, submodule_name)
|
197
215
|
submodule = submodule_cls.from_pretrained(
|
198
216
|
model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
|
199
217
|
)
|
200
218
|
kwargs[submodule_name] = submodule
|
201
219
|
|
202
220
|
with ContextRblnConfig(
|
203
|
-
device=rbln_config.
|
204
|
-
device_map=rbln_config.
|
205
|
-
create_runtimes=rbln_config.
|
206
|
-
optimize_host_mem=rbln_config.
|
207
|
-
activate_profiler=rbln_config.
|
221
|
+
device=rbln_config.device,
|
222
|
+
device_map=rbln_config.device_map,
|
223
|
+
create_runtimes=rbln_config.create_runtimes,
|
224
|
+
optimize_host_mem=rbln_config.optimize_host_memory,
|
225
|
+
activate_profiler=rbln_config.activate_profiler,
|
208
226
|
):
|
209
227
|
model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
|
210
228
|
|
@@ -224,78 +242,27 @@ class RBLNDiffusionMixin:
|
|
224
242
|
compiled_submodules = cls._compile_submodules(model, passed_submodules, model_save_dir, rbln_config)
|
225
243
|
return cls._construct_pipe(model, compiled_submodules, model_save_dir, rbln_config)
|
226
244
|
|
227
|
-
@classmethod
|
228
|
-
def _prepare_rbln_config(
|
229
|
-
cls,
|
230
|
-
rbln_config,
|
231
|
-
) -> Dict[str, Any]:
|
232
|
-
prepared_config = {}
|
233
|
-
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
234
|
-
connected_pipe_config = rbln_config.pop(connected_pipe_name, {})
|
235
|
-
prefix = cls._prefix.get(connected_pipe_name, "")
|
236
|
-
guidance_scale = rbln_config.pop(f"{prefix}guidance_scale", None)
|
237
|
-
if "guidance_scale" not in connected_pipe_config and guidance_scale is not None:
|
238
|
-
connected_pipe_config["guidance_scale"] = guidance_scale
|
239
|
-
for submodule_name in connected_pipe_cls._submodules:
|
240
|
-
submodule_config = rbln_config.pop(prefix + submodule_name, {})
|
241
|
-
if submodule_name not in connected_pipe_config:
|
242
|
-
connected_pipe_config[submodule_name] = {}
|
243
|
-
connected_pipe_config[submodule_name].update(
|
244
|
-
{k: v for k, v in submodule_config.items() if k not in connected_pipe_config[submodule_name]}
|
245
|
-
)
|
246
|
-
prepared_config[connected_pipe_name] = connected_pipe_config
|
247
|
-
prepared_config.update(rbln_config)
|
248
|
-
return prepared_config
|
249
|
-
|
250
|
-
@classmethod
|
251
|
-
def _flatten_rbln_config(
|
252
|
-
cls,
|
253
|
-
rbln_config,
|
254
|
-
) -> Dict[str, Any]:
|
255
|
-
prepared_config = cls._prepare_rbln_config(rbln_config)
|
256
|
-
flattened_config = {}
|
257
|
-
pipe_global_config = {k: v for k, v in prepared_config.items() if k not in cls._connected_classes.keys()}
|
258
|
-
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
259
|
-
connected_pipe_config = prepared_config.pop(connected_pipe_name)
|
260
|
-
prefix = cls._prefix.get(connected_pipe_name, "")
|
261
|
-
connected_pipe_global_config = {
|
262
|
-
k: v for k, v in connected_pipe_config.items() if k not in connected_pipe_cls._submodules
|
263
|
-
}
|
264
|
-
for submodule_name in connected_pipe_cls._submodules:
|
265
|
-
flattened_config[prefix + submodule_name] = connected_pipe_config[submodule_name]
|
266
|
-
flattened_config[prefix + submodule_name].update(
|
267
|
-
{
|
268
|
-
k: v
|
269
|
-
for k, v in connected_pipe_global_config.items()
|
270
|
-
if k not in flattened_config[prefix + submodule_name]
|
271
|
-
}
|
272
|
-
)
|
273
|
-
flattened_config.update(pipe_global_config)
|
274
|
-
return flattened_config
|
275
|
-
|
276
245
|
@classmethod
|
277
246
|
def _compile_pipelines(
|
278
247
|
cls,
|
279
248
|
model: torch.nn.Module,
|
280
249
|
passed_submodules: Dict[str, RBLNModel],
|
281
250
|
model_save_dir: Optional[PathLike],
|
282
|
-
rbln_config:
|
251
|
+
rbln_config: "RBLNDiffusionMixinConfig",
|
283
252
|
) -> Dict[str, RBLNModel]:
|
284
253
|
compiled_submodules = {}
|
285
|
-
|
286
|
-
rbln_config = cls._prepare_rbln_config(rbln_config)
|
287
|
-
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._connected_classes.keys()}
|
288
254
|
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
289
255
|
connected_pipe_submodules = {}
|
290
256
|
prefix = cls._prefix.get(connected_pipe_name, "")
|
291
257
|
for submodule_name in connected_pipe_cls._submodules:
|
292
258
|
connected_pipe_submodules[submodule_name] = passed_submodules.get(prefix + submodule_name, None)
|
293
259
|
connected_pipe = getattr(model, connected_pipe_name)
|
294
|
-
connected_pipe_config = {}
|
295
|
-
connected_pipe_config.update(pipe_global_config)
|
296
|
-
connected_pipe_config.update(rbln_config[connected_pipe_name])
|
297
260
|
connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
|
298
|
-
connected_pipe,
|
261
|
+
connected_pipe,
|
262
|
+
connected_pipe_submodules,
|
263
|
+
model_save_dir,
|
264
|
+
getattr(rbln_config, connected_pipe_name),
|
265
|
+
prefix,
|
299
266
|
)
|
300
267
|
for submodule_name, compiled_submodule in connected_pipe_compiled_submodules.items():
|
301
268
|
compiled_submodules[prefix + submodule_name] = compiled_submodule
|
@@ -307,14 +274,19 @@ class RBLNDiffusionMixin:
|
|
307
274
|
model: torch.nn.Module,
|
308
275
|
passed_submodules: Dict[str, RBLNModel],
|
309
276
|
model_save_dir: Optional[PathLike],
|
310
|
-
rbln_config:
|
277
|
+
rbln_config: RBLNDiffusionMixinConfig,
|
311
278
|
prefix: Optional[str] = "",
|
312
279
|
) -> Dict[str, RBLNModel]:
|
313
280
|
compiled_submodules = {}
|
314
281
|
|
315
282
|
for submodule_name in cls._submodules:
|
316
283
|
submodule = passed_submodules.get(submodule_name) or getattr(model, submodule_name, None)
|
317
|
-
|
284
|
+
|
285
|
+
if getattr(rbln_config, submodule_name, None) is None:
|
286
|
+
raise ValueError(f"RBLN config for submodule {submodule_name} is not provided.")
|
287
|
+
|
288
|
+
submodule_rbln_cls: Type[RBLNModel] = getattr(rbln_config, submodule_name).rbln_model_cls
|
289
|
+
rbln_config = submodule_rbln_cls.update_rbln_config_using_pipe(model, rbln_config, submodule_name)
|
318
290
|
|
319
291
|
if submodule is None:
|
320
292
|
raise ValueError(f"submodule ({submodule_name}) cannot be accessed since it is not provided.")
|
@@ -325,7 +297,7 @@ class RBLNDiffusionMixin:
|
|
325
297
|
submodule = cls._compile_multicontrolnet(
|
326
298
|
controlnets=submodule,
|
327
299
|
model_save_dir=model_save_dir,
|
328
|
-
controlnet_rbln_config=
|
300
|
+
controlnet_rbln_config=getattr(rbln_config, submodule_name),
|
329
301
|
prefix=prefix,
|
330
302
|
)
|
331
303
|
elif isinstance(submodule, torch.nn.Module):
|
@@ -337,7 +309,7 @@ class RBLNDiffusionMixin:
|
|
337
309
|
model=submodule,
|
338
310
|
subfolder=subfolder,
|
339
311
|
model_save_dir=model_save_dir,
|
340
|
-
rbln_config=
|
312
|
+
rbln_config=getattr(rbln_config, submodule_name),
|
341
313
|
)
|
342
314
|
else:
|
343
315
|
raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
|
@@ -350,22 +322,24 @@ class RBLNDiffusionMixin:
|
|
350
322
|
cls,
|
351
323
|
controlnets: "MultiControlNetModel",
|
352
324
|
model_save_dir: Optional[PathLike],
|
353
|
-
controlnet_rbln_config:
|
325
|
+
controlnet_rbln_config: RBLNModelConfig,
|
354
326
|
prefix: Optional[str] = "",
|
355
327
|
):
|
356
328
|
# Compile multiple ControlNet models for a MultiControlNet setup
|
357
329
|
from .models.controlnet import RBLNControlNetModel
|
358
330
|
from .pipelines.controlnet import RBLNMultiControlNetModel
|
359
331
|
|
360
|
-
compiled_controlnets = [
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
332
|
+
compiled_controlnets = []
|
333
|
+
for i, controlnet in enumerate(controlnets.nets):
|
334
|
+
_controlnet_rbln_config = copy.deepcopy(controlnet_rbln_config)
|
335
|
+
compiled_controlnets.append(
|
336
|
+
RBLNControlNetModel.from_model(
|
337
|
+
model=controlnet,
|
338
|
+
subfolder=f"{prefix}controlnet" if i == 0 else f"{prefix}controlnet_{i}",
|
339
|
+
model_save_dir=model_save_dir,
|
340
|
+
rbln_config=_controlnet_rbln_config,
|
341
|
+
)
|
366
342
|
)
|
367
|
-
for i, controlnet in enumerate(controlnets.nets)
|
368
|
-
]
|
369
343
|
return RBLNMultiControlNetModel(compiled_controlnets)
|
370
344
|
|
371
345
|
@classmethod
|
@@ -412,7 +386,7 @@ class RBLNDiffusionMixin:
|
|
412
386
|
# overwrite to replace incorrect config
|
413
387
|
model.save_config(model_save_dir)
|
414
388
|
|
415
|
-
if rbln_config.
|
389
|
+
if rbln_config.optimize_host_memory is False:
|
416
390
|
# Keep compiled_model objs to further analysis. -> TODO: remove soon...
|
417
391
|
model.compiled_models = []
|
418
392
|
for name in cls._submodules:
|
@@ -441,9 +415,9 @@ class RBLNDiffusionMixin:
|
|
441
415
|
kwargs["height"] = compiled_image_size[0]
|
442
416
|
kwargs["width"] = compiled_image_size[1]
|
443
417
|
|
444
|
-
compiled_num_frames = self.unet.rbln_config.
|
418
|
+
compiled_num_frames = self.unet.rbln_config.num_frames
|
445
419
|
if compiled_num_frames is not None:
|
446
|
-
kwargs["num_frames"] =
|
420
|
+
kwargs["num_frames"] = compiled_num_frames
|
447
421
|
return kwargs
|
448
422
|
```
|
449
423
|
"""
|