optimum-rbln 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +17 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/models/controlnet.py +7 -3
- optimum/rbln/diffusers/models/unet_2d_condition.py +5 -5
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +23 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +19 -1
- optimum/rbln/modeling_base.py +162 -18
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +516 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +464 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +123 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +67 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +10 -257
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +12 -440
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +10 -325
- optimum/rbln/transformers/models/mistral/__init__.py +24 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +131 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
- optimum/rbln/utils/import_utils.py +1 -4
- optimum/rbln/utils/runtime_utils.py +2 -1
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +11 -5
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +48 -35
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -22,17 +22,17 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
"""RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
|
24
24
|
|
25
|
-
from pathlib import Path
|
26
|
-
from tempfile import TemporaryDirectory
|
27
25
|
from typing import Any, Callable, Dict, List, Optional, Union
|
28
26
|
|
29
27
|
import torch
|
30
28
|
import torch.nn.functional as F
|
31
|
-
from diffusers import StableDiffusionControlNetImg2ImgPipeline
|
29
|
+
from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionControlNetImg2ImgPipeline
|
32
30
|
from diffusers.image_processor import PipelineImageInput
|
31
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
33
32
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
34
33
|
from diffusers.utils import deprecate, logging
|
35
34
|
from diffusers.utils.torch_utils import is_compiled_module
|
35
|
+
from transformers import CLIPTextModel
|
36
36
|
|
37
37
|
from ....modeling_base import RBLNBaseModel
|
38
38
|
from ....transformers import RBLNCLIPTextModel
|
@@ -63,18 +63,40 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
|
|
63
63
|
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
64
64
|
"""
|
65
65
|
export = kwargs.pop("export", None)
|
66
|
+
vae = kwargs.pop("vae", None)
|
67
|
+
unet = kwargs.pop("unet", None)
|
66
68
|
text_encoder = kwargs.pop("text_encoder", None)
|
67
|
-
|
69
|
+
controlnet = kwargs.pop("controlnet", None)
|
70
|
+
model_save_dir = kwargs.pop("model_save_dir", None)
|
68
71
|
|
69
72
|
rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
|
70
73
|
|
71
74
|
kwargs_dict = {
|
72
75
|
"pretrained_model_name_or_path": model_id,
|
73
|
-
"text_encoder": text_encoder,
|
74
|
-
"controlnet": controlnets,
|
75
76
|
**kwargs,
|
76
77
|
}
|
77
78
|
|
79
|
+
kwargs_dict.update(
|
80
|
+
{
|
81
|
+
**({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
|
82
|
+
**({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
|
83
|
+
**(
|
84
|
+
{"text_encoder": text_encoder}
|
85
|
+
if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
|
86
|
+
else {}
|
87
|
+
),
|
88
|
+
**(
|
89
|
+
{"controlnet": controlnet}
|
90
|
+
if controlnet is not None
|
91
|
+
and (
|
92
|
+
isinstance(controlnet, ControlNetModel)
|
93
|
+
or all(isinstance(c, ControlNetModel) for c in controlnet)
|
94
|
+
)
|
95
|
+
else {}
|
96
|
+
),
|
97
|
+
}
|
98
|
+
)
|
99
|
+
|
78
100
|
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
79
101
|
|
80
102
|
if export is None or export is False:
|
@@ -84,64 +106,87 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
|
|
84
106
|
rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
85
107
|
)
|
86
108
|
|
87
|
-
save_dir = TemporaryDirectory()
|
88
|
-
save_dir_path = Path(save_dir.name)
|
89
|
-
|
90
|
-
model.save_pretrained(save_directory=save_dir_path, **kwargs)
|
91
|
-
|
92
109
|
# compile model, create runtime
|
93
|
-
vae
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
rbln_use_encode=True,
|
98
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
99
|
-
**rbln_config_kwargs,
|
100
|
-
**rbln_constructor_kwargs,
|
101
|
-
)
|
102
|
-
|
103
|
-
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
104
|
-
model_id=save_dir_path / "text_encoder",
|
105
|
-
export=True,
|
106
|
-
**rbln_config_kwargs,
|
107
|
-
**rbln_constructor_kwargs,
|
108
|
-
)
|
109
|
-
|
110
|
-
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
111
|
-
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
112
|
-
|
113
|
-
unet = RBLNUNet2DConditionModel.from_pretrained(
|
114
|
-
model_id=save_dir_path / "unet",
|
115
|
-
export=True,
|
116
|
-
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
117
|
-
rbln_batch_size=unet_batch_size,
|
118
|
-
rbln_use_encode=True,
|
119
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
120
|
-
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
121
|
-
**rbln_config_kwargs,
|
122
|
-
**rbln_constructor_kwargs,
|
123
|
-
)
|
124
|
-
|
125
|
-
if isinstance(controlnets, (list, tuple)):
|
126
|
-
controlnet = RBLNMultiControlNetModel.from_pretrained(
|
127
|
-
model_id=str(save_dir_path / "controlnet"),
|
110
|
+
if not isinstance(vae, RBLNAutoencoderKL):
|
111
|
+
vae = RBLNAutoencoderKL.from_pretrained(
|
112
|
+
model_id=model_id,
|
113
|
+
subfolder="vae",
|
128
114
|
export=True,
|
129
|
-
|
115
|
+
model_save_dir=model_save_dir,
|
116
|
+
rbln_unet_sample_size=model.unet.config.sample_size,
|
117
|
+
rbln_use_encode=True,
|
130
118
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
131
119
|
**rbln_config_kwargs,
|
132
120
|
**rbln_constructor_kwargs,
|
133
121
|
)
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
model_id=
|
122
|
+
|
123
|
+
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
124
|
+
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
125
|
+
model_id=model_id,
|
126
|
+
subfolder="text_encoder",
|
127
|
+
export=True,
|
128
|
+
model_save_dir=model_save_dir,
|
129
|
+
**rbln_config_kwargs,
|
130
|
+
**rbln_constructor_kwargs,
|
131
|
+
)
|
132
|
+
|
133
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
134
|
+
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
135
|
+
|
136
|
+
if not isinstance(unet, RBLNUNet2DConditionModel):
|
137
|
+
unet = RBLNUNet2DConditionModel.from_pretrained(
|
138
|
+
model_id=model_id,
|
139
|
+
subfolder="unet",
|
138
140
|
export=True,
|
141
|
+
model_save_dir=model_save_dir,
|
142
|
+
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
139
143
|
rbln_batch_size=unet_batch_size,
|
144
|
+
rbln_use_encode=True,
|
140
145
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
146
|
+
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
141
147
|
**rbln_config_kwargs,
|
142
148
|
**rbln_constructor_kwargs,
|
143
149
|
)
|
144
|
-
|
150
|
+
|
151
|
+
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
152
|
+
if isinstance(controlnet, (list, tuple)):
|
153
|
+
multicontrolnet = []
|
154
|
+
for i, cid in enumerate(controlnet):
|
155
|
+
subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
|
156
|
+
multicontrolnet.append(
|
157
|
+
RBLNControlNetModel.from_pretrained(
|
158
|
+
model_id=cid.config._name_or_path,
|
159
|
+
subfolder=subfolder_name,
|
160
|
+
export=True,
|
161
|
+
model_save_dir=model_save_dir,
|
162
|
+
rbln_batch_size=unet_batch_size,
|
163
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
164
|
+
**rbln_config_kwargs,
|
165
|
+
**rbln_constructor_kwargs,
|
166
|
+
)
|
167
|
+
)
|
168
|
+
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
169
|
+
controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
|
170
|
+
else:
|
171
|
+
controlnet = RBLNControlNetModel.from_pretrained(
|
172
|
+
model_id=controlnet.config._name_or_path,
|
173
|
+
subfolder="controlnet",
|
174
|
+
export=True,
|
175
|
+
model_save_dir=model_save_dir,
|
176
|
+
rbln_batch_size=unet_batch_size,
|
177
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
178
|
+
**rbln_config_kwargs,
|
179
|
+
**rbln_constructor_kwargs,
|
180
|
+
)
|
181
|
+
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
182
|
+
|
183
|
+
if model_save_dir is not None:
|
184
|
+
# To skip saving original pytorch modules
|
185
|
+
del (model.vae, model.text_encoder, model.unet, model.controlnet)
|
186
|
+
|
187
|
+
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
188
|
+
# So config must be saved again, later.
|
189
|
+
model.save_pretrained(model_save_dir)
|
145
190
|
|
146
191
|
# replace modules
|
147
192
|
model.vae = vae
|
@@ -158,16 +203,23 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
|
|
158
203
|
}
|
159
204
|
model.register_to_config(**update_dict)
|
160
205
|
|
161
|
-
|
206
|
+
if model_save_dir is not None:
|
207
|
+
# overwrite to replace incorrect config
|
208
|
+
model.save_config(model_save_dir)
|
162
209
|
|
210
|
+
# use for CI to access each compiled model
|
163
211
|
if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
|
164
212
|
model.compiled_models = [
|
165
213
|
vae.compiled_models[0],
|
166
214
|
vae.compiled_models[1],
|
167
215
|
text_encoder.compiled_models[0],
|
168
216
|
unet.compiled_models[0],
|
169
|
-
controlnet.compiled_models[0],
|
170
217
|
]
|
218
|
+
if isinstance(controlnet, RBLNMultiControlNetModel):
|
219
|
+
for c_model in controlnet.nets:
|
220
|
+
model.compiled_models.append(c_model.compiled_models[0])
|
221
|
+
else:
|
222
|
+
model.compiled_models.append(controlnet.compiled_models[0])
|
171
223
|
|
172
224
|
return model
|
173
225
|
|
@@ -22,17 +22,17 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
"""RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
|
24
24
|
|
25
|
-
from pathlib import Path
|
26
|
-
from tempfile import TemporaryDirectory
|
27
25
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
28
26
|
|
29
27
|
import torch
|
30
28
|
import torch.nn.functional as F
|
31
|
-
from diffusers import StableDiffusionXLControlNetPipeline
|
29
|
+
from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetPipeline
|
32
30
|
from diffusers.image_processor import PipelineImageInput
|
31
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
33
32
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
34
33
|
from diffusers.utils import deprecate, logging
|
35
34
|
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
35
|
+
from transformers import CLIPTextModel
|
36
36
|
|
37
37
|
from ....modeling_base import RBLNBaseModel
|
38
38
|
from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
@@ -63,103 +63,152 @@ class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipelin
|
|
63
63
|
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
64
64
|
"""
|
65
65
|
export = kwargs.pop("export", None)
|
66
|
-
text_encoder = kwargs.pop("text_encoder", None)
|
67
|
-
controlnets = kwargs.pop("controlnet", None)
|
68
66
|
vae = kwargs.pop("vae", None)
|
67
|
+
unet = kwargs.pop("unet", None)
|
68
|
+
text_encoder = kwargs.pop("text_encoder", None)
|
69
|
+
text_encoder_2 = kwargs.pop("text_encoder_2", None)
|
70
|
+
controlnet = kwargs.pop("controlnet", None)
|
71
|
+
model_save_dir = kwargs.pop("model_save_dir", None)
|
69
72
|
|
70
73
|
rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
|
74
|
+
|
71
75
|
kwargs_dict = {
|
72
76
|
"pretrained_model_name_or_path": model_id,
|
73
|
-
"vae": vae,
|
74
|
-
"controlnet": controlnets,
|
75
|
-
"text_encoder": text_encoder,
|
76
77
|
**kwargs,
|
77
78
|
}
|
78
79
|
|
80
|
+
kwargs_dict.update(
|
81
|
+
{
|
82
|
+
**({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
|
83
|
+
**({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
|
84
|
+
**(
|
85
|
+
{"text_encoder": text_encoder}
|
86
|
+
if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
|
87
|
+
else {}
|
88
|
+
),
|
89
|
+
**(
|
90
|
+
{"controlnet": controlnet}
|
91
|
+
if controlnet is not None
|
92
|
+
and (
|
93
|
+
isinstance(controlnet, ControlNetModel)
|
94
|
+
or all(isinstance(c, ControlNetModel) for c in controlnet)
|
95
|
+
)
|
96
|
+
else {}
|
97
|
+
),
|
98
|
+
}
|
99
|
+
)
|
100
|
+
|
79
101
|
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
80
102
|
|
81
103
|
if export is None or export is False:
|
82
104
|
return model
|
83
105
|
|
84
|
-
save_dir = TemporaryDirectory()
|
85
|
-
save_dir_path = Path(save_dir.name)
|
86
|
-
|
87
|
-
model.save_pretrained(save_directory=save_dir_path, **kwargs)
|
88
|
-
|
89
106
|
do_classifier_free_guidance = (
|
90
107
|
rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
91
108
|
)
|
92
109
|
|
93
|
-
vae
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
subfolder="text_encoder",
|
106
|
-
export=True,
|
107
|
-
**rbln_config_kwargs,
|
108
|
-
**rbln_constructor_kwargs,
|
109
|
-
)
|
110
|
-
text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
|
111
|
-
model_id=model_id,
|
112
|
-
subfolder="text_encoder_2",
|
113
|
-
export=True,
|
114
|
-
**rbln_config_kwargs,
|
115
|
-
**rbln_constructor_kwargs,
|
116
|
-
)
|
117
|
-
|
118
|
-
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
119
|
-
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
110
|
+
if not isinstance(vae, RBLNAutoencoderKL):
|
111
|
+
vae = RBLNAutoencoderKL.from_pretrained(
|
112
|
+
model_id=model_id,
|
113
|
+
subfolder="vae",
|
114
|
+
export=True,
|
115
|
+
model_save_dir=model_save_dir,
|
116
|
+
rbln_unet_sample_size=model.unet.config.sample_size,
|
117
|
+
rbln_use_encode=False,
|
118
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
119
|
+
**rbln_config_kwargs,
|
120
|
+
**rbln_constructor_kwargs,
|
121
|
+
)
|
120
122
|
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
131
|
-
**rbln_config_kwargs,
|
132
|
-
**rbln_constructor_kwargs,
|
133
|
-
)
|
123
|
+
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
124
|
+
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
125
|
+
model_id=model_id,
|
126
|
+
subfolder="text_encoder",
|
127
|
+
export=True,
|
128
|
+
model_save_dir=model_save_dir,
|
129
|
+
**rbln_config_kwargs,
|
130
|
+
**rbln_constructor_kwargs,
|
131
|
+
)
|
134
132
|
|
135
|
-
if isinstance(
|
136
|
-
|
137
|
-
model_id=
|
133
|
+
if not isinstance(text_encoder_2, RBLNCLIPTextModel):
|
134
|
+
text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
|
135
|
+
model_id=model_id,
|
136
|
+
subfolder="text_encoder_2",
|
138
137
|
export=True,
|
139
|
-
|
140
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
138
|
+
model_save_dir=model_save_dir,
|
141
139
|
**rbln_config_kwargs,
|
142
140
|
**rbln_constructor_kwargs,
|
143
141
|
)
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
142
|
+
|
143
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
144
|
+
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
145
|
+
|
146
|
+
if not isinstance(unet, RBLNUNet2DConditionModel):
|
147
|
+
unet = RBLNUNet2DConditionModel.from_pretrained(
|
148
|
+
model_id=model_id,
|
149
|
+
subfolder="unet",
|
148
150
|
export=True,
|
149
|
-
|
151
|
+
model_save_dir=model_save_dir,
|
152
|
+
rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
|
150
153
|
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
154
|
+
rbln_batch_size=unet_batch_size,
|
155
|
+
rbln_use_encode=False,
|
151
156
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
157
|
+
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
152
158
|
**rbln_config_kwargs,
|
153
159
|
**rbln_constructor_kwargs,
|
154
160
|
)
|
155
|
-
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
156
161
|
|
162
|
+
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
163
|
+
if isinstance(controlnet, (list, tuple)):
|
164
|
+
multicontrolnet = []
|
165
|
+
for i, cid in enumerate(controlnet):
|
166
|
+
subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
|
167
|
+
multicontrolnet.append(
|
168
|
+
RBLNControlNetModel.from_pretrained(
|
169
|
+
model_id=cid.config._name_or_path,
|
170
|
+
subfolder=subfolder_name,
|
171
|
+
export=True,
|
172
|
+
model_save_dir=model_save_dir,
|
173
|
+
rbln_batch_size=unet_batch_size,
|
174
|
+
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
175
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
176
|
+
**rbln_config_kwargs,
|
177
|
+
**rbln_constructor_kwargs,
|
178
|
+
)
|
179
|
+
)
|
180
|
+
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
181
|
+
controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
|
182
|
+
else:
|
183
|
+
controlnet = RBLNControlNetModel.from_pretrained(
|
184
|
+
model_id=controlnet.config._name_or_path,
|
185
|
+
subfolder="controlnet",
|
186
|
+
export=True,
|
187
|
+
model_save_dir=model_save_dir,
|
188
|
+
rbln_batch_size=unet_batch_size,
|
189
|
+
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
190
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
191
|
+
**rbln_config_kwargs,
|
192
|
+
**rbln_constructor_kwargs,
|
193
|
+
)
|
194
|
+
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
195
|
+
|
196
|
+
if model_save_dir is not None:
|
197
|
+
# To skip saving original pytorch modules
|
198
|
+
del (model.vae, model.text_encoder, model.unet, model.controlnet)
|
199
|
+
|
200
|
+
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
201
|
+
# So config must be saved again, later.
|
202
|
+
model.save_pretrained(model_save_dir)
|
203
|
+
|
204
|
+
# replace modules
|
157
205
|
model.vae = vae
|
158
206
|
model.text_encoder = text_encoder
|
159
207
|
model.unet = unet
|
160
208
|
model.text_encoder_2 = text_encoder_2
|
161
209
|
model.controlnet = controlnet
|
162
210
|
|
211
|
+
# update config to be able to load from file
|
163
212
|
update_dict = {
|
164
213
|
"vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
165
214
|
"text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
@@ -169,13 +218,23 @@ class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipelin
|
|
169
218
|
}
|
170
219
|
model.register_to_config(**update_dict)
|
171
220
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
221
|
+
if model_save_dir is not None:
|
222
|
+
# overwrite to replace incorrect config
|
223
|
+
model.save_config(model_save_dir)
|
224
|
+
|
225
|
+
# use for CI to access each compiled model
|
226
|
+
if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
|
227
|
+
model.compiled_models = [
|
228
|
+
vae.compiled_models[0],
|
229
|
+
text_encoder.compiled_models[0],
|
230
|
+
text_encoder_2.compiled_models[0],
|
231
|
+
unet.compiled_models[0],
|
232
|
+
]
|
233
|
+
if isinstance(controlnet, RBLNMultiControlNetModel):
|
234
|
+
for c_model in controlnet.nets:
|
235
|
+
model.compiled_models.append(c_model.compiled_models[0])
|
236
|
+
else:
|
237
|
+
model.compiled_models.append(controlnet.compiled_models[0])
|
179
238
|
|
180
239
|
return model
|
181
240
|
|