optimum-rbln 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +6 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -9
- optimum/rbln/diffusers/models/controlnet.py +93 -23
- optimum/rbln/diffusers/models/unet_2d_condition.py +78 -61
- optimum/rbln/diffusers/pipelines/__init__.py +7 -2
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +768 -0
- optimum/rbln/diffusers/pipelines/{stable_diffusion → controlnet}/pipeline_controlnet_img2img.py +25 -16
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +942 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +955 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -9
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +19 -3
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +19 -3
- optimum/rbln/modeling_base.py +36 -3
- optimum/rbln/modeling_seq2seq.py +19 -4
- optimum/rbln/transformers/generation/__init__.py +1 -0
- optimum/rbln/transformers/generation/streamers.py +17 -0
- optimum/rbln/transformers/generation/utils.py +399 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +24 -333
- optimum/rbln/transformers/models/llama/modeling_llama.py +63 -45
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +13 -3
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/RECORD +29 -25
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -63,6 +63,9 @@ _import_structure = {
|
|
63
63
|
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
64
64
|
"RBLNMultiControlNetModel",
|
65
65
|
"RBLNStableDiffusionXLImg2ImgPipeline",
|
66
|
+
"RBLNStableDiffusionControlNetPipeline",
|
67
|
+
"RBLNStableDiffusionXLControlNetPipeline",
|
68
|
+
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
66
69
|
],
|
67
70
|
"modeling_config": ["RBLNRuntimeConfig", "RBLNConfig"],
|
68
71
|
}
|
@@ -73,8 +76,11 @@ if TYPE_CHECKING:
|
|
73
76
|
RBLNControlNetModel,
|
74
77
|
RBLNMultiControlNetModel,
|
75
78
|
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
79
|
+
RBLNStableDiffusionControlNetPipeline,
|
76
80
|
RBLNStableDiffusionImg2ImgPipeline,
|
77
81
|
RBLNStableDiffusionPipeline,
|
82
|
+
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
83
|
+
RBLNStableDiffusionXLControlNetPipeline,
|
78
84
|
RBLNStableDiffusionXLImg2ImgPipeline,
|
79
85
|
RBLNStableDiffusionXLPipeline,
|
80
86
|
RBLNUNet2DConditionModel,
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.1.
|
1
|
+
__version__ = '0.1.1'
|
@@ -39,17 +39,24 @@ _import_structure = {
|
|
39
39
|
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
40
40
|
"RBLNMultiControlNetModel",
|
41
41
|
"RBLNStableDiffusionXLImg2ImgPipeline",
|
42
|
+
"RBLNStableDiffusionControlNetPipeline",
|
43
|
+
"RBLNStableDiffusionXLControlNetPipeline",
|
44
|
+
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
42
45
|
],
|
43
46
|
"models": ["RBLNAutoencoderKL", "RBLNUNet2DConditionModel", "RBLNControlNetModel"],
|
44
47
|
}
|
45
48
|
|
46
49
|
if TYPE_CHECKING:
|
50
|
+
|
47
51
|
from .models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
|
48
52
|
from .pipelines import (
|
49
53
|
RBLNMultiControlNetModel,
|
50
54
|
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
55
|
+
RBLNStableDiffusionControlNetPipeline,
|
51
56
|
RBLNStableDiffusionImg2ImgPipeline,
|
52
57
|
RBLNStableDiffusionPipeline,
|
58
|
+
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
59
|
+
RBLNStableDiffusionXLControlNetPipeline,
|
53
60
|
RBLNStableDiffusionXLImg2ImgPipeline,
|
54
61
|
RBLNStableDiffusionXLPipeline,
|
55
62
|
)
|
@@ -88,14 +88,23 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
88
88
|
subfolder: str = "",
|
89
89
|
local_files_only: bool = False,
|
90
90
|
trust_remote_code: bool = False,
|
91
|
+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
91
92
|
**kwargs,
|
92
93
|
) -> "RBLNAutoencoderKL":
|
93
94
|
task = kwargs.pop("task", None)
|
94
95
|
if task is None:
|
95
96
|
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
96
97
|
|
97
|
-
|
98
|
-
|
98
|
+
if model_save_dir is None:
|
99
|
+
save_dir = TemporaryDirectory()
|
100
|
+
save_dir_path = Path(save_dir.name)
|
101
|
+
else:
|
102
|
+
save_dir = model_save_dir
|
103
|
+
if isinstance(save_dir, TemporaryDirectory):
|
104
|
+
save_dir_path = Path(model_save_dir.name)
|
105
|
+
else:
|
106
|
+
save_dir_path = Path(model_save_dir)
|
107
|
+
save_dir_path.mkdir(exist_ok=True)
|
99
108
|
|
100
109
|
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
101
110
|
|
@@ -119,7 +128,7 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
119
128
|
if not isinstance(config, PretrainedConfig): # diffusers config
|
120
129
|
config = PretrainedConfig(**config)
|
121
130
|
|
122
|
-
config.save_pretrained(save_dir_path)
|
131
|
+
config.save_pretrained(save_dir_path / subfolder)
|
123
132
|
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
124
133
|
|
125
134
|
# Get compilation arguments
|
@@ -137,8 +146,12 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
137
146
|
enc_compiled_model = cls.compile(encoder_model, rbln_runtime_config=rbln_config["encoder"][0])
|
138
147
|
dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["decoder"][0])
|
139
148
|
|
140
|
-
enc_compiled_model.save(
|
141
|
-
|
149
|
+
enc_compiled_model.save(
|
150
|
+
save_dir_path / subfolder / f"{rbln_config['encoder'][0].compiled_model_name}.rbln"
|
151
|
+
)
|
152
|
+
dec_compiled_model.save(
|
153
|
+
save_dir_path / subfolder / f"{rbln_config['decoder'][0].compiled_model_name}.rbln"
|
154
|
+
)
|
142
155
|
|
143
156
|
def compile_text2img():
|
144
157
|
decoder_model = _VAEDecoder(model)
|
@@ -146,19 +159,27 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
146
159
|
|
147
160
|
dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["compiled_model"][0])
|
148
161
|
|
149
|
-
dec_compiled_model.save(
|
162
|
+
dec_compiled_model.save(
|
163
|
+
save_dir_path / subfolder / f"{rbln_config['compiled_model'][0].compiled_model_name}.rbln"
|
164
|
+
)
|
150
165
|
|
151
166
|
if rbln_config_kwargs.get("rbln_use_encode"):
|
152
167
|
compile_img2img()
|
153
168
|
else:
|
154
169
|
compile_text2img()
|
155
170
|
|
156
|
-
rbln_config.save(save_dir_path)
|
171
|
+
rbln_config.save(save_dir_path / subfolder)
|
157
172
|
|
158
173
|
return cls._from_pretrained(
|
159
174
|
model_id=save_dir_path,
|
160
175
|
config=config,
|
161
176
|
model_save_dir=save_dir,
|
177
|
+
use_auth_token=use_auth_token,
|
178
|
+
revision=revision,
|
179
|
+
force_download=force_download,
|
180
|
+
cache_dir=cache_dir,
|
181
|
+
subfolder=subfolder,
|
182
|
+
local_files_only=local_files_only,
|
162
183
|
**rbln_constructor_kwargs,
|
163
184
|
**kwargs,
|
164
185
|
)
|
@@ -216,7 +237,7 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
216
237
|
meta["rbln_img_height"] = rbln_img_height
|
217
238
|
|
218
239
|
vae_enc_input_info = [
|
219
|
-
("x", [rbln_batch_size, model_config.in_channels,
|
240
|
+
("x", [rbln_batch_size, model_config.in_channels, rbln_img_height, rbln_img_width], "float32")
|
220
241
|
]
|
221
242
|
vae_dec_input_info = [
|
222
243
|
(
|
@@ -224,8 +245,8 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
224
245
|
[
|
225
246
|
rbln_batch_size,
|
226
247
|
model_config.latent_channels,
|
227
|
-
rbln_img_width // rbln_vae_scale_factor,
|
228
248
|
rbln_img_height // rbln_vae_scale_factor,
|
249
|
+
rbln_img_width // rbln_vae_scale_factor,
|
229
250
|
],
|
230
251
|
"float32",
|
231
252
|
)
|
@@ -23,9 +23,8 @@
|
|
23
23
|
|
24
24
|
import logging
|
25
25
|
from pathlib import Path
|
26
|
-
from typing import TYPE_CHECKING, Optional, Union
|
26
|
+
from typing import TYPE_CHECKING, Dict, Optional, Union
|
27
27
|
|
28
|
-
import rebel
|
29
28
|
import torch
|
30
29
|
from diffusers import ControlNetModel
|
31
30
|
from optimum.exporters import TasksManager
|
@@ -46,6 +45,37 @@ class _ControlNetModel(torch.nn.Module):
|
|
46
45
|
super().__init__()
|
47
46
|
self.controlnet = controlnet
|
48
47
|
|
48
|
+
def forward(
|
49
|
+
self,
|
50
|
+
sample: torch.Tensor,
|
51
|
+
timestep: torch.Tensor,
|
52
|
+
controlnet_cond: torch.Tensor,
|
53
|
+
conditioning_scale,
|
54
|
+
text_embeds: Optional[torch.Tensor] = None,
|
55
|
+
time_ids: Optional[torch.Tensor] = None,
|
56
|
+
):
|
57
|
+
if text_embeds is not None and time_ids is not None:
|
58
|
+
added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
|
59
|
+
else:
|
60
|
+
added_cond_kwargs = {}
|
61
|
+
|
62
|
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
63
|
+
sample=sample,
|
64
|
+
timestep=timestep,
|
65
|
+
encoder_hidden_states=None,
|
66
|
+
controlnet_cond=controlnet_cond,
|
67
|
+
conditioning_scale=conditioning_scale,
|
68
|
+
added_cond_kwargs=added_cond_kwargs,
|
69
|
+
return_dict=False,
|
70
|
+
)
|
71
|
+
return down_block_res_samples, mid_block_res_sample
|
72
|
+
|
73
|
+
|
74
|
+
class _ControlNetModel_Cross_Attention(torch.nn.Module):
|
75
|
+
def __init__(self, controlnet: "ControlNetModel"):
|
76
|
+
super().__init__()
|
77
|
+
self.controlnet = controlnet
|
78
|
+
|
49
79
|
def forward(
|
50
80
|
self,
|
51
81
|
sample: torch.Tensor,
|
@@ -53,13 +83,21 @@ class _ControlNetModel(torch.nn.Module):
|
|
53
83
|
encoder_hidden_states: torch.Tensor,
|
54
84
|
controlnet_cond: torch.Tensor,
|
55
85
|
conditioning_scale,
|
86
|
+
text_embeds: Optional[torch.Tensor] = None,
|
87
|
+
time_ids: Optional[torch.Tensor] = None,
|
56
88
|
):
|
89
|
+
if text_embeds is not None and time_ids is not None:
|
90
|
+
added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
|
91
|
+
else:
|
92
|
+
added_cond_kwargs = {}
|
93
|
+
|
57
94
|
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
58
95
|
sample=sample,
|
59
96
|
timestep=timestep,
|
60
97
|
encoder_hidden_states=encoder_hidden_states,
|
61
98
|
controlnet_cond=controlnet_cond,
|
62
99
|
conditioning_scale=conditioning_scale,
|
100
|
+
added_cond_kwargs=added_cond_kwargs,
|
63
101
|
return_dict=False,
|
64
102
|
)
|
65
103
|
return down_block_res_samples, mid_block_res_sample
|
@@ -71,6 +109,9 @@ class RBLNControlNetModel(RBLNModel):
|
|
71
109
|
|
72
110
|
def __post_init__(self, **kwargs):
|
73
111
|
self.dtype = torch.float32
|
112
|
+
self.use_encoder_hidden_states = any(
|
113
|
+
item[0] == "encoder_hidden_states" for item in self.rbln_config["compiled_model"][0].input_info
|
114
|
+
)
|
74
115
|
|
75
116
|
@classmethod
|
76
117
|
def from_pretrained(cls, *args, **kwargs):
|
@@ -94,14 +135,16 @@ class RBLNControlNetModel(RBLNModel):
|
|
94
135
|
return rt
|
95
136
|
|
96
137
|
@classmethod
|
97
|
-
def
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
138
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
139
|
+
use_encoder_hidden_states = False
|
140
|
+
for down_block in model.down_blocks:
|
141
|
+
if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
|
142
|
+
break
|
143
|
+
|
144
|
+
if use_encoder_hidden_states:
|
145
|
+
return _ControlNetModel_Cross_Attention(model).eval()
|
146
|
+
else:
|
147
|
+
return _ControlNetModel(model).eval()
|
105
148
|
|
106
149
|
@classmethod
|
107
150
|
def _get_rbln_config(
|
@@ -109,6 +152,7 @@ class RBLNControlNetModel(RBLNModel):
|
|
109
152
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
110
153
|
model_config: "PretrainedConfig",
|
111
154
|
rbln_max_seq_len: Optional[int] = None,
|
155
|
+
rbln_text_model_hidden_size: Optional[int] = None,
|
112
156
|
rbln_batch_size: Optional[int] = None,
|
113
157
|
rbln_img_width: Optional[int] = None,
|
114
158
|
rbln_img_height: Optional[int] = None,
|
@@ -132,12 +176,18 @@ class RBLNControlNetModel(RBLNModel):
|
|
132
176
|
[
|
133
177
|
rbln_batch_size,
|
134
178
|
model_config.in_channels,
|
135
|
-
input_width,
|
136
179
|
input_height,
|
180
|
+
input_width,
|
137
181
|
],
|
138
182
|
"float32",
|
139
183
|
),
|
140
184
|
("timestep", [], "float32"),
|
185
|
+
],
|
186
|
+
batch_size=rbln_batch_size,
|
187
|
+
)
|
188
|
+
use_encoder_hidden_states = any(element != "DownBlock2D" for element in model_config.down_block_types)
|
189
|
+
if use_encoder_hidden_states:
|
190
|
+
rbln_runtime_config.input_info.append(
|
141
191
|
(
|
142
192
|
"encoder_hidden_states",
|
143
193
|
[
|
@@ -146,12 +196,20 @@ class RBLNControlNetModel(RBLNModel):
|
|
146
196
|
model_config.cross_attention_dim,
|
147
197
|
],
|
148
198
|
"float32",
|
149
|
-
)
|
150
|
-
|
151
|
-
|
152
|
-
],
|
153
|
-
batch_size=rbln_batch_size,
|
199
|
+
)
|
200
|
+
)
|
201
|
+
rbln_runtime_config.input_info.append(
|
202
|
+
("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32")
|
154
203
|
)
|
204
|
+
rbln_runtime_config.input_info.append(("conditioning_scale", [], "float32"))
|
205
|
+
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
206
|
+
if rbln_text_model_hidden_size is None:
|
207
|
+
rbln_text_model_hidden_size = 768
|
208
|
+
rbln_runtime_config.input_info.append(
|
209
|
+
("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
|
210
|
+
)
|
211
|
+
rbln_runtime_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
|
212
|
+
|
155
213
|
rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
156
214
|
return rbln_config
|
157
215
|
|
@@ -162,18 +220,30 @@ class RBLNControlNetModel(RBLNModel):
|
|
162
220
|
encoder_hidden_states: torch.Tensor,
|
163
221
|
controlnet_cond: torch.FloatTensor,
|
164
222
|
conditioning_scale: torch.Tensor = 1.0,
|
223
|
+
added_cond_kwargs: Dict[str, torch.Tensor] = {},
|
165
224
|
**kwargs,
|
166
225
|
):
|
167
226
|
"""
|
168
227
|
The [`ControlNetModel`] forward method.
|
169
228
|
"""
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
229
|
+
added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
|
230
|
+
if self.use_encoder_hidden_states:
|
231
|
+
output = super().forward(
|
232
|
+
sample.contiguous(),
|
233
|
+
timestep.float(),
|
234
|
+
encoder_hidden_states,
|
235
|
+
controlnet_cond,
|
236
|
+
torch.tensor(conditioning_scale),
|
237
|
+
**added_cond_kwargs,
|
238
|
+
)
|
239
|
+
else:
|
240
|
+
output = super().forward(
|
241
|
+
sample.contiguous(),
|
242
|
+
timestep.float(),
|
243
|
+
controlnet_cond,
|
244
|
+
torch.tensor(conditioning_scale),
|
245
|
+
**added_cond_kwargs,
|
246
|
+
)
|
177
247
|
down_block_res_samples = output[:-1]
|
178
248
|
mid_block_res_sample = output[-1]
|
179
249
|
|
@@ -27,7 +27,7 @@ from pathlib import Path
|
|
27
27
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
28
28
|
|
29
29
|
import torch
|
30
|
-
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
30
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
31
31
|
from optimum.exporters import TasksManager
|
32
32
|
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
33
33
|
|
@@ -90,22 +90,28 @@ class _UNet_SDXL(torch.nn.Module):
|
|
90
90
|
sample: torch.Tensor,
|
91
91
|
timestep: Union[torch.Tensor, float, int],
|
92
92
|
encoder_hidden_states: torch.Tensor,
|
93
|
-
text_embeds: Optional[torch.Tensor] = None,
|
94
|
-
time_ids: Optional[torch.Tensor] = None,
|
95
93
|
*down_and_mid_block_additional_residuals: Optional[Tuple[torch.Tensor]],
|
96
94
|
) -> torch.Tensor:
|
97
|
-
if
|
98
|
-
added_cond_kwargs = {
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
95
|
+
if len(down_and_mid_block_additional_residuals) == 2:
|
96
|
+
added_cond_kwargs = {
|
97
|
+
"text_embeds": down_and_mid_block_additional_residuals[0],
|
98
|
+
"time_ids": down_and_mid_block_additional_residuals[1],
|
99
|
+
}
|
100
|
+
down_block_additional_residuals = None
|
101
|
+
mid_block_additional_residual = None
|
102
|
+
elif len(down_and_mid_block_additional_residuals) > 2:
|
103
|
+
added_cond_kwargs = {
|
104
|
+
"text_embeds": down_and_mid_block_additional_residuals[-2],
|
105
|
+
"time_ids": down_and_mid_block_additional_residuals[-1],
|
106
|
+
}
|
103
107
|
down_block_additional_residuals, mid_block_additional_residual = (
|
104
|
-
down_and_mid_block_additional_residuals[:-
|
105
|
-
down_and_mid_block_additional_residuals[-
|
108
|
+
down_and_mid_block_additional_residuals[:-3],
|
109
|
+
down_and_mid_block_additional_residuals[-3],
|
106
110
|
)
|
107
111
|
else:
|
108
|
-
|
112
|
+
added_cond_kwargs = {}
|
113
|
+
down_block_additional_residuals = None
|
114
|
+
mid_block_additional_residual = None
|
109
115
|
|
110
116
|
unet_out = self.unet(
|
111
117
|
sample=sample,
|
@@ -197,9 +203,11 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
197
203
|
meta["rbln_use_encode"] = rbln_use_encode
|
198
204
|
|
199
205
|
if rbln_use_encode:
|
206
|
+
# FIXME :: robust img shape getter
|
200
207
|
input_width = rbln_img_width // rbln_vae_scale_factor
|
201
208
|
input_height = rbln_img_height // rbln_vae_scale_factor
|
202
209
|
else:
|
210
|
+
# FIXME :: model_config.sample_size can be tuple or list
|
203
211
|
input_width, input_height = model_config.sample_size, model_config.sample_size
|
204
212
|
|
205
213
|
input_info = [
|
@@ -208,8 +216,8 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
208
216
|
[
|
209
217
|
rbln_batch_size,
|
210
218
|
model_config.in_channels,
|
211
|
-
input_width,
|
212
219
|
input_height,
|
220
|
+
input_width,
|
213
221
|
],
|
214
222
|
"float32",
|
215
223
|
),
|
@@ -225,64 +233,73 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
225
233
|
),
|
226
234
|
]
|
227
235
|
if rbln_is_controlnet:
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
input_info.append(
|
239
|
-
(
|
240
|
-
f"down_block_additional_residuals_{3}",
|
241
|
-
[rbln_batch_size, model_config.block_out_channels[0], input_width // 2, input_height // 2],
|
242
|
-
"float32",
|
236
|
+
if len(model_config.block_out_channels) > 0:
|
237
|
+
input_info.extend(
|
238
|
+
[
|
239
|
+
(
|
240
|
+
f"down_block_additional_residuals_{i}",
|
241
|
+
[rbln_batch_size, model_config.block_out_channels[0], input_height, input_width],
|
242
|
+
"float32",
|
243
|
+
)
|
244
|
+
for i in range(3)
|
245
|
+
]
|
243
246
|
)
|
244
|
-
|
245
|
-
input_info.extend(
|
246
|
-
[
|
247
|
+
input_info.append(
|
247
248
|
(
|
248
|
-
|
249
|
-
[rbln_batch_size, model_config.block_out_channels[
|
249
|
+
"down_block_additional_residuals_3",
|
250
|
+
[rbln_batch_size, model_config.block_out_channels[0], input_height // 2, input_width // 2],
|
250
251
|
"float32",
|
251
252
|
)
|
252
|
-
for i in range(4, 6)
|
253
|
-
]
|
254
|
-
)
|
255
|
-
input_info.append(
|
256
|
-
(
|
257
|
-
f"down_block_additional_residuals_{6}",
|
258
|
-
[rbln_batch_size, model_config.block_out_channels[1], input_width // 4, input_height // 4],
|
259
|
-
"float32",
|
260
253
|
)
|
261
|
-
)
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
[
|
254
|
+
if len(model_config.block_out_channels) > 1:
|
255
|
+
input_info.extend(
|
256
|
+
[
|
257
|
+
(
|
258
|
+
f"down_block_additional_residuals_{i}",
|
259
|
+
[rbln_batch_size, model_config.block_out_channels[1], input_height // 2, input_width // 2],
|
260
|
+
"float32",
|
261
|
+
)
|
262
|
+
for i in range(4, 6)
|
263
|
+
]
|
264
|
+
)
|
265
|
+
input_info.append(
|
274
266
|
(
|
275
|
-
f"down_block_additional_residuals_{
|
276
|
-
[rbln_batch_size, model_config.block_out_channels[
|
267
|
+
f"down_block_additional_residuals_{6}",
|
268
|
+
[rbln_batch_size, model_config.block_out_channels[1], input_height // 4, input_width // 4],
|
277
269
|
"float32",
|
278
270
|
)
|
279
|
-
|
280
|
-
|
281
|
-
|
271
|
+
)
|
272
|
+
if len(model_config.block_out_channels) > 2:
|
273
|
+
input_info.extend(
|
274
|
+
[
|
275
|
+
(
|
276
|
+
f"down_block_additional_residuals_{i}",
|
277
|
+
[rbln_batch_size, model_config.block_out_channels[2], input_height // 4, input_width // 4],
|
278
|
+
"float32",
|
279
|
+
)
|
280
|
+
for i in range(7, 9)
|
281
|
+
]
|
282
|
+
)
|
283
|
+
if len(model_config.block_out_channels) > 3:
|
284
|
+
input_info.extend(
|
285
|
+
[
|
286
|
+
(
|
287
|
+
f"down_block_additional_residuals_{i}",
|
288
|
+
[rbln_batch_size, model_config.block_out_channels[3], input_height // 8, input_width // 8],
|
289
|
+
"float32",
|
290
|
+
)
|
291
|
+
for i in range(9, 12)
|
292
|
+
]
|
293
|
+
)
|
282
294
|
input_info.append(
|
283
295
|
(
|
284
296
|
"mid_block_additional_residual",
|
285
|
-
[
|
297
|
+
[
|
298
|
+
rbln_batch_size,
|
299
|
+
model_config.block_out_channels[-1],
|
300
|
+
input_height // 2 ** (len(model_config.block_out_channels) - 1),
|
301
|
+
input_width // 2 ** (len(model_config.block_out_channels) - 1),
|
302
|
+
],
|
286
303
|
"float32",
|
287
304
|
)
|
288
305
|
)
|
@@ -344,7 +361,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
344
361
|
|
345
362
|
return (
|
346
363
|
super().forward(
|
347
|
-
sample,
|
364
|
+
sample.contiguous(),
|
348
365
|
timestep.float(),
|
349
366
|
encoder_hidden_states,
|
350
367
|
**added_cond_kwargs,
|
@@ -21,9 +21,14 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from .controlnet import
|
25
|
-
|
24
|
+
from .controlnet import (
|
25
|
+
RBLNMultiControlNetModel,
|
26
26
|
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
27
|
+
RBLNStableDiffusionControlNetPipeline,
|
28
|
+
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
29
|
+
RBLNStableDiffusionXLControlNetPipeline,
|
30
|
+
)
|
31
|
+
from .stable_diffusion import (
|
27
32
|
RBLNStableDiffusionImg2ImgPipeline,
|
28
33
|
RBLNStableDiffusionPipeline,
|
29
34
|
)
|
@@ -22,3 +22,7 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
from .multicontrolnet import RBLNMultiControlNetModel
|
25
|
+
from .pipeline_controlnet import RBLNStableDiffusionControlNetPipeline
|
26
|
+
from .pipeline_controlnet_img2img import RBLNStableDiffusionControlNetImg2ImgPipeline
|
27
|
+
from .pipeline_controlnet_sd_xl import RBLNStableDiffusionXLControlNetPipeline
|
28
|
+
from .pipeline_controlnet_sd_xl_img2img import RBLNStableDiffusionXLControlNetImg2ImgPipeline
|