optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +110 -0
- optimum/rbln/transformers/__init__.py +6 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -35,6 +35,9 @@ _import_structure = {
|
|
35
35
|
"RBLNResNetForImageClassification",
|
36
36
|
"RBLNT5ForConditionalGeneration",
|
37
37
|
"RBLNBartForConditionalGeneration",
|
38
|
+
"RBLNXLMRobertaForSequenceClassification",
|
39
|
+
"RBLNRobertaForSequenceClassification",
|
40
|
+
"RBLNRobertaForMaskedLM",
|
38
41
|
],
|
39
42
|
"modeling_base": [
|
40
43
|
"RBLNBaseModel",
|
@@ -42,6 +45,8 @@ _import_structure = {
|
|
42
45
|
"RBLNModelForQuestionAnswering",
|
43
46
|
"RBLNModelForAudioClassification",
|
44
47
|
"RBLNModelForImageClassification",
|
48
|
+
"RBLNModelForSequenceClassification",
|
49
|
+
"RBLNModelForMaskedLM",
|
45
50
|
],
|
46
51
|
"modeling_seq2seq": [
|
47
52
|
"RBLNModelForSeq2SeqLM",
|
@@ -51,11 +56,13 @@ _import_structure = {
|
|
51
56
|
"RBLNCLIPTextModel",
|
52
57
|
"RBLNCLIPTextModelWithProjection",
|
53
58
|
"RBLNDPTForDepthEstimation",
|
59
|
+
"RBLNGemmaForCausalLM",
|
54
60
|
"RBLNGPT2LMHeadModel",
|
55
61
|
"RBLNWav2Vec2ForCTC",
|
56
62
|
"RBLNLlamaForCausalLM",
|
57
63
|
"RBLNMidmLMHeadModel",
|
58
64
|
"RBLNWhisperForConditionalGeneration",
|
65
|
+
"RBLNXLMRobertaModel",
|
59
66
|
],
|
60
67
|
"diffusers": [
|
61
68
|
"RBLNStableDiffusionPipeline",
|
@@ -94,14 +101,19 @@ if TYPE_CHECKING:
|
|
94
101
|
RBLNBartForConditionalGeneration,
|
95
102
|
RBLNBertForQuestionAnswering,
|
96
103
|
RBLNResNetForImageClassification,
|
104
|
+
RBLNRobertaForMaskedLM,
|
105
|
+
RBLNRobertaForSequenceClassification,
|
97
106
|
RBLNT5ForConditionalGeneration,
|
107
|
+
RBLNXLMRobertaForSequenceClassification,
|
98
108
|
)
|
99
109
|
from .modeling_base import (
|
100
110
|
RBLNBaseModel,
|
101
111
|
RBLNModel,
|
102
112
|
RBLNModelForAudioClassification,
|
103
113
|
RBLNModelForImageClassification,
|
114
|
+
RBLNModelForMaskedLM,
|
104
115
|
RBLNModelForQuestionAnswering,
|
116
|
+
RBLNModelForSequenceClassification,
|
105
117
|
)
|
106
118
|
from .modeling_config import RBLNConfig, RBLNRuntimeConfig
|
107
119
|
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|
@@ -110,11 +122,13 @@ if TYPE_CHECKING:
|
|
110
122
|
RBLNCLIPTextModel,
|
111
123
|
RBLNCLIPTextModelWithProjection,
|
112
124
|
RBLNDPTForDepthEstimation,
|
125
|
+
RBLNGemmaForCausalLM,
|
113
126
|
RBLNGPT2LMHeadModel,
|
114
127
|
RBLNLlamaForCausalLM,
|
115
128
|
RBLNMidmLMHeadModel,
|
116
129
|
RBLNWav2Vec2ForCTC,
|
117
130
|
RBLNWhisperForConditionalGeneration,
|
131
|
+
RBLNXLMRobertaModel,
|
118
132
|
)
|
119
133
|
else:
|
120
134
|
import sys
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.1.
|
1
|
+
__version__ = '0.1.8'
|
@@ -120,6 +120,9 @@ class RBLNControlNetModel(RBLNModel):
|
|
120
120
|
model_name_or_path: Union[str, Path],
|
121
121
|
**kwargs,
|
122
122
|
):
|
123
|
+
if "subfolder" in kwargs:
|
124
|
+
del kwargs["subfolder"]
|
125
|
+
|
123
126
|
return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
124
127
|
|
125
128
|
tasktmp = TasksManager.get_model_from_task
|
@@ -244,6 +244,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
244
244
|
for i in range(3)
|
245
245
|
]
|
246
246
|
)
|
247
|
+
if len(model_config.block_out_channels) > 1:
|
247
248
|
input_info.append(
|
248
249
|
(
|
249
250
|
"down_block_additional_residuals_3",
|
@@ -251,7 +252,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
251
252
|
"float32",
|
252
253
|
)
|
253
254
|
)
|
254
|
-
if len(model_config.block_out_channels) > 1:
|
255
255
|
input_info.extend(
|
256
256
|
[
|
257
257
|
(
|
@@ -262,6 +262,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
262
262
|
for i in range(4, 6)
|
263
263
|
]
|
264
264
|
)
|
265
|
+
if len(model_config.block_out_channels) > 2:
|
265
266
|
input_info.append(
|
266
267
|
(
|
267
268
|
f"down_block_additional_residuals_{6}",
|
@@ -269,7 +270,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
269
270
|
"float32",
|
270
271
|
)
|
271
272
|
)
|
272
|
-
if len(model_config.block_out_channels) > 2:
|
273
273
|
input_info.extend(
|
274
274
|
[
|
275
275
|
(
|
@@ -24,56 +24,32 @@
|
|
24
24
|
import logging
|
25
25
|
import os
|
26
26
|
from pathlib import Path
|
27
|
-
from tempfile import TemporaryDirectory
|
28
27
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
29
28
|
|
30
|
-
import rebel
|
31
29
|
import torch
|
32
30
|
from diffusers import ControlNetModel
|
33
31
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
34
32
|
from optimum.exporters import TasksManager
|
35
|
-
from transformers import AutoConfig, AutoModel
|
33
|
+
from transformers import AutoConfig, AutoModel
|
36
34
|
|
37
|
-
from ....modeling_base import
|
38
|
-
from ....modeling_config import
|
35
|
+
from ....modeling_base import RBLNModel
|
36
|
+
from ....modeling_config import RBLNConfig
|
39
37
|
from ...models.controlnet import RBLNControlNetModel
|
40
38
|
|
41
39
|
|
42
40
|
logger = logging.getLogger(__name__)
|
43
41
|
|
44
42
|
if TYPE_CHECKING:
|
45
|
-
|
46
|
-
PretrainedConfig,
|
47
|
-
PreTrainedModel,
|
48
|
-
)
|
43
|
+
pass
|
49
44
|
|
50
45
|
|
51
|
-
class RBLNMultiControlNetModel(
|
52
|
-
model_type = "rbln_model"
|
53
|
-
auto_model_class = AutoModel
|
54
|
-
|
46
|
+
class RBLNMultiControlNetModel(RBLNModel):
|
55
47
|
def __init__(
|
56
48
|
self,
|
57
|
-
models: List[
|
58
|
-
config: PretrainedConfig = None,
|
59
|
-
preprocessors: Optional[List] = None,
|
60
|
-
rbln_config: Optional[RBLNConfig] = None,
|
49
|
+
models: List[RBLNControlNetModel],
|
61
50
|
**kwargs,
|
62
51
|
):
|
63
|
-
|
64
|
-
models,
|
65
|
-
config,
|
66
|
-
preprocessors,
|
67
|
-
rbln_config,
|
68
|
-
**kwargs,
|
69
|
-
)
|
70
|
-
|
71
|
-
if not isinstance(config, PretrainedConfig):
|
72
|
-
config = PretrainedConfig(**config)
|
73
|
-
|
74
|
-
for i in range(len(models)):
|
75
|
-
self.runtimes[i].config = config
|
76
|
-
self.nets = self.runtimes
|
52
|
+
self.nets = models
|
77
53
|
self.dtype = torch.float32
|
78
54
|
|
79
55
|
@classmethod
|
@@ -83,7 +59,7 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
|
|
83
59
|
model_name_or_path: Union[str, Path],
|
84
60
|
**kwargs,
|
85
61
|
):
|
86
|
-
return MultiControlNetModel.from_pretrained(
|
62
|
+
return MultiControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
87
63
|
|
88
64
|
tasktmp = TasksManager.get_model_from_task
|
89
65
|
configtmp = AutoConfig.from_pretrained
|
@@ -101,129 +77,31 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
|
|
101
77
|
def _from_pretrained(
|
102
78
|
cls,
|
103
79
|
model_id: Union[str, Path],
|
104
|
-
config: "PretrainedConfig",
|
105
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
106
|
-
revision: Optional[str] = None,
|
107
|
-
force_download: bool = False,
|
108
|
-
cache_dir: Optional[str] = None,
|
109
|
-
file_name: Optional[str] = None,
|
110
|
-
subfolder: str = "",
|
111
|
-
local_files_only: bool = False,
|
112
80
|
**kwargs,
|
113
|
-
) ->
|
114
|
-
if isinstance(model_id, str):
|
115
|
-
model_path = Path(model_id)
|
116
|
-
else:
|
117
|
-
model_path = model_id / "controlnet"
|
81
|
+
) -> RBLNModel:
|
118
82
|
|
119
|
-
rbln_files = []
|
120
|
-
rbln_config_filenames = []
|
121
83
|
idx = 0
|
122
|
-
|
84
|
+
controlnets = []
|
85
|
+
model_path_to_load = model_id
|
123
86
|
|
124
|
-
while
|
125
|
-
|
126
|
-
|
87
|
+
while os.path.isdir(model_path_to_load):
|
88
|
+
controlnet = RBLNControlNetModel.from_pretrained(model_path_to_load, export=False, **kwargs)
|
89
|
+
controlnets.append(controlnet)
|
90
|
+
rbln_config = RBLNConfig.load(model_path_to_load)
|
127
91
|
idx += 1
|
128
|
-
|
129
|
-
|
130
|
-
if len(rbln_files) == 0:
|
131
|
-
raise FileNotFoundError(f"Could not find any rbln model file in {model_path}")
|
132
|
-
|
133
|
-
if len(rbln_config_filenames) == 0:
|
134
|
-
raise FileNotFoundError(f"Could not find `rbln_config.json` file in {model_path}")
|
135
|
-
|
136
|
-
models = []
|
137
|
-
for rconf, rfiles in zip(rbln_config_filenames, rbln_files):
|
138
|
-
rbln_config = RBLNConfig.load(str(rconf))
|
139
|
-
models.append(rebel.RBLNCompiledModel(rfiles))
|
140
|
-
|
141
|
-
preprocessors = []
|
92
|
+
model_path_to_load = model_id + f"_{idx}"
|
142
93
|
|
143
94
|
return cls(
|
144
|
-
|
145
|
-
config,
|
146
|
-
preprocessors,
|
95
|
+
controlnets,
|
147
96
|
rbln_config=rbln_config,
|
148
97
|
**kwargs,
|
149
98
|
)
|
150
99
|
|
151
|
-
def
|
152
|
-
|
153
|
-
idx = 0
|
154
|
-
real_save_dir_path = save_directory
|
155
|
-
for compiled_model in self.compiled_models:
|
156
|
-
dst_path = Path(real_save_dir_path) / "compiled_model.rbln"
|
157
|
-
if not os.path.exists(real_save_dir_path):
|
158
|
-
os.makedirs(real_save_dir_path)
|
159
|
-
compiled_model.save(dst_path)
|
160
|
-
self.rbln_config.save(real_save_dir_path)
|
161
|
-
idx += 1
|
162
|
-
real_save_dir_path = save_directory + f"_{idx}"
|
163
|
-
|
164
|
-
@classmethod
|
165
|
-
@torch.no_grad()
|
166
|
-
def _export(
|
167
|
-
cls,
|
168
|
-
model_id: str,
|
169
|
-
config: "PretrainedConfig",
|
170
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
171
|
-
revision: Optional[str] = None,
|
172
|
-
force_download: bool = False,
|
173
|
-
cache_dir: Optional[str] = None,
|
174
|
-
subfolder: str = "",
|
175
|
-
local_files_only: bool = False,
|
176
|
-
trust_remote_code: bool = False,
|
177
|
-
**kwargs,
|
178
|
-
) -> "RBLNMultiControlNetModel":
|
179
|
-
task = kwargs.pop("task", None)
|
180
|
-
if task is None:
|
181
|
-
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
182
|
-
|
183
|
-
save_dir = TemporaryDirectory()
|
184
|
-
save_dir_path = Path(save_dir.name)
|
185
|
-
|
186
|
-
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
187
|
-
img_width = rbln_config_kwargs.pop("rbln_img_width", None)
|
188
|
-
img_height = rbln_config_kwargs.pop("rbln_img_height", None)
|
189
|
-
vae_scale_factor = rbln_config_kwargs.pop("rbln_vae_scale_factor", None)
|
190
|
-
batch_size = rbln_config_kwargs.pop("rbln_batch_size", None)
|
191
|
-
|
192
|
-
model: MultiControlNetModel = TasksManager.get_model_from_task(
|
193
|
-
task=task,
|
194
|
-
model_name_or_path=model_id,
|
195
|
-
)
|
196
|
-
|
197
|
-
model_path_to_load = model_id
|
198
|
-
real_save_dir_path = save_dir_path / "controlnet"
|
199
|
-
|
200
|
-
for idx in range(len(model.nets)):
|
100
|
+
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
101
|
+
for idx, model in enumerate(self.nets):
|
201
102
|
suffix = "" if idx == 0 else f"_{idx}"
|
202
|
-
|
203
|
-
|
204
|
-
export=True,
|
205
|
-
rbln_batch_size=batch_size,
|
206
|
-
rbln_img_width=img_width,
|
207
|
-
rbln_img_height=img_height,
|
208
|
-
rbln_vae_scale_factor=vae_scale_factor,
|
209
|
-
)
|
210
|
-
controlnet.save_pretrained(real_save_dir_path)
|
211
|
-
real_save_dir_path = save_dir_path / f"controlnet_{idx+1}"
|
212
|
-
|
213
|
-
return cls._from_pretrained(
|
214
|
-
model_id=save_dir_path,
|
215
|
-
config=config,
|
216
|
-
model_save_dir=save_dir,
|
217
|
-
**rbln_constructor_kwargs,
|
218
|
-
**kwargs,
|
219
|
-
)
|
220
|
-
|
221
|
-
@classmethod
|
222
|
-
def _create_runtimes(
|
223
|
-
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
224
|
-
) -> List[rebel.Runtime]:
|
225
|
-
device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
226
|
-
return [compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in compiled_models]
|
103
|
+
real_save_path = save_directory + suffix
|
104
|
+
model.save_pretrained(real_save_path)
|
227
105
|
|
228
106
|
def forward(
|
229
107
|
self,
|
@@ -241,7 +119,7 @@ class RBLNMultiControlNetModel(RBLNBaseModel):
|
|
241
119
|
return_dict: bool = True,
|
242
120
|
):
|
243
121
|
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
244
|
-
output = controlnet(
|
122
|
+
output = controlnet.model[0](
|
245
123
|
sample=sample.contiguous(),
|
246
124
|
timestep=timestep.float(),
|
247
125
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -22,18 +22,18 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
"""RBLNStableDiffusionPipeline class for inference of diffusion models on rbln devices."""
|
24
24
|
|
25
|
-
from pathlib import Path
|
26
|
-
from tempfile import TemporaryDirectory
|
27
25
|
from typing import Any, Callable, Dict, List, Optional, Union
|
28
26
|
|
29
27
|
import torch
|
30
28
|
import torch.nn.functional as F
|
31
|
-
from diffusers import StableDiffusionControlNetPipeline
|
29
|
+
from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline
|
32
30
|
from diffusers.image_processor import PipelineImageInput
|
31
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
33
32
|
from diffusers.pipelines.controlnet.pipeline_controlnet import retrieve_timesteps
|
34
33
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
35
34
|
from diffusers.utils import deprecate, logging
|
36
35
|
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
36
|
+
from transformers import CLIPTextModel
|
37
37
|
|
38
38
|
from ....modeling_base import RBLNBaseModel
|
39
39
|
from ....transformers import RBLNCLIPTextModel
|
@@ -64,18 +64,40 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
64
64
|
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
65
65
|
"""
|
66
66
|
export = kwargs.pop("export", None)
|
67
|
+
vae = kwargs.pop("vae", None)
|
68
|
+
unet = kwargs.pop("unet", None)
|
67
69
|
text_encoder = kwargs.pop("text_encoder", None)
|
68
|
-
|
70
|
+
controlnet = kwargs.pop("controlnet", None)
|
71
|
+
model_save_dir = kwargs.pop("model_save_dir", None)
|
69
72
|
|
70
73
|
rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
|
71
74
|
|
72
75
|
kwargs_dict = {
|
73
76
|
"pretrained_model_name_or_path": model_id,
|
74
|
-
"text_encoder": text_encoder,
|
75
|
-
"controlnet": controlnets,
|
76
77
|
**kwargs,
|
77
78
|
}
|
78
79
|
|
80
|
+
kwargs_dict.update(
|
81
|
+
{
|
82
|
+
**({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
|
83
|
+
**({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
|
84
|
+
**(
|
85
|
+
{"text_encoder": text_encoder}
|
86
|
+
if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
|
87
|
+
else {}
|
88
|
+
),
|
89
|
+
**(
|
90
|
+
{"controlnet": controlnet}
|
91
|
+
if controlnet is not None
|
92
|
+
and (
|
93
|
+
isinstance(controlnet, ControlNetModel)
|
94
|
+
or all(isinstance(c, ControlNetModel) for c in controlnet)
|
95
|
+
)
|
96
|
+
else {}
|
97
|
+
),
|
98
|
+
}
|
99
|
+
)
|
100
|
+
|
79
101
|
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
80
102
|
|
81
103
|
if export is None or export is False:
|
@@ -85,64 +107,87 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
85
107
|
rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
86
108
|
)
|
87
109
|
|
88
|
-
save_dir = TemporaryDirectory()
|
89
|
-
save_dir_path = Path(save_dir.name)
|
90
|
-
|
91
|
-
model.save_pretrained(save_directory=save_dir_path, **kwargs)
|
92
|
-
|
93
110
|
# compile model, create runtime
|
94
|
-
vae
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
rbln_use_encode=False,
|
99
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
100
|
-
**rbln_config_kwargs,
|
101
|
-
**rbln_constructor_kwargs,
|
102
|
-
)
|
103
|
-
|
104
|
-
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
105
|
-
model_id=save_dir_path / "text_encoder",
|
106
|
-
export=True,
|
107
|
-
**rbln_config_kwargs,
|
108
|
-
**rbln_constructor_kwargs,
|
109
|
-
)
|
110
|
-
|
111
|
-
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
112
|
-
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
113
|
-
|
114
|
-
unet = RBLNUNet2DConditionModel.from_pretrained(
|
115
|
-
model_id=save_dir_path / "unet",
|
116
|
-
export=True,
|
117
|
-
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
118
|
-
rbln_batch_size=unet_batch_size,
|
119
|
-
rbln_use_encode=False,
|
120
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
121
|
-
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
122
|
-
**rbln_config_kwargs,
|
123
|
-
**rbln_constructor_kwargs,
|
124
|
-
)
|
125
|
-
|
126
|
-
if isinstance(controlnets, (list, tuple)):
|
127
|
-
controlnet = RBLNMultiControlNetModel.from_pretrained(
|
128
|
-
model_id=str(save_dir_path / "controlnet"),
|
111
|
+
if not isinstance(vae, RBLNAutoencoderKL):
|
112
|
+
vae = RBLNAutoencoderKL.from_pretrained(
|
113
|
+
model_id=model_id,
|
114
|
+
subfolder="vae",
|
129
115
|
export=True,
|
130
|
-
|
116
|
+
model_save_dir=model_save_dir,
|
117
|
+
rbln_unet_sample_size=model.unet.config.sample_size,
|
118
|
+
rbln_use_encode=False,
|
131
119
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
132
120
|
**rbln_config_kwargs,
|
133
121
|
**rbln_constructor_kwargs,
|
134
122
|
)
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
model_id=
|
123
|
+
|
124
|
+
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
125
|
+
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
126
|
+
model_id=model_id,
|
127
|
+
subfolder="text_encoder",
|
139
128
|
export=True,
|
129
|
+
model_save_dir=model_save_dir,
|
130
|
+
**rbln_config_kwargs,
|
131
|
+
**rbln_constructor_kwargs,
|
132
|
+
)
|
133
|
+
|
134
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
|
135
|
+
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
136
|
+
|
137
|
+
if not isinstance(unet, RBLNUNet2DConditionModel):
|
138
|
+
unet = RBLNUNet2DConditionModel.from_pretrained(
|
139
|
+
model_id=model_id,
|
140
|
+
subfolder="unet",
|
141
|
+
export=True,
|
142
|
+
model_save_dir=model_save_dir,
|
143
|
+
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
140
144
|
rbln_batch_size=unet_batch_size,
|
145
|
+
rbln_use_encode=False,
|
141
146
|
rbln_vae_scale_factor=model.vae_scale_factor,
|
147
|
+
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
142
148
|
**rbln_config_kwargs,
|
143
149
|
**rbln_constructor_kwargs,
|
144
150
|
)
|
145
|
-
|
151
|
+
|
152
|
+
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
153
|
+
if isinstance(controlnet, (list, tuple)):
|
154
|
+
multicontrolnet = []
|
155
|
+
for i, cid in enumerate(controlnet):
|
156
|
+
subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
|
157
|
+
multicontrolnet.append(
|
158
|
+
RBLNControlNetModel.from_pretrained(
|
159
|
+
model_id=cid.config._name_or_path,
|
160
|
+
subfolder=subfolder_name,
|
161
|
+
export=True,
|
162
|
+
model_save_dir=model_save_dir,
|
163
|
+
rbln_batch_size=unet_batch_size,
|
164
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
165
|
+
**rbln_config_kwargs,
|
166
|
+
**rbln_constructor_kwargs,
|
167
|
+
)
|
168
|
+
)
|
169
|
+
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
170
|
+
controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
|
171
|
+
else:
|
172
|
+
controlnet = RBLNControlNetModel.from_pretrained(
|
173
|
+
model_id=controlnet.config._name_or_path,
|
174
|
+
subfolder="controlnet",
|
175
|
+
export=True,
|
176
|
+
model_save_dir=model_save_dir,
|
177
|
+
rbln_batch_size=unet_batch_size,
|
178
|
+
rbln_vae_scale_factor=model.vae_scale_factor,
|
179
|
+
**rbln_config_kwargs,
|
180
|
+
**rbln_constructor_kwargs,
|
181
|
+
)
|
182
|
+
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
183
|
+
|
184
|
+
if model_save_dir is not None:
|
185
|
+
# To skip saving original pytorch modules
|
186
|
+
del (model.vae, model.text_encoder, model.unet, model.controlnet)
|
187
|
+
|
188
|
+
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
189
|
+
# So config must be saved again, later.
|
190
|
+
model.save_pretrained(model_save_dir)
|
146
191
|
|
147
192
|
# replace modules
|
148
193
|
model.vae = vae
|
@@ -159,15 +204,18 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
159
204
|
}
|
160
205
|
model.register_to_config(**update_dict)
|
161
206
|
|
162
|
-
|
207
|
+
if model_save_dir is not None:
|
208
|
+
# overwrite to replace incorrect config
|
209
|
+
model.save_config(model_save_dir)
|
163
210
|
|
211
|
+
# use for CI to access each compiled model
|
164
212
|
if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
|
165
|
-
model.compiled_models = [
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
213
|
+
model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
|
214
|
+
if isinstance(controlnet, RBLNMultiControlNetModel):
|
215
|
+
for c_model in controlnet.nets:
|
216
|
+
model.compiled_models.append(c_model.compiled_models[0])
|
217
|
+
else:
|
218
|
+
model.compiled_models.append(controlnet.compiled_models[0])
|
171
219
|
|
172
220
|
return model
|
173
221
|
|