optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +41 -38
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +26 -2
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
- optimum/rbln/diffusers/models/__init__.py +36 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
- optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
- optimum/rbln/diffusers/pipelines/__init__.py +23 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/modeling.py +238 -0
- optimum/rbln/modeling_base.py +186 -760
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -2
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
- optimum/rbln/utils/decorator_utils.py +51 -11
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +22 -1
- optimum/rbln/utils/logging.py +37 -0
- optimum/rbln/utils/model_utils.py +52 -0
- optimum/rbln/utils/runtime_utils.py +10 -4
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +137 -0
- optimum_rbln-0.2.0.dist-info/METADATA +117 -0
- optimum_rbln-0.2.0.dist-info/RECORD +114 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
- optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/RECORD +0 -107
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -22,19 +22,18 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
25
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
26
26
|
|
27
27
|
import rebel
|
28
28
|
import torch # noqa: I001
|
29
29
|
from diffusers import AutoencoderKL
|
30
|
-
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
31
30
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
32
31
|
from transformers import PretrainedConfig
|
33
32
|
|
34
|
-
from
|
35
|
-
from
|
36
|
-
from ...
|
37
|
-
from
|
33
|
+
from ....modeling import RBLNModel
|
34
|
+
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
35
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
36
|
+
from .vae import RBLNRuntimeVAEDecoder, RBLNRuntimeVAEEncoder, _VAEDecoder, _VAEEncoder
|
38
37
|
|
39
38
|
|
40
39
|
if TYPE_CHECKING:
|
@@ -44,30 +43,22 @@ if TYPE_CHECKING:
|
|
44
43
|
logger = logging.getLogger(__name__)
|
45
44
|
|
46
45
|
|
47
|
-
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
48
|
-
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
49
|
-
moments = self.forward(x.contiguous())
|
50
|
-
posterior = DiagonalGaussianDistribution(moments)
|
51
|
-
return AutoencoderKLOutput(latent_dist=posterior)
|
52
|
-
|
53
|
-
|
54
|
-
class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
|
55
|
-
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
56
|
-
return (self.forward(z),)
|
57
|
-
|
58
|
-
|
59
46
|
class RBLNAutoencoderKL(RBLNModel):
|
47
|
+
auto_model_class = AutoencoderKL
|
60
48
|
config_name = "config.json"
|
49
|
+
hf_library_name = "diffusers"
|
61
50
|
|
62
51
|
def __post_init__(self, **kwargs):
|
63
52
|
super().__post_init__(**kwargs)
|
64
53
|
|
65
|
-
if self.rbln_config.model_cfg.get("img2img_pipeline"):
|
54
|
+
if self.rbln_config.model_cfg.get("img2img_pipeline") or self.rbln_config.model_cfg.get("inpaint_pipeline"):
|
66
55
|
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
67
56
|
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
|
68
57
|
else:
|
69
58
|
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[0], main_input_name="z")
|
70
59
|
|
60
|
+
self.image_size = self.rbln_config.model_cfg["sample_size"]
|
61
|
+
|
71
62
|
@classmethod
|
72
63
|
def get_compiled_model(cls, model, rbln_config: RBLNConfig):
|
73
64
|
def compile_img2img():
|
@@ -89,16 +80,53 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
89
80
|
|
90
81
|
return dec_compiled_model
|
91
82
|
|
92
|
-
if rbln_config.model_cfg.get("img2img_pipeline"):
|
83
|
+
if rbln_config.model_cfg.get("img2img_pipeline") or rbln_config.model_cfg.get("inpaint_pipeline"):
|
93
84
|
return compile_img2img()
|
94
85
|
else:
|
95
86
|
return compile_text2img()
|
96
87
|
|
97
88
|
@classmethod
|
98
|
-
def
|
99
|
-
|
100
|
-
|
101
|
-
|
89
|
+
def get_vae_sample_size(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
|
90
|
+
image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
|
91
|
+
noise_module = getattr(pipe, "unet", None) or getattr(pipe, "transformer", None)
|
92
|
+
vae_scale_factor = (
|
93
|
+
pipe.vae_scale_factor
|
94
|
+
if hasattr(pipe, "vae_scale_factor")
|
95
|
+
else 2 ** (len(pipe.vae.config.block_out_channels) - 1)
|
96
|
+
)
|
97
|
+
|
98
|
+
if noise_module is None:
|
99
|
+
raise AttributeError(
|
100
|
+
"Cannot find noise processing or predicting module attributes. ex. U-Net, Transformer, ..."
|
101
|
+
)
|
102
|
+
|
103
|
+
if (image_size[0] is None) != (image_size[1] is None):
|
104
|
+
raise ValueError("Both image height and image width must be given or not given")
|
105
|
+
|
106
|
+
elif image_size[0] is None and image_size[1] is None:
|
107
|
+
if rbln_config["img2img_pipeline"]:
|
108
|
+
sample_size = noise_module.config.sample_size
|
109
|
+
elif rbln_config["inpaint_pipeline"]:
|
110
|
+
sample_size = noise_module.config.sample_size * vae_scale_factor
|
111
|
+
else:
|
112
|
+
# In case of text2img, sample size of vae decoder is determined by unet.
|
113
|
+
noise_module_sample_size = noise_module.config.sample_size
|
114
|
+
if isinstance(noise_module_sample_size, int):
|
115
|
+
sample_size = noise_module_sample_size * vae_scale_factor
|
116
|
+
else:
|
117
|
+
sample_size = (
|
118
|
+
noise_module_sample_size[0] * vae_scale_factor,
|
119
|
+
noise_module_sample_size[1] * vae_scale_factor,
|
120
|
+
)
|
121
|
+
else:
|
122
|
+
sample_size = (image_size[0], image_size[1])
|
123
|
+
|
124
|
+
return sample_size
|
125
|
+
|
126
|
+
@classmethod
|
127
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
128
|
+
rbln_config.update({"sample_size": cls.get_vae_sample_size(pipe, rbln_config)})
|
129
|
+
return rbln_config
|
102
130
|
|
103
131
|
@classmethod
|
104
132
|
def _get_rbln_config(
|
@@ -109,6 +137,8 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
109
137
|
) -> RBLNConfig:
|
110
138
|
rbln_batch_size = rbln_kwargs.get("batch_size")
|
111
139
|
sample_size = rbln_kwargs.get("sample_size")
|
140
|
+
is_img2img = rbln_kwargs.get("img2img_pipeline")
|
141
|
+
is_inpaint = rbln_kwargs.get("inpaint_pipeline")
|
112
142
|
|
113
143
|
if rbln_batch_size is None:
|
114
144
|
rbln_batch_size = 1
|
@@ -119,6 +149,8 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
119
149
|
if isinstance(sample_size, int):
|
120
150
|
sample_size = (sample_size, sample_size)
|
121
151
|
|
152
|
+
rbln_kwargs["sample_size"] = sample_size
|
153
|
+
|
122
154
|
if hasattr(model_config, "block_out_channels"):
|
123
155
|
vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
124
156
|
else:
|
@@ -128,7 +160,7 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
128
160
|
dec_shape = (sample_size[0] // vae_scale_factor, sample_size[1] // vae_scale_factor)
|
129
161
|
enc_shape = (sample_size[0], sample_size[1])
|
130
162
|
|
131
|
-
if
|
163
|
+
if is_img2img or is_inpaint:
|
132
164
|
vae_enc_input_info = [
|
133
165
|
(
|
134
166
|
"x",
|
@@ -173,15 +205,28 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
173
205
|
|
174
206
|
@classmethod
|
175
207
|
def _create_runtimes(
|
176
|
-
cls,
|
208
|
+
cls,
|
209
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
210
|
+
rbln_device_map: Dict[str, int],
|
211
|
+
activate_profiler: Optional[bool] = None,
|
177
212
|
) -> List[rebel.Runtime]:
|
178
213
|
if len(compiled_models) == 1:
|
214
|
+
if DEFAULT_COMPILED_MODEL_NAME not in rbln_device_map:
|
215
|
+
cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
|
216
|
+
|
179
217
|
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
180
|
-
return [
|
218
|
+
return [
|
219
|
+
compiled_models[0].create_runtime(
|
220
|
+
tensor_type="pt", device=device_val, activate_profiler=activate_profiler
|
221
|
+
)
|
222
|
+
]
|
223
|
+
|
224
|
+
if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
|
225
|
+
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
181
226
|
|
182
227
|
device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
|
183
228
|
return [
|
184
|
-
compiled_model.create_runtime(tensor_type="pt", device=device_val)
|
229
|
+
compiled_model.create_runtime(tensor_type="pt", device=device_val, activate_profiler=activate_profiler)
|
185
230
|
for compiled_model, device_val in zip(compiled_models, device_vals)
|
186
231
|
]
|
187
232
|
|
@@ -191,36 +236,3 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
191
236
|
|
192
237
|
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
193
238
|
return self.decoder.decode(z)
|
194
|
-
|
195
|
-
|
196
|
-
class _VAEDecoder(torch.nn.Module):
|
197
|
-
def __init__(self, vae: "AutoencoderKL"):
|
198
|
-
super().__init__()
|
199
|
-
self.vae = vae
|
200
|
-
|
201
|
-
def forward(self, z):
|
202
|
-
vae_out = self.vae.decode(z, return_dict=False)
|
203
|
-
return vae_out
|
204
|
-
|
205
|
-
|
206
|
-
class _VAEEncoder(torch.nn.Module):
|
207
|
-
def __init__(self, vae: "AutoencoderKL"):
|
208
|
-
super().__init__()
|
209
|
-
self.vae = vae
|
210
|
-
|
211
|
-
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
212
|
-
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
213
|
-
return self.tiled_encode(x, return_dict=return_dict)
|
214
|
-
|
215
|
-
if self.use_slicing and x.shape[0] > 1:
|
216
|
-
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
217
|
-
h = torch.cat(encoded_slices)
|
218
|
-
else:
|
219
|
-
h = self.encoder(x)
|
220
|
-
|
221
|
-
moments = self.quant_conv(h)
|
222
|
-
return moments
|
223
|
-
|
224
|
-
def forward(self, x):
|
225
|
-
vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
|
226
|
-
return vae_out
|
@@ -0,0 +1,83 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import logging
|
25
|
+
from typing import TYPE_CHECKING
|
26
|
+
|
27
|
+
import torch # noqa: I001
|
28
|
+
from diffusers import AutoencoderKL
|
29
|
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
30
|
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
31
|
+
|
32
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
33
|
+
|
34
|
+
|
35
|
+
if TYPE_CHECKING:
|
36
|
+
import torch
|
37
|
+
|
38
|
+
logger = logging.getLogger(__name__)
|
39
|
+
|
40
|
+
|
41
|
+
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
42
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
43
|
+
moments = self.forward(x.contiguous())
|
44
|
+
posterior = DiagonalGaussianDistribution(moments)
|
45
|
+
return AutoencoderKLOutput(latent_dist=posterior)
|
46
|
+
|
47
|
+
|
48
|
+
class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
|
49
|
+
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
50
|
+
return (self.forward(z),)
|
51
|
+
|
52
|
+
|
53
|
+
class _VAEDecoder(torch.nn.Module):
|
54
|
+
def __init__(self, vae: "AutoencoderKL"):
|
55
|
+
super().__init__()
|
56
|
+
self.vae = vae
|
57
|
+
|
58
|
+
def forward(self, z):
|
59
|
+
vae_out = self.vae.decode(z, return_dict=False)
|
60
|
+
return vae_out
|
61
|
+
|
62
|
+
|
63
|
+
class _VAEEncoder(torch.nn.Module):
|
64
|
+
def __init__(self, vae: "AutoencoderKL"):
|
65
|
+
super().__init__()
|
66
|
+
self.vae = vae
|
67
|
+
|
68
|
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
69
|
+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
70
|
+
return self.tiled_encode(x, return_dict=return_dict)
|
71
|
+
|
72
|
+
if self.use_slicing and x.shape[0] > 1:
|
73
|
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
74
|
+
h = torch.cat(encoded_slices)
|
75
|
+
else:
|
76
|
+
h = self.encoder(x)
|
77
|
+
if self.quant_conv is not None:
|
78
|
+
h = self.quant_conv(h)
|
79
|
+
return h
|
80
|
+
|
81
|
+
def forward(self, x):
|
82
|
+
vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
|
83
|
+
return vae_out
|
@@ -21,6 +21,7 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
import importlib
|
24
25
|
import logging
|
25
26
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
26
27
|
|
@@ -28,9 +29,9 @@ import torch
|
|
28
29
|
from diffusers import ControlNetModel
|
29
30
|
from transformers import PretrainedConfig
|
30
31
|
|
31
|
-
from ...
|
32
|
+
from ...modeling import RBLNModel
|
32
33
|
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
33
|
-
from
|
34
|
+
from ..modeling_diffusers import RBLNDiffusionMixin
|
34
35
|
|
35
36
|
|
36
37
|
if TYPE_CHECKING:
|
@@ -104,21 +105,15 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
|
|
104
105
|
|
105
106
|
|
106
107
|
class RBLNControlNetModel(RBLNModel):
|
108
|
+
hf_library_name = "diffusers"
|
109
|
+
auto_model_class = ControlNetModel
|
110
|
+
|
107
111
|
def __post_init__(self, **kwargs):
|
108
112
|
super().__post_init__(**kwargs)
|
109
113
|
self.use_encoder_hidden_states = any(
|
110
114
|
item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
|
111
115
|
)
|
112
116
|
|
113
|
-
@classmethod
|
114
|
-
def from_pretrained(cls, *args, **kwargs):
|
115
|
-
with override_auto_classes(
|
116
|
-
config_func=ControlNetModel.load_config,
|
117
|
-
model_func=ControlNetModel.from_pretrained,
|
118
|
-
):
|
119
|
-
rt = super().from_pretrained(*args, **kwargs)
|
120
|
-
return rt
|
121
|
-
|
122
117
|
@classmethod
|
123
118
|
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
124
119
|
use_encoder_hidden_states = False
|
@@ -131,6 +126,38 @@ class RBLNControlNetModel(RBLNModel):
|
|
131
126
|
else:
|
132
127
|
return _ControlNetModel(model).eval()
|
133
128
|
|
129
|
+
@classmethod
|
130
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
131
|
+
rbln_vae_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.vae.__class__.__name__}")
|
132
|
+
rbln_unet_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.unet.__class__.__name__}")
|
133
|
+
text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
|
134
|
+
|
135
|
+
batch_size = rbln_config.get("batch_size")
|
136
|
+
if not batch_size:
|
137
|
+
do_classifier_free_guidance = (
|
138
|
+
rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
|
139
|
+
)
|
140
|
+
batch_size = 2 if do_classifier_free_guidance else 1
|
141
|
+
else:
|
142
|
+
if rbln_config.get("guidance_scale"):
|
143
|
+
logger.warning(
|
144
|
+
"guidance_scale is ignored because batch size is explicitly specified. "
|
145
|
+
"To ensure consistent behavior, consider removing the guidance scale or "
|
146
|
+
"adjusting the batch size configuration as needed."
|
147
|
+
)
|
148
|
+
|
149
|
+
rbln_config.update(
|
150
|
+
{
|
151
|
+
"max_seq_len": pipe.text_encoder.config.max_position_embeddings,
|
152
|
+
"text_model_hidden_size": text_model_hidden_size,
|
153
|
+
"vae_sample_size": rbln_vae_cls.get_vae_sample_size(pipe, rbln_config),
|
154
|
+
"unet_sample_size": rbln_unet_cls.get_unet_sample_size(pipe, rbln_config),
|
155
|
+
"batch_size": batch_size,
|
156
|
+
}
|
157
|
+
)
|
158
|
+
|
159
|
+
return rbln_config
|
160
|
+
|
134
161
|
@classmethod
|
135
162
|
def _get_rbln_config(
|
136
163
|
cls,
|
@@ -207,6 +234,10 @@ class RBLNControlNetModel(RBLNModel):
|
|
207
234
|
|
208
235
|
return rbln_config
|
209
236
|
|
237
|
+
@property
|
238
|
+
def compiled_batch_size(self):
|
239
|
+
return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
|
240
|
+
|
210
241
|
def forward(
|
211
242
|
self,
|
212
243
|
sample: torch.FloatTensor,
|
@@ -217,9 +248,18 @@ class RBLNControlNetModel(RBLNModel):
|
|
217
248
|
added_cond_kwargs: Dict[str, torch.Tensor] = {},
|
218
249
|
**kwargs,
|
219
250
|
):
|
220
|
-
|
221
|
-
|
222
|
-
|
251
|
+
sample_batch_size = sample.size()[0]
|
252
|
+
compiled_batch_size = self.compiled_batch_size
|
253
|
+
if sample_batch_size != compiled_batch_size and (
|
254
|
+
sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
|
255
|
+
):
|
256
|
+
raise ValueError(
|
257
|
+
f"Mismatch between ControlNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
258
|
+
"This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
|
259
|
+
"Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
|
260
|
+
"For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
|
261
|
+
)
|
262
|
+
|
223
263
|
added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
|
224
264
|
if self.use_encoder_hidden_states:
|
225
265
|
output = super().forward(
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .transformer_sd3 import RBLNSD3Transformer2DModel
|
@@ -0,0 +1,203 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import logging
|
25
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
26
|
+
|
27
|
+
import torch
|
28
|
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
29
|
+
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
30
|
+
from transformers import PretrainedConfig
|
31
|
+
|
32
|
+
from ....modeling import RBLNModel
|
33
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
34
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
35
|
+
|
36
|
+
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
39
|
+
|
40
|
+
logger = logging.getLogger(__name__)
|
41
|
+
|
42
|
+
|
43
|
+
class SD3Transformer2DModelWrapper(torch.nn.Module):
|
44
|
+
def __init__(self, model: "SD3Transformer2DModel") -> None:
|
45
|
+
super().__init__()
|
46
|
+
self.model = model
|
47
|
+
|
48
|
+
def forward(
|
49
|
+
self,
|
50
|
+
hidden_states: torch.FloatTensor,
|
51
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
52
|
+
pooled_projections: torch.FloatTensor = None,
|
53
|
+
timestep: torch.LongTensor = None,
|
54
|
+
# need controlnet support?
|
55
|
+
block_controlnet_hidden_states: List = None,
|
56
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
57
|
+
return_dict: bool = True,
|
58
|
+
):
|
59
|
+
return self.model(
|
60
|
+
hidden_states=hidden_states,
|
61
|
+
encoder_hidden_states=encoder_hidden_states,
|
62
|
+
pooled_projections=pooled_projections,
|
63
|
+
timestep=timestep,
|
64
|
+
return_dict=False,
|
65
|
+
)
|
66
|
+
|
67
|
+
|
68
|
+
class RBLNSD3Transformer2DModel(RBLNModel):
|
69
|
+
hf_library_name = "diffusers"
|
70
|
+
|
71
|
+
def __post_init__(self, **kwargs):
|
72
|
+
super().__post_init__(**kwargs)
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
76
|
+
return SD3Transformer2DModelWrapper(model).eval()
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
80
|
+
sample_size = rbln_config.get("sample_size", pipe.default_sample_size)
|
81
|
+
img_width = rbln_config.get("img_width")
|
82
|
+
img_height = rbln_config.get("img_height")
|
83
|
+
|
84
|
+
if (img_width is None) ^ (img_height is None):
|
85
|
+
raise RuntimeError
|
86
|
+
|
87
|
+
elif img_width and img_height:
|
88
|
+
sample_size = img_height // pipe.vae_scale_factor, img_width // pipe.vae_scale_factor
|
89
|
+
|
90
|
+
prompt_max_length = rbln_config.get("max_sequence_length", 256)
|
91
|
+
prompt_embed_length = pipe.tokenizer_max_length + prompt_max_length
|
92
|
+
|
93
|
+
batch_size = rbln_config.get("batch_size")
|
94
|
+
if not batch_size:
|
95
|
+
do_classifier_free_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0
|
96
|
+
batch_size = 2 if do_classifier_free_guidance else 1
|
97
|
+
else:
|
98
|
+
if rbln_config.get("guidance_scale"):
|
99
|
+
logger.warning(
|
100
|
+
"guidance_scale is ignored because batch size is explicitly specified. "
|
101
|
+
"To ensure consistent behavior, consider removing the guidance scale or "
|
102
|
+
"adjusting the batch size configuration as needed."
|
103
|
+
)
|
104
|
+
|
105
|
+
rbln_config.update(
|
106
|
+
{
|
107
|
+
"batch_size": batch_size,
|
108
|
+
"prompt_embed_length": prompt_embed_length,
|
109
|
+
"sample_size": sample_size,
|
110
|
+
}
|
111
|
+
)
|
112
|
+
|
113
|
+
return rbln_config
|
114
|
+
|
115
|
+
@classmethod
|
116
|
+
def _get_rbln_config(
|
117
|
+
cls,
|
118
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
119
|
+
model_config: "PretrainedConfig",
|
120
|
+
rbln_kwargs: Dict[str, Any] = {},
|
121
|
+
) -> RBLNConfig:
|
122
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
123
|
+
|
124
|
+
sample_size = rbln_kwargs.get("sample_size", model_config.sample_size)
|
125
|
+
if isinstance(sample_size, int):
|
126
|
+
sample_size = (sample_size, sample_size)
|
127
|
+
|
128
|
+
rbln_prompt_embed_length = rbln_kwargs.get("prompt_embed_length")
|
129
|
+
if rbln_prompt_embed_length is None:
|
130
|
+
raise ValueError("rbln_prompt_embed_length should be specified.")
|
131
|
+
|
132
|
+
input_info = [
|
133
|
+
(
|
134
|
+
"hidden_states",
|
135
|
+
[
|
136
|
+
rbln_batch_size,
|
137
|
+
model_config.in_channels,
|
138
|
+
sample_size[0],
|
139
|
+
sample_size[1],
|
140
|
+
],
|
141
|
+
"float32",
|
142
|
+
),
|
143
|
+
(
|
144
|
+
"encoder_hidden_states",
|
145
|
+
[
|
146
|
+
rbln_batch_size,
|
147
|
+
rbln_prompt_embed_length,
|
148
|
+
model_config.joint_attention_dim,
|
149
|
+
],
|
150
|
+
"float32",
|
151
|
+
),
|
152
|
+
(
|
153
|
+
"pooled_projections",
|
154
|
+
[
|
155
|
+
rbln_batch_size,
|
156
|
+
model_config.pooled_projection_dim,
|
157
|
+
],
|
158
|
+
"float32",
|
159
|
+
),
|
160
|
+
("timestep", [rbln_batch_size], "float32"),
|
161
|
+
]
|
162
|
+
|
163
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
164
|
+
|
165
|
+
rbln_config = RBLNConfig(
|
166
|
+
rbln_cls=cls.__name__,
|
167
|
+
compile_cfgs=[rbln_compile_config],
|
168
|
+
rbln_kwargs=rbln_kwargs,
|
169
|
+
)
|
170
|
+
|
171
|
+
rbln_config.model_cfg.update({"batch_size": rbln_batch_size})
|
172
|
+
|
173
|
+
return rbln_config
|
174
|
+
|
175
|
+
@property
|
176
|
+
def compiled_batch_size(self):
|
177
|
+
return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
|
178
|
+
|
179
|
+
def forward(
|
180
|
+
self,
|
181
|
+
hidden_states: torch.FloatTensor,
|
182
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
183
|
+
pooled_projections: torch.FloatTensor = None,
|
184
|
+
timestep: torch.LongTensor = None,
|
185
|
+
block_controlnet_hidden_states: List = None,
|
186
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
187
|
+
return_dict: bool = True,
|
188
|
+
**kwargs,
|
189
|
+
):
|
190
|
+
sample_batch_size = hidden_states.size()[0]
|
191
|
+
compiled_batch_size = self.compiled_batch_size
|
192
|
+
if sample_batch_size != compiled_batch_size and (
|
193
|
+
sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
|
194
|
+
):
|
195
|
+
raise ValueError(
|
196
|
+
f"Mismatch between Transformers' runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
197
|
+
"This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
|
198
|
+
"Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
|
199
|
+
"For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
|
200
|
+
)
|
201
|
+
|
202
|
+
sample = super().forward(hidden_states, encoder_hidden_states, pooled_projections, timestep)
|
203
|
+
return Transformer2DModelOutput(sample=sample)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .unet_2d_condition import RBLNUNet2DConditionModel
|