optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +21 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +282 -100
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
- optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
- optimum_rbln-0.1.8.dist-info/RECORD +73 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
- optimum_rbln-0.1.4.dist-info/RECORD +0 -63
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -22,18 +22,18 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
"""RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
|
24
24
|
|
25
|
-
from pathlib import Path
|
26
|
-
from tempfile import TemporaryDirectory
|
27
25
|
from typing import Any, Callable, Dict, List, Optional, Union
|
28
26
|
|
29
27
|
import torch
|
30
28
|
import torch.nn.functional as F
|
31
|
-
from diffusers import StableDiffusionControlNetPipeline
|
29
|
+
from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline
|
32
30
|
from diffusers.image_processor import PipelineImageInput
|
31
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
33
32
|
from diffusers.pipelines.controlnet.pipeline_controlnet import retrieve_timesteps
|
34
33
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
35
34
|
from diffusers.utils import deprecate, logging
|
36
35
|
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
36
|
+
from transformers import CLIPTextModel
|
37
37
|
|
38
38
|
from ....modeling_base import RBLNBaseModel
|
39
39
|
from ....transformers import RBLNCLIPTextModel
|
@@ -64,18 +64,40 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
64
64
|
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
65
65
|
"""
|
66
66
|
export = kwargs.pop("export", None)
|
67
|
+
vae = kwargs.pop("vae", None)
|
68
|
+
unet = kwargs.pop("unet", None)
|
67
69
|
text_encoder = kwargs.pop("text_encoder", None)
|
68
|
-
|
70
|
+
controlnet = kwargs.pop("controlnet", None)
|
71
|
+
model_save_dir = kwargs.pop("model_save_dir", None)
|
69
72
|
|
70
73
|
rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
|
71
74
|
|
72
75
|
kwargs_dict = {
|
73
76
|
"pretrained_model_name_or_path": model_id,
|
74
|
-
"text_encoder": text_encoder,
|
75
|
-
"controlnet": controlnets,
|
76
77
|
**kwargs,
|
77
78
|
}
|
78
79
|
|
80
|
+
kwargs_dict.update(
|
81
|
+
{
|
82
|
+
**({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
|
83
|
+
**({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
|
84
|
+
**(
|
85
|
+
{"text_encoder": text_encoder}
|
86
|
+
if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
|
87
|
+
else {}
|
88
|
+
),
|
89
|
+
**(
|
90
|
+
{"controlnet": controlnet}
|
91
|
+
if controlnet is not None
|
92
|
+
and (
|
93
|
+
isinstance(controlnet, ControlNetModel)
|
94
|
+
or all(isinstance(c, ControlNetModel) for c in controlnet)
|
95
|
+
)
|
96
|
+
else {}
|
97
|
+
),
|
98
|
+
}
|
99
|
+
)
|
100
|
+
|
79
101
|
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
80
102
|
|
81
103
|
if export is None or export is False:
|
@@ -85,64 +107,87 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
85
107
|
rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
86
108
|
)
|
87
109
|
|
88
|
-
save_dir = TemporaryDirectory()
|
89
|
-
save_dir_path = Path(save_dir.name)
|
90
|
-
|
91
|
-
model.save_pretrained(save_directory=save_dir_path, **kwargs)
|
92
|
-
|
93
110
|
# compile model, create runtime
|
94
|
-
vae
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
rbln_use_encode=False,
|
99
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
100
|
-
**rbln_config_kwargs,
|
101
|
-
**rbln_constructor_kwargs,
|
102
|
-
)
|
103
|
-
|
104
|
-
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
105
|
-
model_id=save_dir_path / "text_encoder",
|
106
|
-
export=True,
|
107
|
-
**rbln_config_kwargs,
|
108
|
-
**rbln_constructor_kwargs,
|
109
|
-
)
|
110
|
-
|
111
|
-
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
112
|
-
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
113
|
-
|
114
|
-
unet = RBLNUNet2DConditionModel.from_pretrained(
|
115
|
-
model_id=save_dir_path / "unet",
|
116
|
-
export=True,
|
117
|
-
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
118
|
-
rbln_batch_size=unet_batch_size,
|
119
|
-
rbln_use_encode=False,
|
120
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
121
|
-
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
122
|
-
**rbln_config_kwargs,
|
123
|
-
**rbln_constructor_kwargs,
|
124
|
-
)
|
125
|
-
|
126
|
-
if isinstance(controlnets, (list, tuple)):
|
127
|
-
controlnet = RBLNMultiControlNetModel.from_pretrained(
|
128
|
-
model_id=str(save_dir_path / "controlnet"),
|
111
|
+
if not isinstance(vae, RBLNAutoencoderKL):
|
112
|
+
vae = RBLNAutoencoderKL.from_pretrained(
|
113
|
+
model_id=model_id,
|
114
|
+
subfolder="vae",
|
129
115
|
export=True,
|
130
|
-
|
116
|
+
model_save_dir=model_save_dir,
|
117
|
+
rbln_unet_sample_size=model.unet.config.sample_size,
|
118
|
+
rbln_use_encode=False,
|
131
119
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
132
120
|
**rbln_config_kwargs,
|
133
121
|
**rbln_constructor_kwargs,
|
134
122
|
)
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
model_id=
|
123
|
+
|
124
|
+
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
125
|
+
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
126
|
+
model_id=model_id,
|
127
|
+
subfolder="text_encoder",
|
139
128
|
export=True,
|
129
|
+
model_save_dir=model_save_dir,
|
130
|
+
**rbln_config_kwargs,
|
131
|
+
**rbln_constructor_kwargs,
|
132
|
+
)
|
133
|
+
|
134
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
135
|
+
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
136
|
+
|
137
|
+
if not isinstance(unet, RBLNUNet2DConditionModel):
|
138
|
+
unet = RBLNUNet2DConditionModel.from_pretrained(
|
139
|
+
model_id=model_id,
|
140
|
+
subfolder="unet",
|
141
|
+
export=True,
|
142
|
+
model_save_dir=model_save_dir,
|
143
|
+
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
140
144
|
rbln_batch_size=unet_batch_size,
|
145
|
+
rbln_use_encode=False,
|
141
146
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
147
|
+
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
142
148
|
**rbln_config_kwargs,
|
143
149
|
**rbln_constructor_kwargs,
|
144
150
|
)
|
145
|
-
|
151
|
+
|
152
|
+
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
153
|
+
if isinstance(controlnet, (list, tuple)):
|
154
|
+
multicontrolnet = []
|
155
|
+
for i, cid in enumerate(controlnet):
|
156
|
+
subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
|
157
|
+
multicontrolnet.append(
|
158
|
+
RBLNControlNetModel.from_pretrained(
|
159
|
+
model_id=cid.config._name_or_path,
|
160
|
+
subfolder=subfolder_name,
|
161
|
+
export=True,
|
162
|
+
model_save_dir=model_save_dir,
|
163
|
+
rbln_batch_size=unet_batch_size,
|
164
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
165
|
+
**rbln_config_kwargs,
|
166
|
+
**rbln_constructor_kwargs,
|
167
|
+
)
|
168
|
+
)
|
169
|
+
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
170
|
+
controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
|
171
|
+
else:
|
172
|
+
controlnet = RBLNControlNetModel.from_pretrained(
|
173
|
+
model_id=controlnet.config._name_or_path,
|
174
|
+
subfolder="controlnet",
|
175
|
+
export=True,
|
176
|
+
model_save_dir=model_save_dir,
|
177
|
+
rbln_batch_size=unet_batch_size,
|
178
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
179
|
+
**rbln_config_kwargs,
|
180
|
+
**rbln_constructor_kwargs,
|
181
|
+
)
|
182
|
+
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
183
|
+
|
184
|
+
if model_save_dir is not None:
|
185
|
+
# To skip saving original pytorch modules
|
186
|
+
del (model.vae, model.text_encoder, model.unet, model.controlnet)
|
187
|
+
|
188
|
+
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
189
|
+
# So config must be saved again, later.
|
190
|
+
model.save_pretrained(model_save_dir)
|
146
191
|
|
147
192
|
# replace modules
|
148
193
|
model.vae = vae
|
@@ -159,7 +204,18 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
159
204
|
}
|
160
205
|
model.register_to_config(**update_dict)
|
161
206
|
|
162
|
-
|
207
|
+
if model_save_dir is not None:
|
208
|
+
# overwrite to replace incorrect config
|
209
|
+
model.save_config(model_save_dir)
|
210
|
+
|
211
|
+
# use for CI to access each compiled model
|
212
|
+
if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
|
213
|
+
model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
|
214
|
+
if isinstance(controlnet, RBLNMultiControlNetModel):
|
215
|
+
for c_model in controlnet.nets:
|
216
|
+
model.compiled_models.append(c_model.compiled_models[0])
|
217
|
+
else:
|
218
|
+
model.compiled_models.append(controlnet.compiled_models[0])
|
163
219
|
|
164
220
|
return model
|
165
221
|
|
@@ -22,17 +22,17 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
"""RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
|
24
24
|
|
25
|
-
from pathlib import Path
|
26
|
-
from tempfile import TemporaryDirectory
|
27
25
|
from typing import Any, Callable, Dict, List, Optional, Union
|
28
26
|
|
29
27
|
import torch
|
30
28
|
import torch.nn.functional as F
|
31
|
-
from diffusers import StableDiffusionControlNetImg2ImgPipeline
|
29
|
+
from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionControlNetImg2ImgPipeline
|
32
30
|
from diffusers.image_processor import PipelineImageInput
|
31
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
33
32
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
34
33
|
from diffusers.utils import deprecate, logging
|
35
34
|
from diffusers.utils.torch_utils import is_compiled_module
|
35
|
+
from transformers import CLIPTextModel
|
36
36
|
|
37
37
|
from ....modeling_base import RBLNBaseModel
|
38
38
|
from ....transformers import RBLNCLIPTextModel
|
@@ -63,18 +63,40 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
|
|
63
63
|
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
64
64
|
"""
|
65
65
|
export = kwargs.pop("export", None)
|
66
|
+
vae = kwargs.pop("vae", None)
|
67
|
+
unet = kwargs.pop("unet", None)
|
66
68
|
text_encoder = kwargs.pop("text_encoder", None)
|
67
|
-
|
69
|
+
controlnet = kwargs.pop("controlnet", None)
|
70
|
+
model_save_dir = kwargs.pop("model_save_dir", None)
|
68
71
|
|
69
72
|
rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
|
70
73
|
|
71
74
|
kwargs_dict = {
|
72
75
|
"pretrained_model_name_or_path": model_id,
|
73
|
-
"text_encoder": text_encoder,
|
74
|
-
"controlnet": controlnets,
|
75
76
|
**kwargs,
|
76
77
|
}
|
77
78
|
|
79
|
+
kwargs_dict.update(
|
80
|
+
{
|
81
|
+
**({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
|
82
|
+
**({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
|
83
|
+
**(
|
84
|
+
{"text_encoder": text_encoder}
|
85
|
+
if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
|
86
|
+
else {}
|
87
|
+
),
|
88
|
+
**(
|
89
|
+
{"controlnet": controlnet}
|
90
|
+
if controlnet is not None
|
91
|
+
and (
|
92
|
+
isinstance(controlnet, ControlNetModel)
|
93
|
+
or all(isinstance(c, ControlNetModel) for c in controlnet)
|
94
|
+
)
|
95
|
+
else {}
|
96
|
+
),
|
97
|
+
}
|
98
|
+
)
|
99
|
+
|
78
100
|
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
79
101
|
|
80
102
|
if export is None or export is False:
|
@@ -84,64 +106,87 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
|
|
84
106
|
rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
85
107
|
)
|
86
108
|
|
87
|
-
save_dir = TemporaryDirectory()
|
88
|
-
save_dir_path = Path(save_dir.name)
|
89
|
-
|
90
|
-
model.save_pretrained(save_directory=save_dir_path, **kwargs)
|
91
|
-
|
92
109
|
# compile model, create runtime
|
93
|
-
vae
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
rbln_use_encode=True,
|
98
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
99
|
-
**rbln_config_kwargs,
|
100
|
-
**rbln_constructor_kwargs,
|
101
|
-
)
|
102
|
-
|
103
|
-
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
104
|
-
model_id=save_dir_path / "text_encoder",
|
105
|
-
export=True,
|
106
|
-
**rbln_config_kwargs,
|
107
|
-
**rbln_constructor_kwargs,
|
108
|
-
)
|
109
|
-
|
110
|
-
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
111
|
-
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
112
|
-
|
113
|
-
unet = RBLNUNet2DConditionModel.from_pretrained(
|
114
|
-
model_id=save_dir_path / "unet",
|
115
|
-
export=True,
|
116
|
-
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
117
|
-
rbln_batch_size=unet_batch_size,
|
118
|
-
rbln_use_encode=True,
|
119
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
120
|
-
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
121
|
-
**rbln_config_kwargs,
|
122
|
-
**rbln_constructor_kwargs,
|
123
|
-
)
|
124
|
-
|
125
|
-
if isinstance(controlnets, (list, tuple)):
|
126
|
-
controlnet = RBLNMultiControlNetModel.from_pretrained(
|
127
|
-
model_id=str(save_dir_path / "controlnet"),
|
110
|
+
if not isinstance(vae, RBLNAutoencoderKL):
|
111
|
+
vae = RBLNAutoencoderKL.from_pretrained(
|
112
|
+
model_id=model_id,
|
113
|
+
subfolder="vae",
|
128
114
|
export=True,
|
129
|
-
|
115
|
+
model_save_dir=model_save_dir,
|
116
|
+
rbln_unet_sample_size=model.unet.config.sample_size,
|
117
|
+
rbln_use_encode=True,
|
130
118
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
131
119
|
**rbln_config_kwargs,
|
132
120
|
**rbln_constructor_kwargs,
|
133
121
|
)
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
model_id=
|
122
|
+
|
123
|
+
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
124
|
+
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
125
|
+
model_id=model_id,
|
126
|
+
subfolder="text_encoder",
|
127
|
+
export=True,
|
128
|
+
model_save_dir=model_save_dir,
|
129
|
+
**rbln_config_kwargs,
|
130
|
+
**rbln_constructor_kwargs,
|
131
|
+
)
|
132
|
+
|
133
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
134
|
+
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
135
|
+
|
136
|
+
if not isinstance(unet, RBLNUNet2DConditionModel):
|
137
|
+
unet = RBLNUNet2DConditionModel.from_pretrained(
|
138
|
+
model_id=model_id,
|
139
|
+
subfolder="unet",
|
138
140
|
export=True,
|
141
|
+
model_save_dir=model_save_dir,
|
142
|
+
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
139
143
|
rbln_batch_size=unet_batch_size,
|
144
|
+
rbln_use_encode=True,
|
140
145
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
146
|
+
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
141
147
|
**rbln_config_kwargs,
|
142
148
|
**rbln_constructor_kwargs,
|
143
149
|
)
|
144
|
-
|
150
|
+
|
151
|
+
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
152
|
+
if isinstance(controlnet, (list, tuple)):
|
153
|
+
multicontrolnet = []
|
154
|
+
for i, cid in enumerate(controlnet):
|
155
|
+
subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
|
156
|
+
multicontrolnet.append(
|
157
|
+
RBLNControlNetModel.from_pretrained(
|
158
|
+
model_id=cid.config._name_or_path,
|
159
|
+
subfolder=subfolder_name,
|
160
|
+
export=True,
|
161
|
+
model_save_dir=model_save_dir,
|
162
|
+
rbln_batch_size=unet_batch_size,
|
163
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
164
|
+
**rbln_config_kwargs,
|
165
|
+
**rbln_constructor_kwargs,
|
166
|
+
)
|
167
|
+
)
|
168
|
+
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
169
|
+
controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
|
170
|
+
else:
|
171
|
+
controlnet = RBLNControlNetModel.from_pretrained(
|
172
|
+
model_id=controlnet.config._name_or_path,
|
173
|
+
subfolder="controlnet",
|
174
|
+
export=True,
|
175
|
+
model_save_dir=model_save_dir,
|
176
|
+
rbln_batch_size=unet_batch_size,
|
177
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
178
|
+
**rbln_config_kwargs,
|
179
|
+
**rbln_constructor_kwargs,
|
180
|
+
)
|
181
|
+
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
182
|
+
|
183
|
+
if model_save_dir is not None:
|
184
|
+
# To skip saving original pytorch modules
|
185
|
+
del (model.vae, model.text_encoder, model.unet, model.controlnet)
|
186
|
+
|
187
|
+
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
188
|
+
# So config must be saved again, later.
|
189
|
+
model.save_pretrained(model_save_dir)
|
145
190
|
|
146
191
|
# replace modules
|
147
192
|
model.vae = vae
|
@@ -158,7 +203,23 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
|
|
158
203
|
}
|
159
204
|
model.register_to_config(**update_dict)
|
160
205
|
|
161
|
-
|
206
|
+
if model_save_dir is not None:
|
207
|
+
# overwrite to replace incorrect config
|
208
|
+
model.save_config(model_save_dir)
|
209
|
+
|
210
|
+
# use for CI to access each compiled model
|
211
|
+
if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
|
212
|
+
model.compiled_models = [
|
213
|
+
vae.compiled_models[0],
|
214
|
+
vae.compiled_models[1],
|
215
|
+
text_encoder.compiled_models[0],
|
216
|
+
unet.compiled_models[0],
|
217
|
+
]
|
218
|
+
if isinstance(controlnet, RBLNMultiControlNetModel):
|
219
|
+
for c_model in controlnet.nets:
|
220
|
+
model.compiled_models.append(c_model.compiled_models[0])
|
221
|
+
else:
|
222
|
+
model.compiled_models.append(controlnet.compiled_models[0])
|
162
223
|
|
163
224
|
return model
|
164
225
|
|