optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,17 +21,17 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
import importlib
|
24
25
|
import logging
|
25
|
-
from pathlib import Path
|
26
26
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
27
27
|
|
28
28
|
import torch
|
29
29
|
from diffusers import ControlNetModel
|
30
|
-
from
|
31
|
-
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
30
|
+
from transformers import PretrainedConfig
|
32
31
|
|
33
|
-
from ...
|
32
|
+
from ...modeling import RBLNModel
|
34
33
|
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
34
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
35
35
|
|
36
36
|
|
37
37
|
if TYPE_CHECKING:
|
@@ -105,33 +105,15 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
|
|
105
105
|
|
106
106
|
|
107
107
|
class RBLNControlNetModel(RBLNModel):
|
108
|
+
hf_library_name = "diffusers"
|
109
|
+
auto_model_class = ControlNetModel
|
110
|
+
|
108
111
|
def __post_init__(self, **kwargs):
|
109
112
|
super().__post_init__(**kwargs)
|
110
113
|
self.use_encoder_hidden_states = any(
|
111
114
|
item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
|
112
115
|
)
|
113
116
|
|
114
|
-
@classmethod
|
115
|
-
def from_pretrained(cls, *args, **kwargs):
|
116
|
-
def get_model_from_task(
|
117
|
-
task: str,
|
118
|
-
model_name_or_path: Union[str, Path],
|
119
|
-
**kwargs,
|
120
|
-
):
|
121
|
-
return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
122
|
-
|
123
|
-
tasktmp = TasksManager.get_model_from_task
|
124
|
-
configtmp = AutoConfig.from_pretrained
|
125
|
-
modeltmp = AutoModel.from_pretrained
|
126
|
-
TasksManager.get_model_from_task = get_model_from_task
|
127
|
-
AutoConfig.from_pretrained = ControlNetModel.load_config
|
128
|
-
AutoModel.from_pretrained = ControlNetModel.from_pretrained
|
129
|
-
rt = super().from_pretrained(*args, **kwargs)
|
130
|
-
AutoConfig.from_pretrained = configtmp
|
131
|
-
AutoModel.from_pretrained = modeltmp
|
132
|
-
TasksManager.get_model_from_task = tasktmp
|
133
|
-
return rt
|
134
|
-
|
135
117
|
@classmethod
|
136
118
|
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
137
119
|
use_encoder_hidden_states = False
|
@@ -144,6 +126,38 @@ class RBLNControlNetModel(RBLNModel):
|
|
144
126
|
else:
|
145
127
|
return _ControlNetModel(model).eval()
|
146
128
|
|
129
|
+
@classmethod
|
130
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
131
|
+
rbln_vae_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.vae.__class__.__name__}")
|
132
|
+
rbln_unet_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.unet.__class__.__name__}")
|
133
|
+
text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
|
134
|
+
|
135
|
+
batch_size = rbln_config.get("batch_size")
|
136
|
+
if not batch_size:
|
137
|
+
do_classifier_free_guidance = (
|
138
|
+
rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
|
139
|
+
)
|
140
|
+
batch_size = 2 if do_classifier_free_guidance else 1
|
141
|
+
else:
|
142
|
+
if rbln_config.get("guidance_scale"):
|
143
|
+
logger.warning(
|
144
|
+
"guidance_scale is ignored because batch size is explicitly specified. "
|
145
|
+
"To ensure consistent behavior, consider removing the guidance scale or "
|
146
|
+
"adjusting the batch size configuration as needed."
|
147
|
+
)
|
148
|
+
|
149
|
+
rbln_config.update(
|
150
|
+
{
|
151
|
+
"max_seq_len": pipe.text_encoder.config.max_position_embeddings,
|
152
|
+
"text_model_hidden_size": text_model_hidden_size,
|
153
|
+
"vae_sample_size": rbln_vae_cls.get_vae_sample_size(pipe, rbln_config),
|
154
|
+
"unet_sample_size": rbln_unet_cls.get_unet_sample_size(pipe, rbln_config),
|
155
|
+
"batch_size": batch_size,
|
156
|
+
}
|
157
|
+
)
|
158
|
+
|
159
|
+
return rbln_config
|
160
|
+
|
147
161
|
@classmethod
|
148
162
|
def _get_rbln_config(
|
149
163
|
cls,
|
@@ -151,33 +165,35 @@ class RBLNControlNetModel(RBLNModel):
|
|
151
165
|
model_config: "PretrainedConfig",
|
152
166
|
rbln_kwargs: Dict[str, Any] = {},
|
153
167
|
) -> RBLNConfig:
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
rbln_img_height = rbln_kwargs.get("img_height", None)
|
159
|
-
rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
|
168
|
+
batch_size = rbln_kwargs.get("batch_size")
|
169
|
+
max_seq_len = rbln_kwargs.get("max_seq_len")
|
170
|
+
unet_sample_size = rbln_kwargs.get("unet_sample_size")
|
171
|
+
vae_sample_size = rbln_kwargs.get("vae_sample_size")
|
160
172
|
|
161
|
-
if
|
162
|
-
|
173
|
+
if batch_size is None:
|
174
|
+
batch_size = 1
|
163
175
|
|
164
|
-
if
|
165
|
-
|
176
|
+
if unet_sample_size is None:
|
177
|
+
raise ValueError(
|
178
|
+
"`rbln_unet_sample_size` (latent height, widht) must be specified (ex. unet's sample_size)"
|
179
|
+
)
|
166
180
|
|
167
|
-
if
|
168
|
-
raise ValueError(
|
181
|
+
if vae_sample_size is None:
|
182
|
+
raise ValueError(
|
183
|
+
"`rbln_vae_sample_size` (input image height, width) must be specified (ex. vae's sample_size)"
|
184
|
+
)
|
169
185
|
|
170
|
-
|
171
|
-
|
186
|
+
if max_seq_len is None:
|
187
|
+
raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings )must be specified")
|
172
188
|
|
173
189
|
input_info = [
|
174
190
|
(
|
175
191
|
"sample",
|
176
192
|
[
|
177
|
-
|
193
|
+
batch_size,
|
178
194
|
model_config.in_channels,
|
179
|
-
|
180
|
-
|
195
|
+
unet_sample_size[0],
|
196
|
+
unet_sample_size[1],
|
181
197
|
],
|
182
198
|
"float32",
|
183
199
|
),
|
@@ -189,23 +205,24 @@ class RBLNControlNetModel(RBLNModel):
|
|
189
205
|
input_info.append(
|
190
206
|
(
|
191
207
|
"encoder_hidden_states",
|
192
|
-
[
|
193
|
-
rbln_batch_size,
|
194
|
-
rbln_max_seq_len,
|
195
|
-
model_config.cross_attention_dim,
|
196
|
-
],
|
208
|
+
[batch_size, max_seq_len, model_config.cross_attention_dim],
|
197
209
|
"float32",
|
198
210
|
)
|
199
211
|
)
|
200
212
|
|
201
|
-
input_info.append(
|
213
|
+
input_info.append(
|
214
|
+
(
|
215
|
+
"controlnet_cond",
|
216
|
+
[batch_size, 3, vae_sample_size[0], vae_sample_size[1]],
|
217
|
+
"float32",
|
218
|
+
)
|
219
|
+
)
|
202
220
|
input_info.append(("conditioning_scale", [], "float32"))
|
203
221
|
|
204
222
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
205
|
-
|
206
|
-
|
207
|
-
input_info.append(("
|
208
|
-
input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
|
223
|
+
rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
|
224
|
+
input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
|
225
|
+
input_info.append(("time_ids", [batch_size, 6], "float32"))
|
209
226
|
|
210
227
|
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
211
228
|
|
@@ -215,18 +232,12 @@ class RBLNControlNetModel(RBLNModel):
|
|
215
232
|
rbln_kwargs=rbln_kwargs,
|
216
233
|
)
|
217
234
|
|
218
|
-
rbln_config.model_cfg.update(
|
219
|
-
{
|
220
|
-
"max_seq_len": rbln_max_seq_len,
|
221
|
-
"batch_size": rbln_batch_size,
|
222
|
-
"img_width": rbln_img_width,
|
223
|
-
"img_height": rbln_img_height,
|
224
|
-
"vae_scale_factor": rbln_vae_scale_factor,
|
225
|
-
}
|
226
|
-
)
|
227
|
-
|
228
235
|
return rbln_config
|
229
236
|
|
237
|
+
@property
|
238
|
+
def compiled_batch_size(self):
|
239
|
+
return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
|
240
|
+
|
230
241
|
def forward(
|
231
242
|
self,
|
232
243
|
sample: torch.FloatTensor,
|
@@ -237,9 +248,18 @@ class RBLNControlNetModel(RBLNModel):
|
|
237
248
|
added_cond_kwargs: Dict[str, torch.Tensor] = {},
|
238
249
|
**kwargs,
|
239
250
|
):
|
240
|
-
|
241
|
-
|
242
|
-
|
251
|
+
sample_batch_size = sample.size()[0]
|
252
|
+
compiled_batch_size = self.compiled_batch_size
|
253
|
+
if sample_batch_size != compiled_batch_size and (
|
254
|
+
sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
|
255
|
+
):
|
256
|
+
raise ValueError(
|
257
|
+
f"Mismatch between ControlNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
258
|
+
"This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
|
259
|
+
"Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
|
260
|
+
"For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
|
261
|
+
)
|
262
|
+
|
243
263
|
added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
|
244
264
|
if self.use_encoder_hidden_states:
|
245
265
|
output = super().forward(
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .transformer_sd3 import RBLNSD3Transformer2DModel
|
@@ -0,0 +1,203 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import logging
|
25
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
26
|
+
|
27
|
+
import torch
|
28
|
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
29
|
+
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
30
|
+
from transformers import PretrainedConfig
|
31
|
+
|
32
|
+
from ....modeling import RBLNModel
|
33
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
34
|
+
from ....modeling_diffusers import RBLNDiffusionMixin
|
35
|
+
|
36
|
+
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
39
|
+
|
40
|
+
logger = logging.getLogger(__name__)
|
41
|
+
|
42
|
+
|
43
|
+
class SD3Transformer2DModelWrapper(torch.nn.Module):
|
44
|
+
def __init__(self, model: "SD3Transformer2DModel") -> None:
|
45
|
+
super().__init__()
|
46
|
+
self.model = model
|
47
|
+
|
48
|
+
def forward(
|
49
|
+
self,
|
50
|
+
hidden_states: torch.FloatTensor,
|
51
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
52
|
+
pooled_projections: torch.FloatTensor = None,
|
53
|
+
timestep: torch.LongTensor = None,
|
54
|
+
# need controlnet support?
|
55
|
+
block_controlnet_hidden_states: List = None,
|
56
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
57
|
+
return_dict: bool = True,
|
58
|
+
):
|
59
|
+
return self.model(
|
60
|
+
hidden_states=hidden_states,
|
61
|
+
encoder_hidden_states=encoder_hidden_states,
|
62
|
+
pooled_projections=pooled_projections,
|
63
|
+
timestep=timestep,
|
64
|
+
return_dict=False,
|
65
|
+
)
|
66
|
+
|
67
|
+
|
68
|
+
class RBLNSD3Transformer2DModel(RBLNModel):
|
69
|
+
hf_library_name = "diffusers"
|
70
|
+
|
71
|
+
def __post_init__(self, **kwargs):
|
72
|
+
super().__post_init__(**kwargs)
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
76
|
+
return SD3Transformer2DModelWrapper(model).eval()
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
80
|
+
sample_size = rbln_config.get("sample_size", pipe.default_sample_size)
|
81
|
+
img_width = rbln_config.get("img_width")
|
82
|
+
img_height = rbln_config.get("img_height")
|
83
|
+
|
84
|
+
if (img_width is None) ^ (img_height is None):
|
85
|
+
raise RuntimeError
|
86
|
+
|
87
|
+
elif img_width and img_height:
|
88
|
+
sample_size = img_height // pipe.vae_scale_factor, img_width // pipe.vae_scale_factor
|
89
|
+
|
90
|
+
prompt_max_length = rbln_config.get("max_sequence_length", 256)
|
91
|
+
prompt_embed_length = pipe.tokenizer_max_length + prompt_max_length
|
92
|
+
|
93
|
+
batch_size = rbln_config.get("batch_size")
|
94
|
+
if not batch_size:
|
95
|
+
do_classifier_free_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0
|
96
|
+
batch_size = 2 if do_classifier_free_guidance else 1
|
97
|
+
else:
|
98
|
+
if rbln_config.get("guidance_scale"):
|
99
|
+
logger.warning(
|
100
|
+
"guidance_scale is ignored because batch size is explicitly specified. "
|
101
|
+
"To ensure consistent behavior, consider removing the guidance scale or "
|
102
|
+
"adjusting the batch size configuration as needed."
|
103
|
+
)
|
104
|
+
|
105
|
+
rbln_config.update(
|
106
|
+
{
|
107
|
+
"batch_size": batch_size,
|
108
|
+
"prompt_embed_length": prompt_embed_length,
|
109
|
+
"sample_size": sample_size,
|
110
|
+
}
|
111
|
+
)
|
112
|
+
|
113
|
+
return rbln_config
|
114
|
+
|
115
|
+
@classmethod
|
116
|
+
def _get_rbln_config(
|
117
|
+
cls,
|
118
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
119
|
+
model_config: "PretrainedConfig",
|
120
|
+
rbln_kwargs: Dict[str, Any] = {},
|
121
|
+
) -> RBLNConfig:
|
122
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
123
|
+
|
124
|
+
sample_size = rbln_kwargs.get("sample_size", model_config.sample_size)
|
125
|
+
if isinstance(sample_size, int):
|
126
|
+
sample_size = (sample_size, sample_size)
|
127
|
+
|
128
|
+
rbln_prompt_embed_length = rbln_kwargs.get("prompt_embed_length")
|
129
|
+
if rbln_prompt_embed_length is None:
|
130
|
+
raise ValueError("rbln_prompt_embed_length should be specified.")
|
131
|
+
|
132
|
+
input_info = [
|
133
|
+
(
|
134
|
+
"hidden_states",
|
135
|
+
[
|
136
|
+
rbln_batch_size,
|
137
|
+
model_config.in_channels,
|
138
|
+
sample_size[0],
|
139
|
+
sample_size[1],
|
140
|
+
],
|
141
|
+
"float32",
|
142
|
+
),
|
143
|
+
(
|
144
|
+
"encoder_hidden_states",
|
145
|
+
[
|
146
|
+
rbln_batch_size,
|
147
|
+
rbln_prompt_embed_length,
|
148
|
+
model_config.joint_attention_dim,
|
149
|
+
],
|
150
|
+
"float32",
|
151
|
+
),
|
152
|
+
(
|
153
|
+
"pooled_projections",
|
154
|
+
[
|
155
|
+
rbln_batch_size,
|
156
|
+
model_config.pooled_projection_dim,
|
157
|
+
],
|
158
|
+
"float32",
|
159
|
+
),
|
160
|
+
("timestep", [rbln_batch_size], "float32"),
|
161
|
+
]
|
162
|
+
|
163
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
164
|
+
|
165
|
+
rbln_config = RBLNConfig(
|
166
|
+
rbln_cls=cls.__name__,
|
167
|
+
compile_cfgs=[rbln_compile_config],
|
168
|
+
rbln_kwargs=rbln_kwargs,
|
169
|
+
)
|
170
|
+
|
171
|
+
rbln_config.model_cfg.update({"batch_size": rbln_batch_size})
|
172
|
+
|
173
|
+
return rbln_config
|
174
|
+
|
175
|
+
@property
|
176
|
+
def compiled_batch_size(self):
|
177
|
+
return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
|
178
|
+
|
179
|
+
def forward(
|
180
|
+
self,
|
181
|
+
hidden_states: torch.FloatTensor,
|
182
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
183
|
+
pooled_projections: torch.FloatTensor = None,
|
184
|
+
timestep: torch.LongTensor = None,
|
185
|
+
block_controlnet_hidden_states: List = None,
|
186
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
187
|
+
return_dict: bool = True,
|
188
|
+
**kwargs,
|
189
|
+
):
|
190
|
+
sample_batch_size = hidden_states.size()[0]
|
191
|
+
compiled_batch_size = self.compiled_batch_size
|
192
|
+
if sample_batch_size != compiled_batch_size and (
|
193
|
+
sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
|
194
|
+
):
|
195
|
+
raise ValueError(
|
196
|
+
f"Mismatch between Transformers' runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
197
|
+
"This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
|
198
|
+
"Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
|
199
|
+
"For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
|
200
|
+
)
|
201
|
+
|
202
|
+
sample = super().forward(hidden_states, encoder_hidden_states, pooled_projections, timestep)
|
203
|
+
return Transformer2DModelOutput(sample=sample)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .unet_2d_condition import RBLNUNet2DConditionModel
|