optimum-rbln 0.2.1a4__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -2
- optimum/rbln/__version__.py +9 -4
- optimum/rbln/diffusers/__init__.py +10 -0
- optimum/rbln/diffusers/modeling_diffusers.py +132 -25
- optimum/rbln/diffusers/models/__init__.py +7 -1
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +52 -2
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +159 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +174 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +57 -14
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +83 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +22 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +22 -0
- optimum/rbln/modeling_base.py +10 -9
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/models/__init__.py +12 -2
- optimum/rbln/transformers/models/clip/__init__.py +6 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +26 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +3 -1
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +1 -1
- optimum/rbln/utils/import_utils.py +7 -0
- {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/METADATA +1 -1
- {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/RECORD +28 -22
- {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,174 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from pathlib import Path
|
16
|
+
from typing import Any, Dict, Optional, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from diffusers.models.transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
|
20
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
21
|
+
|
22
|
+
from ....modeling import RBLNModel
|
23
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
24
|
+
from ....utils.logging import get_logger
|
25
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
26
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
27
|
+
|
28
|
+
|
29
|
+
logger = get_logger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
class RBLNRuntimePriorTransformer(RBLNPytorchRuntime):
|
33
|
+
def forward(
|
34
|
+
self, hidden_states, timestep, proj_embedding, encoder_hidden_states, attention_mask, return_dict: bool = True
|
35
|
+
):
|
36
|
+
predicted_image_embedding = super().forward(
|
37
|
+
hidden_states,
|
38
|
+
timestep,
|
39
|
+
proj_embedding,
|
40
|
+
encoder_hidden_states,
|
41
|
+
attention_mask,
|
42
|
+
)
|
43
|
+
if return_dict:
|
44
|
+
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
45
|
+
else:
|
46
|
+
return (predicted_image_embedding,)
|
47
|
+
|
48
|
+
|
49
|
+
class _PriorTransformer(torch.nn.Module):
|
50
|
+
def __init__(self, prior: PriorTransformer):
|
51
|
+
super().__init__()
|
52
|
+
self._prior = prior
|
53
|
+
|
54
|
+
def forward(
|
55
|
+
self,
|
56
|
+
hidden_states,
|
57
|
+
timestep,
|
58
|
+
proj_embedding,
|
59
|
+
encoder_hidden_states,
|
60
|
+
attention_mask,
|
61
|
+
return_dict=True,
|
62
|
+
):
|
63
|
+
return self._prior.forward(
|
64
|
+
hidden_states,
|
65
|
+
timestep,
|
66
|
+
proj_embedding,
|
67
|
+
encoder_hidden_states,
|
68
|
+
attention_mask,
|
69
|
+
return_dict=False,
|
70
|
+
)
|
71
|
+
|
72
|
+
|
73
|
+
class RBLNPriorTransformer(RBLNModel):
|
74
|
+
hf_library_name = "diffusers"
|
75
|
+
auto_model_class = PriorTransformer
|
76
|
+
|
77
|
+
def __post_init__(self, **kwargs):
|
78
|
+
super().__post_init__(**kwargs)
|
79
|
+
self.runtime = RBLNRuntimePriorTransformer(runtime=self.model[0])
|
80
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
81
|
+
self.clip_mean = artifacts["clip_mean"]
|
82
|
+
self.clip_std = artifacts["clip_std"]
|
83
|
+
|
84
|
+
@classmethod
|
85
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
86
|
+
return _PriorTransformer(model).eval()
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
90
|
+
batch_size = rbln_config.get("batch_size")
|
91
|
+
if not batch_size:
|
92
|
+
do_classifier_free_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0
|
93
|
+
batch_size = 2 if do_classifier_free_guidance else 1
|
94
|
+
else:
|
95
|
+
if rbln_config.get("guidance_scale"):
|
96
|
+
logger.warning(
|
97
|
+
"guidance_scale is ignored because batch size is explicitly specified. "
|
98
|
+
"To ensure consistent behavior, consider removing the guidance scale or "
|
99
|
+
"adjusting the batch size configuration as needed."
|
100
|
+
)
|
101
|
+
embedding_dim = rbln_config.get("embedding_dim", pipe.prior.config.embedding_dim)
|
102
|
+
num_embeddings = rbln_config.get("num_embeddings", pipe.prior.config.num_embeddings)
|
103
|
+
|
104
|
+
rbln_config.update(
|
105
|
+
{
|
106
|
+
"batch_size": batch_size,
|
107
|
+
"embedding_dim": embedding_dim,
|
108
|
+
"num_embeddings": num_embeddings,
|
109
|
+
}
|
110
|
+
)
|
111
|
+
|
112
|
+
return rbln_config
|
113
|
+
|
114
|
+
@classmethod
|
115
|
+
def save_torch_artifacts(
|
116
|
+
cls,
|
117
|
+
model: "PreTrainedModel",
|
118
|
+
save_dir_path: Path,
|
119
|
+
subfolder: str,
|
120
|
+
rbln_config: RBLNConfig,
|
121
|
+
):
|
122
|
+
save_dict = {}
|
123
|
+
save_dict["clip_mean"] = model.clip_mean
|
124
|
+
save_dict["clip_std"] = model.clip_std
|
125
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
126
|
+
|
127
|
+
@classmethod
|
128
|
+
def _get_rbln_config(
|
129
|
+
cls,
|
130
|
+
preprocessors,
|
131
|
+
model_config: PretrainedConfig,
|
132
|
+
rbln_kwargs,
|
133
|
+
) -> RBLNConfig:
|
134
|
+
batch_size = rbln_kwargs.get("batch_size") or 1
|
135
|
+
embedding_dim = rbln_kwargs.get("embedding_dim") or model_config.embedding_dim
|
136
|
+
num_embeddings = rbln_kwargs.get("num_embeddings") or model_config.num_embeddings
|
137
|
+
|
138
|
+
input_info = [
|
139
|
+
("hidden_states", [batch_size, embedding_dim], "float32"),
|
140
|
+
("timestep", [], "float32"),
|
141
|
+
("proj_embedding", [batch_size, embedding_dim], "float32"),
|
142
|
+
("encoder_hidden_states", [batch_size, num_embeddings, embedding_dim], "float32"),
|
143
|
+
("attention_mask", [batch_size, num_embeddings], "float32"),
|
144
|
+
]
|
145
|
+
|
146
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
147
|
+
rbln_config = RBLNConfig(
|
148
|
+
rbln_cls=cls.__name__,
|
149
|
+
compile_cfgs=[rbln_compile_config],
|
150
|
+
rbln_kwargs=rbln_kwargs,
|
151
|
+
)
|
152
|
+
return rbln_config
|
153
|
+
|
154
|
+
def forward(
|
155
|
+
self,
|
156
|
+
hidden_states,
|
157
|
+
timestep: Union[torch.Tensor, float, int],
|
158
|
+
proj_embedding: torch.Tensor,
|
159
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
160
|
+
attention_mask: Optional[torch.BoolTensor] = None,
|
161
|
+
return_dict: bool = True,
|
162
|
+
):
|
163
|
+
return self.runtime.forward(
|
164
|
+
hidden_states.contiguous(),
|
165
|
+
timestep.float(),
|
166
|
+
proj_embedding,
|
167
|
+
encoder_hidden_states,
|
168
|
+
attention_mask.float(),
|
169
|
+
return_dict,
|
170
|
+
)
|
171
|
+
|
172
|
+
def post_process_latents(self, prior_latents):
|
173
|
+
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
174
|
+
return prior_latents
|
@@ -115,6 +115,29 @@ class _UNet_SDXL(torch.nn.Module):
|
|
115
115
|
return unet_out
|
116
116
|
|
117
117
|
|
118
|
+
class _UNet_Kandinsky(torch.nn.Module):
|
119
|
+
def __init__(self, unet: "UNet2DConditionModel"):
|
120
|
+
super().__init__()
|
121
|
+
self.unet = unet
|
122
|
+
|
123
|
+
def forward(
|
124
|
+
self,
|
125
|
+
sample: torch.Tensor,
|
126
|
+
timestep: Union[torch.Tensor, float, int],
|
127
|
+
image_embeds: torch.Tensor,
|
128
|
+
) -> torch.Tensor:
|
129
|
+
added_cond_kwargs = {"image_embeds": image_embeds}
|
130
|
+
|
131
|
+
unet_out = self.unet(
|
132
|
+
sample=sample,
|
133
|
+
timestep=timestep,
|
134
|
+
encoder_hidden_states=None,
|
135
|
+
added_cond_kwargs=added_cond_kwargs,
|
136
|
+
return_dict=False,
|
137
|
+
)
|
138
|
+
return unet_out
|
139
|
+
|
140
|
+
|
118
141
|
class RBLNUNet2DConditionModel(RBLNModel):
|
119
142
|
hf_library_name = "diffusers"
|
120
143
|
auto_model_class = UNet2DConditionModel
|
@@ -138,6 +161,8 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
138
161
|
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
139
162
|
if model.config.addition_embed_type == "text_time":
|
140
163
|
return _UNet_SDXL(model).eval()
|
164
|
+
elif model.config.addition_embed_type == "image":
|
165
|
+
return _UNet_Kandinsky(model).eval()
|
141
166
|
else:
|
142
167
|
return _UNet_SD(model).eval()
|
143
168
|
|
@@ -146,6 +171,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
146
171
|
cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]
|
147
172
|
) -> Union[int, Tuple[int, int]]:
|
148
173
|
image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
|
174
|
+
scale_factor = pipe.movq_scale_factor if hasattr(pipe, "movq_scale_factor") else pipe.vae_scale_factor
|
149
175
|
if (image_size[0] is None) != (image_size[1] is None):
|
150
176
|
raise ValueError("Both image height and image width must be given or not given")
|
151
177
|
elif image_size[0] is None and image_size[1] is None:
|
@@ -153,22 +179,23 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
153
179
|
# In case of img2img, sample size of unet is determined by vae encoder.
|
154
180
|
vae_sample_size = pipe.vae.config.sample_size
|
155
181
|
if isinstance(vae_sample_size, int):
|
156
|
-
sample_size = vae_sample_size //
|
182
|
+
sample_size = vae_sample_size // scale_factor
|
157
183
|
else:
|
158
184
|
sample_size = (
|
159
|
-
vae_sample_size[0] //
|
160
|
-
vae_sample_size[1] //
|
185
|
+
vae_sample_size[0] // scale_factor,
|
186
|
+
vae_sample_size[1] // scale_factor,
|
161
187
|
)
|
162
188
|
else:
|
163
189
|
sample_size = pipe.unet.config.sample_size
|
164
190
|
else:
|
165
|
-
sample_size = (image_size[0] //
|
191
|
+
sample_size = (image_size[0] // scale_factor, image_size[1] // scale_factor)
|
166
192
|
|
167
193
|
return sample_size
|
168
194
|
|
169
195
|
@classmethod
|
170
196
|
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
171
197
|
text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
|
198
|
+
image_model_hidden_size = pipe.unet.config.encoder_hid_dim if hasattr(pipe, "unet") else None
|
172
199
|
|
173
200
|
batch_size = rbln_config.get("batch_size")
|
174
201
|
if not batch_size:
|
@@ -184,10 +211,12 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
184
211
|
"adjusting the batch size configuration as needed."
|
185
212
|
)
|
186
213
|
|
214
|
+
max_seq_len = pipe.text_encoder.config.max_position_embeddings if hasattr(pipe, "text_encoder") else None
|
187
215
|
rbln_config.update(
|
188
216
|
{
|
189
|
-
"max_seq_len":
|
217
|
+
"max_seq_len": max_seq_len,
|
190
218
|
"text_model_hidden_size": text_model_hidden_size,
|
219
|
+
"image_model_hidden_size": image_model_hidden_size,
|
191
220
|
"sample_size": cls.get_unet_sample_size(pipe, rbln_config),
|
192
221
|
"batch_size": batch_size,
|
193
222
|
"is_controlnet": "controlnet" in pipe.config.keys(),
|
@@ -218,15 +247,16 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
218
247
|
if isinstance(sample_size, int):
|
219
248
|
sample_size = (sample_size, sample_size)
|
220
249
|
|
221
|
-
if max_seq_len is None:
|
222
|
-
raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings) must be specified.")
|
223
|
-
|
224
250
|
input_info = [
|
225
251
|
("sample", [batch_size, model_config.in_channels, sample_size[0], sample_size[1]], "float32"),
|
226
252
|
("timestep", [], "float32"),
|
227
|
-
("encoder_hidden_states", [batch_size, max_seq_len, model_config.cross_attention_dim], "float32"),
|
228
253
|
]
|
229
254
|
|
255
|
+
if max_seq_len is not None:
|
256
|
+
input_info.append(
|
257
|
+
("encoder_hidden_states", [batch_size, max_seq_len, model_config.cross_attention_dim], "float32"),
|
258
|
+
)
|
259
|
+
|
230
260
|
if is_controlnet:
|
231
261
|
# down block addtional residuals
|
232
262
|
first_shape = [batch_size, model_config.block_out_channels[0], sample_size[0], sample_size[1]]
|
@@ -256,11 +286,15 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
256
286
|
]
|
257
287
|
input_info.append(("mid_block_additional_residual", shape, "float32"))
|
258
288
|
|
259
|
-
if hasattr(model_config, "addition_embed_type")
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
289
|
+
if hasattr(model_config, "addition_embed_type"):
|
290
|
+
if model_config.addition_embed_type == "text_time":
|
291
|
+
rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
|
292
|
+
rbln_in_features = model_config.projection_class_embeddings_input_dim
|
293
|
+
input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
|
294
|
+
input_info.append(("time_ids", [batch_size, 6], "float32"))
|
295
|
+
elif model_config.addition_embed_type == "image":
|
296
|
+
rbln_image_model_hidden_size = rbln_kwargs["image_model_hidden_size"]
|
297
|
+
input_info.append(("image_embeds", [batch_size, rbln_image_model_hidden_size], "float32"))
|
264
298
|
|
265
299
|
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
266
300
|
|
@@ -323,6 +357,15 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
323
357
|
),
|
324
358
|
)
|
325
359
|
|
360
|
+
if "image_embeds" in added_cond_kwargs:
|
361
|
+
return (
|
362
|
+
super().forward(
|
363
|
+
sample.contiguous(),
|
364
|
+
timestep.float(),
|
365
|
+
**added_cond_kwargs,
|
366
|
+
),
|
367
|
+
)
|
368
|
+
|
326
369
|
return (
|
327
370
|
super().forward(
|
328
371
|
sample.contiguous(),
|
@@ -25,6 +25,11 @@ _import_structure = {
|
|
25
25
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
26
26
|
"RBLNStableDiffusionXLControlNetPipeline",
|
27
27
|
],
|
28
|
+
"kandinsky2_2": [
|
29
|
+
"RBLNKandinskyV22InpaintCombinedPipeline",
|
30
|
+
"RBLNKandinskyV22InpaintPipeline",
|
31
|
+
"RBLNKandinskyV22PriorPipeline",
|
32
|
+
],
|
28
33
|
"stable_diffusion": [
|
29
34
|
"RBLNStableDiffusionImg2ImgPipeline",
|
30
35
|
"RBLNStableDiffusionPipeline",
|
@@ -49,6 +54,11 @@ if TYPE_CHECKING:
|
|
49
54
|
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
50
55
|
RBLNStableDiffusionXLControlNetPipeline,
|
51
56
|
)
|
57
|
+
from .kandinsky2_2 import (
|
58
|
+
RBLNKandinskyV22InpaintCombinedPipeline,
|
59
|
+
RBLNKandinskyV22InpaintPipeline,
|
60
|
+
RBLNKandinskyV22PriorPipeline,
|
61
|
+
)
|
52
62
|
from .stable_diffusion import (
|
53
63
|
RBLNStableDiffusionImg2ImgPipeline,
|
54
64
|
RBLNStableDiffusionInpaintPipeline,
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from .pipeline_kandinsky2_2_combined import RBLNKandinskyV22InpaintCombinedPipeline
|
16
|
+
from .pipeline_kandinsky2_2_inpaint import RBLNKandinskyV22InpaintPipeline
|
17
|
+
from .pipeline_kandinsky2_2_prior import RBLNKandinskyV22PriorPipeline
|
@@ -0,0 +1,83 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from diffusers import (
|
16
|
+
DDPMScheduler,
|
17
|
+
KandinskyV22InpaintCombinedPipeline,
|
18
|
+
PriorTransformer,
|
19
|
+
UnCLIPScheduler,
|
20
|
+
UNet2DConditionModel,
|
21
|
+
VQModel,
|
22
|
+
)
|
23
|
+
from transformers import (
|
24
|
+
CLIPImageProcessor,
|
25
|
+
CLIPTextModelWithProjection,
|
26
|
+
CLIPTokenizer,
|
27
|
+
CLIPVisionModelWithProjection,
|
28
|
+
)
|
29
|
+
|
30
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
31
|
+
from .pipeline_kandinsky2_2_inpaint import RBLNKandinskyV22InpaintPipeline
|
32
|
+
from .pipeline_kandinsky2_2_prior import RBLNKandinskyV22PriorPipeline
|
33
|
+
|
34
|
+
|
35
|
+
class RBLNKandinskyV22InpaintCombinedPipeline(RBLNDiffusionMixin, KandinskyV22InpaintCombinedPipeline):
|
36
|
+
original_class = KandinskyV22InpaintCombinedPipeline
|
37
|
+
_connected_classes = {"prior_pipe": RBLNKandinskyV22PriorPipeline, "decoder_pipe": RBLNKandinskyV22InpaintPipeline}
|
38
|
+
_submodules = ["prior_pipe", "decoder_pipe"]
|
39
|
+
_prefix = {"prior_pipe": "prior_"}
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
unet: "UNet2DConditionModel",
|
44
|
+
scheduler: "DDPMScheduler",
|
45
|
+
movq: "VQModel",
|
46
|
+
prior_prior: "PriorTransformer",
|
47
|
+
prior_image_encoder: "CLIPVisionModelWithProjection",
|
48
|
+
prior_text_encoder: "CLIPTextModelWithProjection",
|
49
|
+
prior_tokenizer: "CLIPTokenizer",
|
50
|
+
prior_scheduler: "UnCLIPScheduler",
|
51
|
+
prior_image_processor: "CLIPImageProcessor",
|
52
|
+
):
|
53
|
+
RBLNDiffusionMixin.__init__(self)
|
54
|
+
super(KandinskyV22InpaintCombinedPipeline, self).__init__()
|
55
|
+
|
56
|
+
self.register_modules(
|
57
|
+
unet=unet,
|
58
|
+
scheduler=scheduler,
|
59
|
+
movq=movq,
|
60
|
+
prior_prior=prior_prior,
|
61
|
+
prior_image_encoder=prior_image_encoder,
|
62
|
+
prior_text_encoder=prior_text_encoder,
|
63
|
+
prior_tokenizer=prior_tokenizer,
|
64
|
+
prior_scheduler=prior_scheduler,
|
65
|
+
prior_image_processor=prior_image_processor,
|
66
|
+
)
|
67
|
+
|
68
|
+
self.prior_pipe = RBLNKandinskyV22PriorPipeline(
|
69
|
+
prior=prior_prior,
|
70
|
+
image_encoder=prior_image_encoder,
|
71
|
+
text_encoder=prior_text_encoder,
|
72
|
+
tokenizer=prior_tokenizer,
|
73
|
+
scheduler=prior_scheduler,
|
74
|
+
image_processor=prior_image_processor,
|
75
|
+
)
|
76
|
+
self.decoder_pipe = RBLNKandinskyV22InpaintPipeline(
|
77
|
+
unet=unet,
|
78
|
+
scheduler=scheduler,
|
79
|
+
movq=movq,
|
80
|
+
)
|
81
|
+
|
82
|
+
def get_compiled_image_size(self):
|
83
|
+
return self.movq.image_size
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from diffusers import KandinskyV22InpaintPipeline
|
16
|
+
|
17
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
18
|
+
|
19
|
+
|
20
|
+
class RBLNKandinskyV22InpaintPipeline(RBLNDiffusionMixin, KandinskyV22InpaintPipeline):
|
21
|
+
original_class = KandinskyV22InpaintPipeline
|
22
|
+
_submodules = ["unet", "movq"]
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from diffusers import KandinskyV22PriorPipeline
|
16
|
+
|
17
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
18
|
+
|
19
|
+
|
20
|
+
class RBLNKandinskyV22PriorPipeline(RBLNDiffusionMixin, KandinskyV22PriorPipeline):
|
21
|
+
original_class = KandinskyV22PriorPipeline
|
22
|
+
_submodules = ["text_encoder", "image_encoder", "prior"]
|
optimum/rbln/modeling_base.py
CHANGED
@@ -442,8 +442,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
442
442
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
443
443
|
return
|
444
444
|
|
445
|
-
|
446
|
-
|
445
|
+
# Normalize paths to handle relative paths and symlinks
|
446
|
+
real_save_dir = Path(self.model_save_dir).resolve() / self.subfolder
|
447
|
+
save_directory_path = Path(save_directory).resolve()
|
447
448
|
|
448
449
|
if not os.path.exists(real_save_dir) or not os.path.isdir(real_save_dir):
|
449
450
|
raise FileNotFoundError(
|
@@ -452,13 +453,13 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
452
453
|
f"Please ensure the model directory exists and you have the necessary permissions to access it."
|
453
454
|
)
|
454
455
|
|
455
|
-
if save_directory_path
|
456
|
+
if save_directory_path == real_save_dir:
|
456
457
|
raise FileExistsError(
|
457
458
|
f"Cannot save model to '{save_directory}'. This directory already exists and contains the model files."
|
458
459
|
)
|
459
460
|
|
460
|
-
# Create a temporary directory
|
461
|
-
tmp_dir =
|
461
|
+
# Create a temporary directory with normalized path
|
462
|
+
tmp_dir = str(save_directory_path) + ".tmp"
|
462
463
|
try:
|
463
464
|
# Remove temporary directory if it exists from a previous failed attempt
|
464
465
|
if os.path.exists(tmp_dir):
|
@@ -473,9 +474,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
473
474
|
self.generation_config.save_pretrained(tmp_dir)
|
474
475
|
|
475
476
|
# If everything succeeded, atomically replace the target directory
|
476
|
-
if os.path.exists(
|
477
|
-
shutil.rmtree(
|
478
|
-
os.rename(tmp_dir,
|
477
|
+
if os.path.exists(save_directory_path):
|
478
|
+
shutil.rmtree(save_directory_path)
|
479
|
+
os.rename(tmp_dir, save_directory_path)
|
479
480
|
|
480
481
|
except Exception as e:
|
481
482
|
# Clean up the temporary directory if anything fails
|
@@ -484,7 +485,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
484
485
|
raise e # Re-raise the exception after cleanup
|
485
486
|
|
486
487
|
if push_to_hub:
|
487
|
-
return super().push_to_hub(
|
488
|
+
return super().push_to_hub(str(save_directory_path), **kwargs)
|
488
489
|
|
489
490
|
@staticmethod
|
490
491
|
def _raise_missing_compiled_file_error(missing_files: List[str]):
|
@@ -40,6 +40,7 @@ _import_structure = {
|
|
40
40
|
"RBLNCLIPTextModel",
|
41
41
|
"RBLNCLIPTextModelWithProjection",
|
42
42
|
"RBLNCLIPVisionModel",
|
43
|
+
"RBLNCLIPVisionModelWithProjection",
|
43
44
|
"RBLNDPTForDepthEstimation",
|
44
45
|
"RBLNExaoneForCausalLM",
|
45
46
|
"RBLNGemmaForCausalLM",
|
@@ -99,6 +100,7 @@ if TYPE_CHECKING:
|
|
99
100
|
RBLNCLIPTextModel,
|
100
101
|
RBLNCLIPTextModelWithProjection,
|
101
102
|
RBLNCLIPVisionModel,
|
103
|
+
RBLNCLIPVisionModelWithProjection,
|
102
104
|
RBLNDPTForDepthEstimation,
|
103
105
|
RBLNExaoneForCausalLM,
|
104
106
|
RBLNGemmaForCausalLM,
|
@@ -34,7 +34,12 @@ _import_structure = {
|
|
34
34
|
],
|
35
35
|
"bart": ["RBLNBartForConditionalGeneration", "RBLNBartModel"],
|
36
36
|
"bert": ["RBLNBertModel", "RBLNBertForQuestionAnswering", "RBLNBertForMaskedLM"],
|
37
|
-
"clip": [
|
37
|
+
"clip": [
|
38
|
+
"RBLNCLIPTextModel",
|
39
|
+
"RBLNCLIPTextModelWithProjection",
|
40
|
+
"RBLNCLIPVisionModel",
|
41
|
+
"RBLNCLIPVisionModelWithProjection",
|
42
|
+
],
|
38
43
|
"dpt": ["RBLNDPTForDepthEstimation"],
|
39
44
|
"exaone": ["RBLNExaoneForCausalLM"],
|
40
45
|
"gemma": ["RBLNGemmaForCausalLM"],
|
@@ -68,7 +73,12 @@ if TYPE_CHECKING:
|
|
68
73
|
)
|
69
74
|
from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
|
70
75
|
from .bert import RBLNBertForMaskedLM, RBLNBertForQuestionAnswering, RBLNBertModel
|
71
|
-
from .clip import
|
76
|
+
from .clip import (
|
77
|
+
RBLNCLIPTextModel,
|
78
|
+
RBLNCLIPTextModelWithProjection,
|
79
|
+
RBLNCLIPVisionModel,
|
80
|
+
RBLNCLIPVisionModelWithProjection,
|
81
|
+
)
|
72
82
|
from .dpt import RBLNDPTForDepthEstimation
|
73
83
|
from .exaone import RBLNExaoneForCausalLM
|
74
84
|
from .gemma import RBLNGemmaForCausalLM
|
@@ -12,4 +12,9 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from .modeling_clip import
|
15
|
+
from .modeling_clip import (
|
16
|
+
RBLNCLIPTextModel,
|
17
|
+
RBLNCLIPTextModelWithProjection,
|
18
|
+
RBLNCLIPVisionModel,
|
19
|
+
RBLNCLIPVisionModelWithProjection,
|
20
|
+
)
|
@@ -22,7 +22,7 @@ from transformers import (
|
|
22
22
|
CLIPVisionModel,
|
23
23
|
)
|
24
24
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
25
|
-
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
25
|
+
from transformers.models.clip.modeling_clip import CLIPTextModelOutput, CLIPVisionModelOutput
|
26
26
|
|
27
27
|
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
|
28
28
|
from ....modeling import RBLNModel
|
@@ -116,6 +116,10 @@ class RBLNCLIPVisionModel(RBLNModel):
|
|
116
116
|
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
117
117
|
return _VisionEncoder(model).eval()
|
118
118
|
|
119
|
+
@classmethod
|
120
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
121
|
+
return rbln_config
|
122
|
+
|
119
123
|
@classmethod
|
120
124
|
def _get_rbln_config(
|
121
125
|
cls,
|
@@ -179,3 +183,24 @@ class RBLNCLIPVisionModel(RBLNModel):
|
|
179
183
|
pooler_output=output[1],
|
180
184
|
hidden_states=output[2:],
|
181
185
|
)
|
186
|
+
|
187
|
+
|
188
|
+
class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
|
189
|
+
def forward(
|
190
|
+
self,
|
191
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
192
|
+
**kwargs,
|
193
|
+
) -> Union[Tuple, CLIPVisionModelOutput]:
|
194
|
+
if len(kwargs) > 0 and any(kwargs.values()):
|
195
|
+
logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
|
196
|
+
|
197
|
+
output = super().forward(pixel_values)
|
198
|
+
image_embeds = output[0]
|
199
|
+
last_hidden_state = output[1]
|
200
|
+
hidden_states = output[2:]
|
201
|
+
|
202
|
+
return CLIPVisionModelOutput(
|
203
|
+
image_embeds=image_embeds,
|
204
|
+
last_hidden_state=last_hidden_state,
|
205
|
+
hidden_states=hidden_states,
|
206
|
+
)
|
@@ -427,12 +427,14 @@ class DecoderOnlyModel(nn.Module):
|
|
427
427
|
cos, sin = None, None
|
428
428
|
|
429
429
|
# (batch, seq_len) -> (batch,)
|
430
|
-
seq_positions = cache_position[:, 0]
|
431
430
|
if self.attn_impl == "flash_attn":
|
431
|
+
seq_positions = cache_position[:, 0]
|
432
432
|
max_seq_len = past_key_values[0][0].shape[-2]
|
433
433
|
seq_positions = self.convert_sequence_positions_for_flash_attn(
|
434
434
|
seq_positions=seq_positions, max_seq_len=max_seq_len
|
435
435
|
)
|
436
|
+
else:
|
437
|
+
seq_positions = cache_position[:, :1]
|
436
438
|
|
437
439
|
present_key_values = past_key_values
|
438
440
|
for layer in self.layers:
|
@@ -459,7 +459,7 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
459
459
|
), # Unsqueeze group axis since CustomKernel expects it for group query attention
|
460
460
|
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
461
461
|
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
462
|
-
cache_position
|
462
|
+
cache_position,
|
463
463
|
torch.tensor(1.0, dtype=torch.float32), # scale
|
464
464
|
)
|
465
465
|
|