optimum-rbln 0.2.1a4__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. optimum/rbln/__init__.py +14 -2
  2. optimum/rbln/__version__.py +9 -4
  3. optimum/rbln/diffusers/__init__.py +10 -0
  4. optimum/rbln/diffusers/modeling_diffusers.py +132 -25
  5. optimum/rbln/diffusers/models/__init__.py +7 -1
  6. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +52 -2
  8. optimum/rbln/diffusers/models/autoencoders/vq_model.py +159 -0
  9. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  10. optimum/rbln/diffusers/models/transformers/prior_transformer.py +174 -0
  11. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +57 -14
  12. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  13. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +17 -0
  14. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +83 -0
  15. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +22 -0
  16. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +22 -0
  17. optimum/rbln/modeling_base.py +10 -9
  18. optimum/rbln/transformers/__init__.py +2 -0
  19. optimum/rbln/transformers/models/__init__.py +12 -2
  20. optimum/rbln/transformers/models/clip/__init__.py +6 -1
  21. optimum/rbln/transformers/models/clip/modeling_clip.py +26 -1
  22. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +3 -1
  23. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +1 -1
  24. optimum/rbln/utils/import_utils.py +7 -0
  25. {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/METADATA +1 -1
  26. {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/RECORD +28 -22
  27. {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/WHEEL +0 -0
  28. {optimum_rbln-0.2.1a4.dist-info → optimum_rbln-0.7.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,174 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import Any, Dict, Optional, Union
17
+
18
+ import torch
19
+ from diffusers.models.transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
20
+ from transformers import PretrainedConfig, PreTrainedModel
21
+
22
+ from ....modeling import RBLNModel
23
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
24
+ from ....utils.logging import get_logger
25
+ from ....utils.runtime_utils import RBLNPytorchRuntime
26
+ from ...modeling_diffusers import RBLNDiffusionMixin
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ class RBLNRuntimePriorTransformer(RBLNPytorchRuntime):
33
+ def forward(
34
+ self, hidden_states, timestep, proj_embedding, encoder_hidden_states, attention_mask, return_dict: bool = True
35
+ ):
36
+ predicted_image_embedding = super().forward(
37
+ hidden_states,
38
+ timestep,
39
+ proj_embedding,
40
+ encoder_hidden_states,
41
+ attention_mask,
42
+ )
43
+ if return_dict:
44
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
45
+ else:
46
+ return (predicted_image_embedding,)
47
+
48
+
49
+ class _PriorTransformer(torch.nn.Module):
50
+ def __init__(self, prior: PriorTransformer):
51
+ super().__init__()
52
+ self._prior = prior
53
+
54
+ def forward(
55
+ self,
56
+ hidden_states,
57
+ timestep,
58
+ proj_embedding,
59
+ encoder_hidden_states,
60
+ attention_mask,
61
+ return_dict=True,
62
+ ):
63
+ return self._prior.forward(
64
+ hidden_states,
65
+ timestep,
66
+ proj_embedding,
67
+ encoder_hidden_states,
68
+ attention_mask,
69
+ return_dict=False,
70
+ )
71
+
72
+
73
+ class RBLNPriorTransformer(RBLNModel):
74
+ hf_library_name = "diffusers"
75
+ auto_model_class = PriorTransformer
76
+
77
+ def __post_init__(self, **kwargs):
78
+ super().__post_init__(**kwargs)
79
+ self.runtime = RBLNRuntimePriorTransformer(runtime=self.model[0])
80
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
81
+ self.clip_mean = artifacts["clip_mean"]
82
+ self.clip_std = artifacts["clip_std"]
83
+
84
+ @classmethod
85
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
86
+ return _PriorTransformer(model).eval()
87
+
88
+ @classmethod
89
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
90
+ batch_size = rbln_config.get("batch_size")
91
+ if not batch_size:
92
+ do_classifier_free_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0
93
+ batch_size = 2 if do_classifier_free_guidance else 1
94
+ else:
95
+ if rbln_config.get("guidance_scale"):
96
+ logger.warning(
97
+ "guidance_scale is ignored because batch size is explicitly specified. "
98
+ "To ensure consistent behavior, consider removing the guidance scale or "
99
+ "adjusting the batch size configuration as needed."
100
+ )
101
+ embedding_dim = rbln_config.get("embedding_dim", pipe.prior.config.embedding_dim)
102
+ num_embeddings = rbln_config.get("num_embeddings", pipe.prior.config.num_embeddings)
103
+
104
+ rbln_config.update(
105
+ {
106
+ "batch_size": batch_size,
107
+ "embedding_dim": embedding_dim,
108
+ "num_embeddings": num_embeddings,
109
+ }
110
+ )
111
+
112
+ return rbln_config
113
+
114
+ @classmethod
115
+ def save_torch_artifacts(
116
+ cls,
117
+ model: "PreTrainedModel",
118
+ save_dir_path: Path,
119
+ subfolder: str,
120
+ rbln_config: RBLNConfig,
121
+ ):
122
+ save_dict = {}
123
+ save_dict["clip_mean"] = model.clip_mean
124
+ save_dict["clip_std"] = model.clip_std
125
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
126
+
127
+ @classmethod
128
+ def _get_rbln_config(
129
+ cls,
130
+ preprocessors,
131
+ model_config: PretrainedConfig,
132
+ rbln_kwargs,
133
+ ) -> RBLNConfig:
134
+ batch_size = rbln_kwargs.get("batch_size") or 1
135
+ embedding_dim = rbln_kwargs.get("embedding_dim") or model_config.embedding_dim
136
+ num_embeddings = rbln_kwargs.get("num_embeddings") or model_config.num_embeddings
137
+
138
+ input_info = [
139
+ ("hidden_states", [batch_size, embedding_dim], "float32"),
140
+ ("timestep", [], "float32"),
141
+ ("proj_embedding", [batch_size, embedding_dim], "float32"),
142
+ ("encoder_hidden_states", [batch_size, num_embeddings, embedding_dim], "float32"),
143
+ ("attention_mask", [batch_size, num_embeddings], "float32"),
144
+ ]
145
+
146
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
147
+ rbln_config = RBLNConfig(
148
+ rbln_cls=cls.__name__,
149
+ compile_cfgs=[rbln_compile_config],
150
+ rbln_kwargs=rbln_kwargs,
151
+ )
152
+ return rbln_config
153
+
154
+ def forward(
155
+ self,
156
+ hidden_states,
157
+ timestep: Union[torch.Tensor, float, int],
158
+ proj_embedding: torch.Tensor,
159
+ encoder_hidden_states: Optional[torch.Tensor] = None,
160
+ attention_mask: Optional[torch.BoolTensor] = None,
161
+ return_dict: bool = True,
162
+ ):
163
+ return self.runtime.forward(
164
+ hidden_states.contiguous(),
165
+ timestep.float(),
166
+ proj_embedding,
167
+ encoder_hidden_states,
168
+ attention_mask.float(),
169
+ return_dict,
170
+ )
171
+
172
+ def post_process_latents(self, prior_latents):
173
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
174
+ return prior_latents
@@ -115,6 +115,29 @@ class _UNet_SDXL(torch.nn.Module):
115
115
  return unet_out
116
116
 
117
117
 
118
+ class _UNet_Kandinsky(torch.nn.Module):
119
+ def __init__(self, unet: "UNet2DConditionModel"):
120
+ super().__init__()
121
+ self.unet = unet
122
+
123
+ def forward(
124
+ self,
125
+ sample: torch.Tensor,
126
+ timestep: Union[torch.Tensor, float, int],
127
+ image_embeds: torch.Tensor,
128
+ ) -> torch.Tensor:
129
+ added_cond_kwargs = {"image_embeds": image_embeds}
130
+
131
+ unet_out = self.unet(
132
+ sample=sample,
133
+ timestep=timestep,
134
+ encoder_hidden_states=None,
135
+ added_cond_kwargs=added_cond_kwargs,
136
+ return_dict=False,
137
+ )
138
+ return unet_out
139
+
140
+
118
141
  class RBLNUNet2DConditionModel(RBLNModel):
119
142
  hf_library_name = "diffusers"
120
143
  auto_model_class = UNet2DConditionModel
@@ -138,6 +161,8 @@ class RBLNUNet2DConditionModel(RBLNModel):
138
161
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
139
162
  if model.config.addition_embed_type == "text_time":
140
163
  return _UNet_SDXL(model).eval()
164
+ elif model.config.addition_embed_type == "image":
165
+ return _UNet_Kandinsky(model).eval()
141
166
  else:
142
167
  return _UNet_SD(model).eval()
143
168
 
@@ -146,6 +171,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
146
171
  cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]
147
172
  ) -> Union[int, Tuple[int, int]]:
148
173
  image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
174
+ scale_factor = pipe.movq_scale_factor if hasattr(pipe, "movq_scale_factor") else pipe.vae_scale_factor
149
175
  if (image_size[0] is None) != (image_size[1] is None):
150
176
  raise ValueError("Both image height and image width must be given or not given")
151
177
  elif image_size[0] is None and image_size[1] is None:
@@ -153,22 +179,23 @@ class RBLNUNet2DConditionModel(RBLNModel):
153
179
  # In case of img2img, sample size of unet is determined by vae encoder.
154
180
  vae_sample_size = pipe.vae.config.sample_size
155
181
  if isinstance(vae_sample_size, int):
156
- sample_size = vae_sample_size // pipe.vae_scale_factor
182
+ sample_size = vae_sample_size // scale_factor
157
183
  else:
158
184
  sample_size = (
159
- vae_sample_size[0] // pipe.vae_scale_factor,
160
- vae_sample_size[1] // pipe.vae_scale_factor,
185
+ vae_sample_size[0] // scale_factor,
186
+ vae_sample_size[1] // scale_factor,
161
187
  )
162
188
  else:
163
189
  sample_size = pipe.unet.config.sample_size
164
190
  else:
165
- sample_size = (image_size[0] // pipe.vae_scale_factor, image_size[1] // pipe.vae_scale_factor)
191
+ sample_size = (image_size[0] // scale_factor, image_size[1] // scale_factor)
166
192
 
167
193
  return sample_size
168
194
 
169
195
  @classmethod
170
196
  def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
171
197
  text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
198
+ image_model_hidden_size = pipe.unet.config.encoder_hid_dim if hasattr(pipe, "unet") else None
172
199
 
173
200
  batch_size = rbln_config.get("batch_size")
174
201
  if not batch_size:
@@ -184,10 +211,12 @@ class RBLNUNet2DConditionModel(RBLNModel):
184
211
  "adjusting the batch size configuration as needed."
185
212
  )
186
213
 
214
+ max_seq_len = pipe.text_encoder.config.max_position_embeddings if hasattr(pipe, "text_encoder") else None
187
215
  rbln_config.update(
188
216
  {
189
- "max_seq_len": pipe.text_encoder.config.max_position_embeddings,
217
+ "max_seq_len": max_seq_len,
190
218
  "text_model_hidden_size": text_model_hidden_size,
219
+ "image_model_hidden_size": image_model_hidden_size,
191
220
  "sample_size": cls.get_unet_sample_size(pipe, rbln_config),
192
221
  "batch_size": batch_size,
193
222
  "is_controlnet": "controlnet" in pipe.config.keys(),
@@ -218,15 +247,16 @@ class RBLNUNet2DConditionModel(RBLNModel):
218
247
  if isinstance(sample_size, int):
219
248
  sample_size = (sample_size, sample_size)
220
249
 
221
- if max_seq_len is None:
222
- raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings) must be specified.")
223
-
224
250
  input_info = [
225
251
  ("sample", [batch_size, model_config.in_channels, sample_size[0], sample_size[1]], "float32"),
226
252
  ("timestep", [], "float32"),
227
- ("encoder_hidden_states", [batch_size, max_seq_len, model_config.cross_attention_dim], "float32"),
228
253
  ]
229
254
 
255
+ if max_seq_len is not None:
256
+ input_info.append(
257
+ ("encoder_hidden_states", [batch_size, max_seq_len, model_config.cross_attention_dim], "float32"),
258
+ )
259
+
230
260
  if is_controlnet:
231
261
  # down block addtional residuals
232
262
  first_shape = [batch_size, model_config.block_out_channels[0], sample_size[0], sample_size[1]]
@@ -256,11 +286,15 @@ class RBLNUNet2DConditionModel(RBLNModel):
256
286
  ]
257
287
  input_info.append(("mid_block_additional_residual", shape, "float32"))
258
288
 
259
- if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
260
- rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
261
- rbln_in_features = model_config.projection_class_embeddings_input_dim
262
- input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
263
- input_info.append(("time_ids", [batch_size, 6], "float32"))
289
+ if hasattr(model_config, "addition_embed_type"):
290
+ if model_config.addition_embed_type == "text_time":
291
+ rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
292
+ rbln_in_features = model_config.projection_class_embeddings_input_dim
293
+ input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
294
+ input_info.append(("time_ids", [batch_size, 6], "float32"))
295
+ elif model_config.addition_embed_type == "image":
296
+ rbln_image_model_hidden_size = rbln_kwargs["image_model_hidden_size"]
297
+ input_info.append(("image_embeds", [batch_size, rbln_image_model_hidden_size], "float32"))
264
298
 
265
299
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
266
300
 
@@ -323,6 +357,15 @@ class RBLNUNet2DConditionModel(RBLNModel):
323
357
  ),
324
358
  )
325
359
 
360
+ if "image_embeds" in added_cond_kwargs:
361
+ return (
362
+ super().forward(
363
+ sample.contiguous(),
364
+ timestep.float(),
365
+ **added_cond_kwargs,
366
+ ),
367
+ )
368
+
326
369
  return (
327
370
  super().forward(
328
371
  sample.contiguous(),
@@ -25,6 +25,11 @@ _import_structure = {
25
25
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
26
26
  "RBLNStableDiffusionXLControlNetPipeline",
27
27
  ],
28
+ "kandinsky2_2": [
29
+ "RBLNKandinskyV22InpaintCombinedPipeline",
30
+ "RBLNKandinskyV22InpaintPipeline",
31
+ "RBLNKandinskyV22PriorPipeline",
32
+ ],
28
33
  "stable_diffusion": [
29
34
  "RBLNStableDiffusionImg2ImgPipeline",
30
35
  "RBLNStableDiffusionPipeline",
@@ -49,6 +54,11 @@ if TYPE_CHECKING:
49
54
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
50
55
  RBLNStableDiffusionXLControlNetPipeline,
51
56
  )
57
+ from .kandinsky2_2 import (
58
+ RBLNKandinskyV22InpaintCombinedPipeline,
59
+ RBLNKandinskyV22InpaintPipeline,
60
+ RBLNKandinskyV22PriorPipeline,
61
+ )
52
62
  from .stable_diffusion import (
53
63
  RBLNStableDiffusionImg2ImgPipeline,
54
64
  RBLNStableDiffusionInpaintPipeline,
@@ -0,0 +1,17 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .pipeline_kandinsky2_2_combined import RBLNKandinskyV22InpaintCombinedPipeline
16
+ from .pipeline_kandinsky2_2_inpaint import RBLNKandinskyV22InpaintPipeline
17
+ from .pipeline_kandinsky2_2_prior import RBLNKandinskyV22PriorPipeline
@@ -0,0 +1,83 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from diffusers import (
16
+ DDPMScheduler,
17
+ KandinskyV22InpaintCombinedPipeline,
18
+ PriorTransformer,
19
+ UnCLIPScheduler,
20
+ UNet2DConditionModel,
21
+ VQModel,
22
+ )
23
+ from transformers import (
24
+ CLIPImageProcessor,
25
+ CLIPTextModelWithProjection,
26
+ CLIPTokenizer,
27
+ CLIPVisionModelWithProjection,
28
+ )
29
+
30
+ from ...modeling_diffusers import RBLNDiffusionMixin
31
+ from .pipeline_kandinsky2_2_inpaint import RBLNKandinskyV22InpaintPipeline
32
+ from .pipeline_kandinsky2_2_prior import RBLNKandinskyV22PriorPipeline
33
+
34
+
35
+ class RBLNKandinskyV22InpaintCombinedPipeline(RBLNDiffusionMixin, KandinskyV22InpaintCombinedPipeline):
36
+ original_class = KandinskyV22InpaintCombinedPipeline
37
+ _connected_classes = {"prior_pipe": RBLNKandinskyV22PriorPipeline, "decoder_pipe": RBLNKandinskyV22InpaintPipeline}
38
+ _submodules = ["prior_pipe", "decoder_pipe"]
39
+ _prefix = {"prior_pipe": "prior_"}
40
+
41
+ def __init__(
42
+ self,
43
+ unet: "UNet2DConditionModel",
44
+ scheduler: "DDPMScheduler",
45
+ movq: "VQModel",
46
+ prior_prior: "PriorTransformer",
47
+ prior_image_encoder: "CLIPVisionModelWithProjection",
48
+ prior_text_encoder: "CLIPTextModelWithProjection",
49
+ prior_tokenizer: "CLIPTokenizer",
50
+ prior_scheduler: "UnCLIPScheduler",
51
+ prior_image_processor: "CLIPImageProcessor",
52
+ ):
53
+ RBLNDiffusionMixin.__init__(self)
54
+ super(KandinskyV22InpaintCombinedPipeline, self).__init__()
55
+
56
+ self.register_modules(
57
+ unet=unet,
58
+ scheduler=scheduler,
59
+ movq=movq,
60
+ prior_prior=prior_prior,
61
+ prior_image_encoder=prior_image_encoder,
62
+ prior_text_encoder=prior_text_encoder,
63
+ prior_tokenizer=prior_tokenizer,
64
+ prior_scheduler=prior_scheduler,
65
+ prior_image_processor=prior_image_processor,
66
+ )
67
+
68
+ self.prior_pipe = RBLNKandinskyV22PriorPipeline(
69
+ prior=prior_prior,
70
+ image_encoder=prior_image_encoder,
71
+ text_encoder=prior_text_encoder,
72
+ tokenizer=prior_tokenizer,
73
+ scheduler=prior_scheduler,
74
+ image_processor=prior_image_processor,
75
+ )
76
+ self.decoder_pipe = RBLNKandinskyV22InpaintPipeline(
77
+ unet=unet,
78
+ scheduler=scheduler,
79
+ movq=movq,
80
+ )
81
+
82
+ def get_compiled_image_size(self):
83
+ return self.movq.image_size
@@ -0,0 +1,22 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from diffusers import KandinskyV22InpaintPipeline
16
+
17
+ from ...modeling_diffusers import RBLNDiffusionMixin
18
+
19
+
20
+ class RBLNKandinskyV22InpaintPipeline(RBLNDiffusionMixin, KandinskyV22InpaintPipeline):
21
+ original_class = KandinskyV22InpaintPipeline
22
+ _submodules = ["unet", "movq"]
@@ -0,0 +1,22 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from diffusers import KandinskyV22PriorPipeline
16
+
17
+ from ...modeling_diffusers import RBLNDiffusionMixin
18
+
19
+
20
+ class RBLNKandinskyV22PriorPipeline(RBLNDiffusionMixin, KandinskyV22PriorPipeline):
21
+ original_class = KandinskyV22PriorPipeline
22
+ _submodules = ["text_encoder", "image_encoder", "prior"]
@@ -442,8 +442,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
442
442
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
443
443
  return
444
444
 
445
- real_save_dir = self.model_save_dir / self.subfolder
446
- save_directory_path = Path(save_directory)
445
+ # Normalize paths to handle relative paths and symlinks
446
+ real_save_dir = Path(self.model_save_dir).resolve() / self.subfolder
447
+ save_directory_path = Path(save_directory).resolve()
447
448
 
448
449
  if not os.path.exists(real_save_dir) or not os.path.isdir(real_save_dir):
449
450
  raise FileNotFoundError(
@@ -452,13 +453,13 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
452
453
  f"Please ensure the model directory exists and you have the necessary permissions to access it."
453
454
  )
454
455
 
455
- if save_directory_path.absolute() == real_save_dir.absolute():
456
+ if save_directory_path == real_save_dir:
456
457
  raise FileExistsError(
457
458
  f"Cannot save model to '{save_directory}'. This directory already exists and contains the model files."
458
459
  )
459
460
 
460
- # Create a temporary directory next to the target directory
461
- tmp_dir = save_directory + ".tmp"
461
+ # Create a temporary directory with normalized path
462
+ tmp_dir = str(save_directory_path) + ".tmp"
462
463
  try:
463
464
  # Remove temporary directory if it exists from a previous failed attempt
464
465
  if os.path.exists(tmp_dir):
@@ -473,9 +474,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
473
474
  self.generation_config.save_pretrained(tmp_dir)
474
475
 
475
476
  # If everything succeeded, atomically replace the target directory
476
- if os.path.exists(save_directory):
477
- shutil.rmtree(save_directory)
478
- os.rename(tmp_dir, save_directory)
477
+ if os.path.exists(save_directory_path):
478
+ shutil.rmtree(save_directory_path)
479
+ os.rename(tmp_dir, save_directory_path)
479
480
 
480
481
  except Exception as e:
481
482
  # Clean up the temporary directory if anything fails
@@ -484,7 +485,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
484
485
  raise e # Re-raise the exception after cleanup
485
486
 
486
487
  if push_to_hub:
487
- return super().push_to_hub(save_directory, **kwargs)
488
+ return super().push_to_hub(str(save_directory_path), **kwargs)
488
489
 
489
490
  @staticmethod
490
491
  def _raise_missing_compiled_file_error(missing_files: List[str]):
@@ -40,6 +40,7 @@ _import_structure = {
40
40
  "RBLNCLIPTextModel",
41
41
  "RBLNCLIPTextModelWithProjection",
42
42
  "RBLNCLIPVisionModel",
43
+ "RBLNCLIPVisionModelWithProjection",
43
44
  "RBLNDPTForDepthEstimation",
44
45
  "RBLNExaoneForCausalLM",
45
46
  "RBLNGemmaForCausalLM",
@@ -99,6 +100,7 @@ if TYPE_CHECKING:
99
100
  RBLNCLIPTextModel,
100
101
  RBLNCLIPTextModelWithProjection,
101
102
  RBLNCLIPVisionModel,
103
+ RBLNCLIPVisionModelWithProjection,
102
104
  RBLNDPTForDepthEstimation,
103
105
  RBLNExaoneForCausalLM,
104
106
  RBLNGemmaForCausalLM,
@@ -34,7 +34,12 @@ _import_structure = {
34
34
  ],
35
35
  "bart": ["RBLNBartForConditionalGeneration", "RBLNBartModel"],
36
36
  "bert": ["RBLNBertModel", "RBLNBertForQuestionAnswering", "RBLNBertForMaskedLM"],
37
- "clip": ["RBLNCLIPTextModel", "RBLNCLIPTextModelWithProjection", "RBLNCLIPVisionModel"],
37
+ "clip": [
38
+ "RBLNCLIPTextModel",
39
+ "RBLNCLIPTextModelWithProjection",
40
+ "RBLNCLIPVisionModel",
41
+ "RBLNCLIPVisionModelWithProjection",
42
+ ],
38
43
  "dpt": ["RBLNDPTForDepthEstimation"],
39
44
  "exaone": ["RBLNExaoneForCausalLM"],
40
45
  "gemma": ["RBLNGemmaForCausalLM"],
@@ -68,7 +73,12 @@ if TYPE_CHECKING:
68
73
  )
69
74
  from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
70
75
  from .bert import RBLNBertForMaskedLM, RBLNBertForQuestionAnswering, RBLNBertModel
71
- from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
76
+ from .clip import (
77
+ RBLNCLIPTextModel,
78
+ RBLNCLIPTextModelWithProjection,
79
+ RBLNCLIPVisionModel,
80
+ RBLNCLIPVisionModelWithProjection,
81
+ )
72
82
  from .dpt import RBLNDPTForDepthEstimation
73
83
  from .exaone import RBLNExaoneForCausalLM
74
84
  from .gemma import RBLNGemmaForCausalLM
@@ -12,4 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
15
+ from .modeling_clip import (
16
+ RBLNCLIPTextModel,
17
+ RBLNCLIPTextModelWithProjection,
18
+ RBLNCLIPVisionModel,
19
+ RBLNCLIPVisionModelWithProjection,
20
+ )
@@ -22,7 +22,7 @@ from transformers import (
22
22
  CLIPVisionModel,
23
23
  )
24
24
  from transformers.modeling_outputs import BaseModelOutputWithPooling
25
- from transformers.models.clip.modeling_clip import CLIPTextModelOutput
25
+ from transformers.models.clip.modeling_clip import CLIPTextModelOutput, CLIPVisionModelOutput
26
26
 
27
27
  from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
28
28
  from ....modeling import RBLNModel
@@ -116,6 +116,10 @@ class RBLNCLIPVisionModel(RBLNModel):
116
116
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
117
117
  return _VisionEncoder(model).eval()
118
118
 
119
+ @classmethod
120
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
121
+ return rbln_config
122
+
119
123
  @classmethod
120
124
  def _get_rbln_config(
121
125
  cls,
@@ -179,3 +183,24 @@ class RBLNCLIPVisionModel(RBLNModel):
179
183
  pooler_output=output[1],
180
184
  hidden_states=output[2:],
181
185
  )
186
+
187
+
188
+ class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
189
+ def forward(
190
+ self,
191
+ pixel_values: Optional[torch.FloatTensor] = None,
192
+ **kwargs,
193
+ ) -> Union[Tuple, CLIPVisionModelOutput]:
194
+ if len(kwargs) > 0 and any(kwargs.values()):
195
+ logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
196
+
197
+ output = super().forward(pixel_values)
198
+ image_embeds = output[0]
199
+ last_hidden_state = output[1]
200
+ hidden_states = output[2:]
201
+
202
+ return CLIPVisionModelOutput(
203
+ image_embeds=image_embeds,
204
+ last_hidden_state=last_hidden_state,
205
+ hidden_states=hidden_states,
206
+ )
@@ -427,12 +427,14 @@ class DecoderOnlyModel(nn.Module):
427
427
  cos, sin = None, None
428
428
 
429
429
  # (batch, seq_len) -> (batch,)
430
- seq_positions = cache_position[:, 0]
431
430
  if self.attn_impl == "flash_attn":
431
+ seq_positions = cache_position[:, 0]
432
432
  max_seq_len = past_key_values[0][0].shape[-2]
433
433
  seq_positions = self.convert_sequence_positions_for_flash_attn(
434
434
  seq_positions=seq_positions, max_seq_len=max_seq_len
435
435
  )
436
+ else:
437
+ seq_positions = cache_position[:, :1]
436
438
 
437
439
  present_key_values = past_key_values
438
440
  for layer in self.layers:
@@ -459,7 +459,7 @@ class Seq2SeqSelfAttention(nn.Module):
459
459
  ), # Unsqueeze group axis since CustomKernel expects it for group query attention
460
460
  past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
461
461
  past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
462
- cache_position.squeeze(1),
462
+ cache_position,
463
463
  torch.tensor(1.0, dtype=torch.float32), # scale
464
464
  )
465
465