optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +156 -36
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/configuration_utils.py +772 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +63 -122
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +55 -70
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- optimum/rbln/modeling.py +58 -39
- optimum/rbln/modeling_base.py +85 -75
- optimum/rbln/transformers/__init__.py +79 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +96 -34
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/submodule.py +26 -43
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
- optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,143 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional, Tuple
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
from ....transformers import RBLNCLIPTextModelWithProjectionConfig, RBLNT5EncoderModelConfig
|
19
|
+
from ....utils.logging import get_logger
|
20
|
+
from ..models import RBLNAutoencoderKLConfig, RBLNSD3Transformer2DModelConfig
|
21
|
+
|
22
|
+
|
23
|
+
logger = get_logger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class _RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
|
27
|
+
submodules = ["transformer", "text_encoder", "text_encoder_2", "text_encoder_3", "vae"]
|
28
|
+
_vae_uses_encoder = False
|
29
|
+
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
transformer: Optional[RBLNSD3Transformer2DModelConfig] = None,
|
33
|
+
text_encoder: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
|
34
|
+
text_encoder_2: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
|
35
|
+
text_encoder_3: Optional[RBLNT5EncoderModelConfig] = None,
|
36
|
+
vae: Optional[RBLNAutoencoderKLConfig] = None,
|
37
|
+
*,
|
38
|
+
max_seq_len: Optional[int] = None,
|
39
|
+
sample_size: Optional[Tuple[int, int]] = None,
|
40
|
+
image_size: Optional[Tuple[int, int]] = None,
|
41
|
+
batch_size: Optional[int] = None,
|
42
|
+
img_height: Optional[int] = None,
|
43
|
+
img_width: Optional[int] = None,
|
44
|
+
guidance_scale: Optional[float] = None,
|
45
|
+
**kwargs,
|
46
|
+
):
|
47
|
+
"""
|
48
|
+
Args:
|
49
|
+
transformer (Optional[RBLNSD3Transformer2DModelConfig]): Configuration for the transformer model component.
|
50
|
+
Initialized as RBLNSD3Transformer2DModelConfig if not provided.
|
51
|
+
text_encoder (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the primary text encoder.
|
52
|
+
Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
|
53
|
+
text_encoder_2 (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the secondary text encoder.
|
54
|
+
Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
|
55
|
+
text_encoder_3 (Optional[RBLNT5EncoderModelConfig]): Configuration for the tertiary text encoder.
|
56
|
+
Initialized as RBLNT5EncoderModelConfig if not provided.
|
57
|
+
vae (Optional[RBLNAutoencoderKLConfig]): Configuration for the VAE model component.
|
58
|
+
Initialized as RBLNAutoencoderKLConfig if not provided.
|
59
|
+
max_seq_len (Optional[int]): Maximum sequence length for text inputs. Defaults to 256.
|
60
|
+
sample_size (Optional[Tuple[int, int]]): Spatial dimensions for the transformer model.
|
61
|
+
image_size (Optional[Tuple[int, int]]): Dimensions for the generated images.
|
62
|
+
Cannot be used together with img_height/img_width.
|
63
|
+
batch_size (Optional[int]): Batch size for inference, applied to all submodules.
|
64
|
+
img_height (Optional[int]): Height of the generated images.
|
65
|
+
img_width (Optional[int]): Width of the generated images.
|
66
|
+
guidance_scale (Optional[float]): Scale for classifier-free guidance. Deprecated parameter.
|
67
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
68
|
+
|
69
|
+
Raises:
|
70
|
+
ValueError: If both image_size and img_height/img_width are provided.
|
71
|
+
|
72
|
+
Note:
|
73
|
+
When guidance_scale > 1.0, the transformer batch size is automatically doubled to
|
74
|
+
accommodate classifier-free guidance.
|
75
|
+
"""
|
76
|
+
super().__init__(**kwargs)
|
77
|
+
if image_size is not None and (img_height is not None or img_width is not None):
|
78
|
+
raise ValueError("image_size and img_height/img_width cannot both be provided")
|
79
|
+
|
80
|
+
if img_height is not None and img_width is not None:
|
81
|
+
image_size = (img_height, img_width)
|
82
|
+
|
83
|
+
max_seq_len = max_seq_len or 256
|
84
|
+
|
85
|
+
self.text_encoder = self.init_submodule_config(
|
86
|
+
RBLNCLIPTextModelWithProjectionConfig, text_encoder, batch_size=batch_size
|
87
|
+
)
|
88
|
+
self.text_encoder_2 = self.init_submodule_config(
|
89
|
+
RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
|
90
|
+
)
|
91
|
+
self.text_encoder_3 = self.init_submodule_config(
|
92
|
+
RBLNT5EncoderModelConfig,
|
93
|
+
text_encoder_3,
|
94
|
+
batch_size=batch_size,
|
95
|
+
max_seq_len=max_seq_len,
|
96
|
+
)
|
97
|
+
self.transformer = self.init_submodule_config(
|
98
|
+
RBLNSD3Transformer2DModelConfig,
|
99
|
+
transformer,
|
100
|
+
batch_size=batch_size,
|
101
|
+
sample_size=sample_size,
|
102
|
+
)
|
103
|
+
self.vae = self.init_submodule_config(
|
104
|
+
RBLNAutoencoderKLConfig,
|
105
|
+
vae,
|
106
|
+
batch_size=batch_size,
|
107
|
+
uses_encoder=self.__class__._vae_uses_encoder,
|
108
|
+
sample_size=image_size,
|
109
|
+
)
|
110
|
+
|
111
|
+
if guidance_scale is not None:
|
112
|
+
logger.warning("Specifying `guidance_scale` is deprecated. It will be removed in a future version.")
|
113
|
+
do_classifier_free_guidance = guidance_scale > 1.0
|
114
|
+
if do_classifier_free_guidance:
|
115
|
+
self.transformer.batch_size = self.text_encoder.batch_size * 2
|
116
|
+
|
117
|
+
@property
|
118
|
+
def max_seq_len(self):
|
119
|
+
return self.text_encoder_3.max_seq_len
|
120
|
+
|
121
|
+
@property
|
122
|
+
def batch_size(self):
|
123
|
+
return self.vae.batch_size
|
124
|
+
|
125
|
+
@property
|
126
|
+
def sample_size(self):
|
127
|
+
return self.transformer.sample_size
|
128
|
+
|
129
|
+
@property
|
130
|
+
def image_size(self):
|
131
|
+
return self.vae.sample_size
|
132
|
+
|
133
|
+
|
134
|
+
class RBLNStableDiffusion3PipelineConfig(_RBLNStableDiffusion3PipelineBaseConfig):
|
135
|
+
_vae_uses_encoder = False
|
136
|
+
|
137
|
+
|
138
|
+
class RBLNStableDiffusion3Img2ImgPipelineConfig(_RBLNStableDiffusion3PipelineBaseConfig):
|
139
|
+
_vae_uses_encoder = True
|
140
|
+
|
141
|
+
|
142
|
+
class RBLNStableDiffusion3InpaintPipelineConfig(_RBLNStableDiffusion3PipelineBaseConfig):
|
143
|
+
_vae_uses_encoder = True
|
@@ -0,0 +1,124 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional, Tuple
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
from ....transformers import RBLNCLIPTextModelConfig, RBLNCLIPTextModelWithProjectionConfig
|
19
|
+
from ....utils.logging import get_logger
|
20
|
+
from ..models import RBLNAutoencoderKLConfig, RBLNUNet2DConditionModelConfig
|
21
|
+
|
22
|
+
|
23
|
+
logger = get_logger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class _RBLNStableDiffusionXLPipelineBaseConfig(RBLNModelConfig):
|
27
|
+
submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
|
28
|
+
_vae_uses_encoder = False
|
29
|
+
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
text_encoder: Optional[RBLNCLIPTextModelConfig] = None,
|
33
|
+
text_encoder_2: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
|
34
|
+
unet: Optional[RBLNUNet2DConditionModelConfig] = None,
|
35
|
+
vae: Optional[RBLNAutoencoderKLConfig] = None,
|
36
|
+
*,
|
37
|
+
batch_size: Optional[int] = None,
|
38
|
+
img_height: Optional[int] = None,
|
39
|
+
img_width: Optional[int] = None,
|
40
|
+
sample_size: Optional[Tuple[int, int]] = None,
|
41
|
+
image_size: Optional[Tuple[int, int]] = None,
|
42
|
+
guidance_scale: Optional[float] = None,
|
43
|
+
**kwargs,
|
44
|
+
):
|
45
|
+
"""
|
46
|
+
Args:
|
47
|
+
text_encoder (Optional[RBLNCLIPTextModelConfig]): Configuration for the primary text encoder component.
|
48
|
+
Initialized as RBLNCLIPTextModelConfig if not provided.
|
49
|
+
text_encoder_2 (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the secondary text encoder component.
|
50
|
+
Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
|
51
|
+
unet (Optional[RBLNUNet2DConditionModelConfig]): Configuration for the UNet model component.
|
52
|
+
Initialized as RBLNUNet2DConditionModelConfig if not provided.
|
53
|
+
vae (Optional[RBLNAutoencoderKLConfig]): Configuration for the VAE model component.
|
54
|
+
Initialized as RBLNAutoencoderKLConfig if not provided.
|
55
|
+
batch_size (Optional[int]): Batch size for inference, applied to all submodules.
|
56
|
+
img_height (Optional[int]): Height of the generated images.
|
57
|
+
img_width (Optional[int]): Width of the generated images.
|
58
|
+
sample_size (Optional[Tuple[int, int]]): Spatial dimensions for the UNet model.
|
59
|
+
image_size (Optional[Tuple[int, int]]): Alternative way to specify image dimensions.
|
60
|
+
Cannot be used together with img_height/img_width.
|
61
|
+
guidance_scale (Optional[float]): Scale for classifier-free guidance. Deprecated parameter.
|
62
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
63
|
+
|
64
|
+
Raises:
|
65
|
+
ValueError: If both image_size and img_height/img_width are provided.
|
66
|
+
|
67
|
+
Note:
|
68
|
+
When guidance_scale > 1.0, the UNet batch size is automatically doubled to
|
69
|
+
accommodate classifier-free guidance.
|
70
|
+
"""
|
71
|
+
super().__init__(**kwargs)
|
72
|
+
if image_size is not None and (img_height is not None or img_width is not None):
|
73
|
+
raise ValueError("image_size and img_height/img_width cannot both be provided")
|
74
|
+
|
75
|
+
if img_height is not None and img_width is not None:
|
76
|
+
image_size = (img_height, img_width)
|
77
|
+
|
78
|
+
self.text_encoder = self.init_submodule_config(RBLNCLIPTextModelConfig, text_encoder, batch_size=batch_size)
|
79
|
+
self.text_encoder_2 = self.init_submodule_config(
|
80
|
+
RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
|
81
|
+
)
|
82
|
+
self.unet = self.init_submodule_config(
|
83
|
+
RBLNUNet2DConditionModelConfig,
|
84
|
+
unet,
|
85
|
+
batch_size=batch_size,
|
86
|
+
sample_size=sample_size,
|
87
|
+
)
|
88
|
+
self.vae = self.init_submodule_config(
|
89
|
+
RBLNAutoencoderKLConfig,
|
90
|
+
vae,
|
91
|
+
batch_size=batch_size,
|
92
|
+
uses_encoder=self.__class__._vae_uses_encoder,
|
93
|
+
sample_size=image_size, # image size is equal to sample size in vae
|
94
|
+
)
|
95
|
+
|
96
|
+
if guidance_scale is not None:
|
97
|
+
logger.warning("Specifying `guidance_scale` is deprecated. It will be removed in a future version.")
|
98
|
+
do_classifier_free_guidance = guidance_scale > 1.0
|
99
|
+
if do_classifier_free_guidance:
|
100
|
+
self.unet.batch_size = self.text_encoder.batch_size * 2
|
101
|
+
|
102
|
+
@property
|
103
|
+
def batch_size(self):
|
104
|
+
return self.vae.batch_size
|
105
|
+
|
106
|
+
@property
|
107
|
+
def sample_size(self):
|
108
|
+
return self.unet.sample_size
|
109
|
+
|
110
|
+
@property
|
111
|
+
def image_size(self):
|
112
|
+
return self.vae.sample_size
|
113
|
+
|
114
|
+
|
115
|
+
class RBLNStableDiffusionXLPipelineConfig(_RBLNStableDiffusionXLPipelineBaseConfig):
|
116
|
+
_vae_uses_encoder = False
|
117
|
+
|
118
|
+
|
119
|
+
class RBLNStableDiffusionXLImg2ImgPipelineConfig(_RBLNStableDiffusionXLPipelineBaseConfig):
|
120
|
+
_vae_uses_encoder = True
|
121
|
+
|
122
|
+
|
123
|
+
class RBLNStableDiffusionXLInpaintPipelineConfig(_RBLNStableDiffusionXLPipelineBaseConfig):
|
124
|
+
_vae_uses_encoder = True
|
@@ -15,12 +15,14 @@
|
|
15
15
|
import copy
|
16
16
|
import importlib
|
17
17
|
from os import PathLike
|
18
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
18
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
19
19
|
|
20
20
|
import torch
|
21
21
|
|
22
|
+
from ..configuration_utils import ContextRblnConfig, RBLNModelConfig
|
22
23
|
from ..modeling import RBLNModel
|
23
|
-
|
24
|
+
|
25
|
+
# from ..transformers import RBLNCLIPTextModelConfig
|
24
26
|
from ..utils.decorator_utils import remove_compile_time_kwargs
|
25
27
|
from ..utils.logging import get_logger
|
26
28
|
|
@@ -31,6 +33,10 @@ if TYPE_CHECKING:
|
|
31
33
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
32
34
|
|
33
35
|
|
36
|
+
class RBLNDiffusionMixinConfig(RBLNModelConfig):
|
37
|
+
pass
|
38
|
+
|
39
|
+
|
34
40
|
class RBLNDiffusionMixin:
|
35
41
|
"""
|
36
42
|
RBLNDiffusionMixin provides essential functionalities for compiling Stable Diffusion pipeline components to run on RBLN NPUs.
|
@@ -69,6 +75,7 @@ class RBLNDiffusionMixin:
|
|
69
75
|
_connected_classes = {}
|
70
76
|
_submodules = []
|
71
77
|
_prefix = {}
|
78
|
+
_rbln_config_class = None
|
72
79
|
|
73
80
|
@classmethod
|
74
81
|
def is_img2img_pipeline(cls):
|
@@ -78,35 +85,6 @@ class RBLNDiffusionMixin:
|
|
78
85
|
def is_inpaint_pipeline(cls):
|
79
86
|
return "Inpaint" in cls.__name__
|
80
87
|
|
81
|
-
@classmethod
|
82
|
-
def get_submodule_rbln_config(
|
83
|
-
cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
|
84
|
-
) -> Dict[str, Any]:
|
85
|
-
submodule = getattr(model, submodule_name)
|
86
|
-
submodule_class_name = submodule.__class__.__name__
|
87
|
-
if isinstance(submodule, torch.nn.Module):
|
88
|
-
if submodule_class_name == "MultiControlNetModel":
|
89
|
-
submodule_class_name = "ControlNetModel"
|
90
|
-
|
91
|
-
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
|
92
|
-
|
93
|
-
submodule_config = rbln_config.get(submodule_name, {})
|
94
|
-
submodule_config = copy.deepcopy(submodule_config)
|
95
|
-
|
96
|
-
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
|
97
|
-
|
98
|
-
submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
|
99
|
-
submodule_config.update(
|
100
|
-
{
|
101
|
-
"img2img_pipeline": cls.is_img2img_pipeline(),
|
102
|
-
"inpaint_pipeline": cls.is_inpaint_pipeline(),
|
103
|
-
}
|
104
|
-
)
|
105
|
-
submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
|
106
|
-
else:
|
107
|
-
raise ValueError(f"submodule {submodule_name} isn't supported")
|
108
|
-
return submodule_config
|
109
|
-
|
110
88
|
@staticmethod
|
111
89
|
def _maybe_apply_and_fuse_lora(
|
112
90
|
model: torch.nn.Module,
|
@@ -146,7 +124,22 @@ class RBLNDiffusionMixin:
|
|
146
124
|
return model
|
147
125
|
|
148
126
|
@classmethod
|
149
|
-
|
127
|
+
def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
|
128
|
+
"""
|
129
|
+
Lazily loads and caches the corresponding RBLN model config class.
|
130
|
+
"""
|
131
|
+
if cls._rbln_config_class is None:
|
132
|
+
rbln_config_class_name = cls.__name__ + "Config"
|
133
|
+
library = importlib.import_module("optimum.rbln")
|
134
|
+
cls._rbln_config_class = getattr(library, rbln_config_class_name, None)
|
135
|
+
if cls._rbln_config_class is None:
|
136
|
+
raise ValueError(
|
137
|
+
f"RBLN config class {rbln_config_class_name} not found. This is an internal error. "
|
138
|
+
"Please report it to the developers."
|
139
|
+
)
|
140
|
+
return cls._rbln_config_class
|
141
|
+
|
142
|
+
@classmethod
|
150
143
|
def from_pretrained(
|
151
144
|
cls,
|
152
145
|
model_id: str,
|
@@ -159,6 +152,8 @@ class RBLNDiffusionMixin:
|
|
159
152
|
lora_scales: Optional[Union[float, List[float]]] = None,
|
160
153
|
**kwargs,
|
161
154
|
) -> RBLNModel:
|
155
|
+
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
156
|
+
|
162
157
|
if export:
|
163
158
|
# keep submodules if user passed any of them.
|
164
159
|
passed_submodules = {
|
@@ -168,22 +163,12 @@ class RBLNDiffusionMixin:
|
|
168
163
|
else:
|
169
164
|
# raise error if any of submodules are torch module.
|
170
165
|
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
171
|
-
rbln_config = cls._flatten_rbln_config(rbln_config)
|
172
166
|
for submodule_name in cls._submodules:
|
173
167
|
if isinstance(kwargs.get(submodule_name), torch.nn.Module):
|
174
168
|
raise AssertionError(
|
175
169
|
f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
|
176
170
|
)
|
177
171
|
|
178
|
-
submodule_config = rbln_config.get(submodule_name, {})
|
179
|
-
|
180
|
-
for key, value in rbln_config.items():
|
181
|
-
if key in RUNTIME_KEYWORDS and key not in submodule_config:
|
182
|
-
submodule_config[key] = value
|
183
|
-
|
184
|
-
if not any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
|
185
|
-
continue
|
186
|
-
|
187
172
|
module_name, class_name = model_index_config[submodule_name]
|
188
173
|
if module_name != "optimum.rbln":
|
189
174
|
raise ValueError(
|
@@ -192,19 +177,19 @@ class RBLNDiffusionMixin:
|
|
192
177
|
"Expected 'optimum.rbln'. Please check the model_index.json configuration."
|
193
178
|
)
|
194
179
|
|
195
|
-
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), class_name)
|
196
|
-
|
180
|
+
submodule_cls: Type[RBLNModel] = getattr(importlib.import_module("optimum.rbln"), class_name)
|
181
|
+
submodule_config = getattr(rbln_config, submodule_name)
|
197
182
|
submodule = submodule_cls.from_pretrained(
|
198
183
|
model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
|
199
184
|
)
|
200
185
|
kwargs[submodule_name] = submodule
|
201
186
|
|
202
187
|
with ContextRblnConfig(
|
203
|
-
device=rbln_config.
|
204
|
-
device_map=rbln_config.
|
205
|
-
create_runtimes=rbln_config.
|
206
|
-
optimize_host_mem=rbln_config.
|
207
|
-
activate_profiler=rbln_config.
|
188
|
+
device=rbln_config.device,
|
189
|
+
device_map=rbln_config.device_map,
|
190
|
+
create_runtimes=rbln_config.create_runtimes,
|
191
|
+
optimize_host_mem=rbln_config.optimize_host_memory,
|
192
|
+
activate_profiler=rbln_config.activate_profiler,
|
208
193
|
):
|
209
194
|
model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
|
210
195
|
|
@@ -224,78 +209,27 @@ class RBLNDiffusionMixin:
|
|
224
209
|
compiled_submodules = cls._compile_submodules(model, passed_submodules, model_save_dir, rbln_config)
|
225
210
|
return cls._construct_pipe(model, compiled_submodules, model_save_dir, rbln_config)
|
226
211
|
|
227
|
-
@classmethod
|
228
|
-
def _prepare_rbln_config(
|
229
|
-
cls,
|
230
|
-
rbln_config,
|
231
|
-
) -> Dict[str, Any]:
|
232
|
-
prepared_config = {}
|
233
|
-
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
234
|
-
connected_pipe_config = rbln_config.pop(connected_pipe_name, {})
|
235
|
-
prefix = cls._prefix.get(connected_pipe_name, "")
|
236
|
-
guidance_scale = rbln_config.pop(f"{prefix}guidance_scale", None)
|
237
|
-
if "guidance_scale" not in connected_pipe_config and guidance_scale is not None:
|
238
|
-
connected_pipe_config["guidance_scale"] = guidance_scale
|
239
|
-
for submodule_name in connected_pipe_cls._submodules:
|
240
|
-
submodule_config = rbln_config.pop(prefix + submodule_name, {})
|
241
|
-
if submodule_name not in connected_pipe_config:
|
242
|
-
connected_pipe_config[submodule_name] = {}
|
243
|
-
connected_pipe_config[submodule_name].update(
|
244
|
-
{k: v for k, v in submodule_config.items() if k not in connected_pipe_config[submodule_name]}
|
245
|
-
)
|
246
|
-
prepared_config[connected_pipe_name] = connected_pipe_config
|
247
|
-
prepared_config.update(rbln_config)
|
248
|
-
return prepared_config
|
249
|
-
|
250
|
-
@classmethod
|
251
|
-
def _flatten_rbln_config(
|
252
|
-
cls,
|
253
|
-
rbln_config,
|
254
|
-
) -> Dict[str, Any]:
|
255
|
-
prepared_config = cls._prepare_rbln_config(rbln_config)
|
256
|
-
flattened_config = {}
|
257
|
-
pipe_global_config = {k: v for k, v in prepared_config.items() if k not in cls._connected_classes.keys()}
|
258
|
-
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
259
|
-
connected_pipe_config = prepared_config.pop(connected_pipe_name)
|
260
|
-
prefix = cls._prefix.get(connected_pipe_name, "")
|
261
|
-
connected_pipe_global_config = {
|
262
|
-
k: v for k, v in connected_pipe_config.items() if k not in connected_pipe_cls._submodules
|
263
|
-
}
|
264
|
-
for submodule_name in connected_pipe_cls._submodules:
|
265
|
-
flattened_config[prefix + submodule_name] = connected_pipe_config[submodule_name]
|
266
|
-
flattened_config[prefix + submodule_name].update(
|
267
|
-
{
|
268
|
-
k: v
|
269
|
-
for k, v in connected_pipe_global_config.items()
|
270
|
-
if k not in flattened_config[prefix + submodule_name]
|
271
|
-
}
|
272
|
-
)
|
273
|
-
flattened_config.update(pipe_global_config)
|
274
|
-
return flattened_config
|
275
|
-
|
276
212
|
@classmethod
|
277
213
|
def _compile_pipelines(
|
278
214
|
cls,
|
279
215
|
model: torch.nn.Module,
|
280
216
|
passed_submodules: Dict[str, RBLNModel],
|
281
217
|
model_save_dir: Optional[PathLike],
|
282
|
-
rbln_config:
|
218
|
+
rbln_config: "RBLNDiffusionMixinConfig",
|
283
219
|
) -> Dict[str, RBLNModel]:
|
284
220
|
compiled_submodules = {}
|
285
|
-
|
286
|
-
rbln_config = cls._prepare_rbln_config(rbln_config)
|
287
|
-
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._connected_classes.keys()}
|
288
221
|
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
289
222
|
connected_pipe_submodules = {}
|
290
223
|
prefix = cls._prefix.get(connected_pipe_name, "")
|
291
224
|
for submodule_name in connected_pipe_cls._submodules:
|
292
225
|
connected_pipe_submodules[submodule_name] = passed_submodules.get(prefix + submodule_name, None)
|
293
226
|
connected_pipe = getattr(model, connected_pipe_name)
|
294
|
-
connected_pipe_config = {}
|
295
|
-
connected_pipe_config.update(pipe_global_config)
|
296
|
-
connected_pipe_config.update(rbln_config[connected_pipe_name])
|
297
227
|
connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
|
298
|
-
connected_pipe,
|
228
|
+
connected_pipe,
|
229
|
+
connected_pipe_submodules,
|
230
|
+
model_save_dir,
|
231
|
+
getattr(rbln_config, connected_pipe_name),
|
232
|
+
prefix,
|
299
233
|
)
|
300
234
|
for submodule_name, compiled_submodule in connected_pipe_compiled_submodules.items():
|
301
235
|
compiled_submodules[prefix + submodule_name] = compiled_submodule
|
@@ -307,14 +241,19 @@ class RBLNDiffusionMixin:
|
|
307
241
|
model: torch.nn.Module,
|
308
242
|
passed_submodules: Dict[str, RBLNModel],
|
309
243
|
model_save_dir: Optional[PathLike],
|
310
|
-
rbln_config:
|
244
|
+
rbln_config: RBLNDiffusionMixinConfig,
|
311
245
|
prefix: Optional[str] = "",
|
312
246
|
) -> Dict[str, RBLNModel]:
|
313
247
|
compiled_submodules = {}
|
314
248
|
|
315
249
|
for submodule_name in cls._submodules:
|
316
250
|
submodule = passed_submodules.get(submodule_name) or getattr(model, submodule_name, None)
|
317
|
-
|
251
|
+
|
252
|
+
if getattr(rbln_config, submodule_name, None) is None:
|
253
|
+
raise ValueError(f"RBLN config for submodule {submodule_name} is not provided.")
|
254
|
+
|
255
|
+
submodule_rbln_cls: Type[RBLNModel] = getattr(rbln_config, submodule_name).rbln_model_cls
|
256
|
+
rbln_config = submodule_rbln_cls.update_rbln_config_using_pipe(model, rbln_config, submodule_name)
|
318
257
|
|
319
258
|
if submodule is None:
|
320
259
|
raise ValueError(f"submodule ({submodule_name}) cannot be accessed since it is not provided.")
|
@@ -325,7 +264,7 @@ class RBLNDiffusionMixin:
|
|
325
264
|
submodule = cls._compile_multicontrolnet(
|
326
265
|
controlnets=submodule,
|
327
266
|
model_save_dir=model_save_dir,
|
328
|
-
controlnet_rbln_config=
|
267
|
+
controlnet_rbln_config=getattr(rbln_config, submodule_name),
|
329
268
|
prefix=prefix,
|
330
269
|
)
|
331
270
|
elif isinstance(submodule, torch.nn.Module):
|
@@ -337,7 +276,7 @@ class RBLNDiffusionMixin:
|
|
337
276
|
model=submodule,
|
338
277
|
subfolder=subfolder,
|
339
278
|
model_save_dir=model_save_dir,
|
340
|
-
rbln_config=
|
279
|
+
rbln_config=getattr(rbln_config, submodule_name),
|
341
280
|
)
|
342
281
|
else:
|
343
282
|
raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
|
@@ -350,22 +289,24 @@ class RBLNDiffusionMixin:
|
|
350
289
|
cls,
|
351
290
|
controlnets: "MultiControlNetModel",
|
352
291
|
model_save_dir: Optional[PathLike],
|
353
|
-
controlnet_rbln_config:
|
292
|
+
controlnet_rbln_config: RBLNModelConfig,
|
354
293
|
prefix: Optional[str] = "",
|
355
294
|
):
|
356
295
|
# Compile multiple ControlNet models for a MultiControlNet setup
|
357
296
|
from .models.controlnet import RBLNControlNetModel
|
358
297
|
from .pipelines.controlnet import RBLNMultiControlNetModel
|
359
298
|
|
360
|
-
compiled_controlnets = [
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
299
|
+
compiled_controlnets = []
|
300
|
+
for i, controlnet in enumerate(controlnets.nets):
|
301
|
+
_controlnet_rbln_config = copy.deepcopy(controlnet_rbln_config)
|
302
|
+
compiled_controlnets.append(
|
303
|
+
RBLNControlNetModel.from_model(
|
304
|
+
model=controlnet,
|
305
|
+
subfolder=f"{prefix}controlnet" if i == 0 else f"{prefix}controlnet_{i}",
|
306
|
+
model_save_dir=model_save_dir,
|
307
|
+
rbln_config=_controlnet_rbln_config,
|
308
|
+
)
|
366
309
|
)
|
367
|
-
for i, controlnet in enumerate(controlnets.nets)
|
368
|
-
]
|
369
310
|
return RBLNMultiControlNetModel(compiled_controlnets)
|
370
311
|
|
371
312
|
@classmethod
|
@@ -412,7 +353,7 @@ class RBLNDiffusionMixin:
|
|
412
353
|
# overwrite to replace incorrect config
|
413
354
|
model.save_config(model_save_dir)
|
414
355
|
|
415
|
-
if rbln_config.
|
356
|
+
if rbln_config.optimize_host_memory is False:
|
416
357
|
# Keep compiled_model objs to further analysis. -> TODO: remove soon...
|
417
358
|
model.compiled_models = []
|
418
359
|
for name in cls._submodules:
|
@@ -441,9 +382,9 @@ class RBLNDiffusionMixin:
|
|
441
382
|
kwargs["height"] = compiled_image_size[0]
|
442
383
|
kwargs["width"] = compiled_image_size[1]
|
443
384
|
|
444
|
-
compiled_num_frames = self.unet.rbln_config.
|
385
|
+
compiled_num_frames = self.unet.rbln_config.num_frames
|
445
386
|
if compiled_num_frames is not None:
|
446
|
-
kwargs["num_frames"] =
|
387
|
+
kwargs["num_frames"] = compiled_num_frames
|
447
388
|
return kwargs
|
448
389
|
```
|
449
390
|
"""
|