optimum-rbln 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. optimum/rbln/__init__.py +115 -0
  2. optimum/rbln/__version__.py +1 -0
  3. optimum/rbln/diffusers/__init__.py +64 -0
  4. optimum/rbln/diffusers/models/__init__.py +26 -0
  5. optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
  6. optimum/rbln/diffusers/models/controlnet.py +180 -0
  7. optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
  8. optimum/rbln/diffusers/pipelines/__init__.py +30 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
  10. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
  18. optimum/rbln/modeling.py +0 -0
  19. optimum/rbln/modeling_alias.py +49 -0
  20. optimum/rbln/modeling_base.py +645 -0
  21. optimum/rbln/modeling_config.py +169 -0
  22. optimum/rbln/modeling_seq2seq.py +469 -0
  23. optimum/rbln/transformers/__init__.py +59 -0
  24. optimum/rbln/transformers/generation/__init__.py +24 -0
  25. optimum/rbln/transformers/generation/streamers.py +122 -0
  26. optimum/rbln/transformers/models/__init__.py +28 -0
  27. optimum/rbln/transformers/models/bart/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +24 -0
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
  31. optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
  32. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
  33. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
  34. optimum/rbln/transformers/models/llama/__init__.py +24 -0
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
  37. optimum/rbln/transformers/models/t5/__init__.py +24 -0
  38. optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
  39. optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
  40. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
  41. optimum/rbln/transformers/models/whisper/__init__.py +24 -0
  42. optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
  43. optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
  44. optimum/rbln/utils/__init__.py +25 -0
  45. optimum/rbln/utils/import_utils.py +28 -0
  46. optimum/rbln/utils/runtime_utils.py +71 -0
  47. optimum/rbln/utils/save_utils.py +92 -0
  48. optimum_rbln-0.1.0.dist-info/METADATA +144 -0
  49. optimum_rbln-0.1.0.dist-info/RECORD +51 -0
  50. optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
  51. optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,180 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from pathlib import Path
26
+ from typing import TYPE_CHECKING, Optional, Union
27
+
28
+ import rebel
29
+ import torch
30
+ from diffusers import ControlNetModel
31
+ from optimum.exporters import TasksManager
32
+ from transformers import AutoConfig, AutoModel, PretrainedConfig
33
+
34
+ from ...modeling_base import RBLNModel
35
+ from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
36
+
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ if TYPE_CHECKING:
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
42
+
43
+
44
+ class _ControlNetModel(torch.nn.Module):
45
+ def __init__(self, controlnet: "ControlNetModel"):
46
+ super().__init__()
47
+ self.controlnet = controlnet
48
+
49
+ def forward(
50
+ self,
51
+ sample: torch.Tensor,
52
+ timestep: torch.Tensor,
53
+ encoder_hidden_states: torch.Tensor,
54
+ controlnet_cond: torch.Tensor,
55
+ conditioning_scale,
56
+ ):
57
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
58
+ sample=sample,
59
+ timestep=timestep,
60
+ encoder_hidden_states=encoder_hidden_states,
61
+ controlnet_cond=controlnet_cond,
62
+ conditioning_scale=conditioning_scale,
63
+ return_dict=False,
64
+ )
65
+ return down_block_res_samples, mid_block_res_sample
66
+
67
+
68
+ class RBLNControlNetModel(RBLNModel):
69
+ model_type = "rbln_model"
70
+ auto_model_class = AutoModel # feature extraction
71
+
72
+ def __post_init__(self, **kwargs):
73
+ self.dtype = torch.float32
74
+
75
+ @classmethod
76
+ def from_pretrained(cls, *args, **kwargs):
77
+ def get_model_from_task(
78
+ task: str,
79
+ model_name_or_path: Union[str, Path],
80
+ **kwargs,
81
+ ):
82
+ return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
83
+
84
+ tasktmp = TasksManager.get_model_from_task
85
+ configtmp = AutoConfig.from_pretrained
86
+ modeltmp = AutoModel.from_pretrained
87
+ TasksManager.get_model_from_task = get_model_from_task
88
+ AutoConfig.from_pretrained = ControlNetModel.load_config
89
+ AutoModel.from_pretrained = ControlNetModel.from_pretrained
90
+ rt = super().from_pretrained(*args, **kwargs)
91
+ AutoConfig.from_pretrained = configtmp
92
+ AutoModel.from_pretrained = modeltmp
93
+ TasksManager.get_model_from_task = tasktmp
94
+ return rt
95
+
96
+ @classmethod
97
+ def compile(cls, model, rbln_runtime_config: Optional[RBLNRuntimeConfig] = None):
98
+ compiled_model = rebel.compile_from_torch(
99
+ _ControlNetModel(model),
100
+ input_info=rbln_runtime_config.input_info,
101
+ batch_size=rbln_runtime_config.batch_size,
102
+ fusion=rbln_runtime_config.fusion,
103
+ )
104
+ return compiled_model
105
+
106
+ @classmethod
107
+ def _get_rbln_config(
108
+ cls,
109
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
110
+ model_config: "PretrainedConfig",
111
+ rbln_max_seq_len: Optional[int] = None,
112
+ rbln_batch_size: Optional[int] = None,
113
+ rbln_img_width: Optional[int] = None,
114
+ rbln_img_height: Optional[int] = None,
115
+ rbln_vae_scale_factor: Optional[int] = None,
116
+ ) -> RBLNConfig:
117
+ meta = {"type": "controlnet"}
118
+
119
+ if rbln_batch_size is None:
120
+ rbln_batch_size = 1
121
+
122
+ if rbln_max_seq_len is None:
123
+ rbln_max_seq_len = 77
124
+
125
+ input_width = rbln_img_width // rbln_vae_scale_factor
126
+ input_height = rbln_img_height // rbln_vae_scale_factor
127
+
128
+ rbln_runtime_config = RBLNRuntimeConfig(
129
+ input_info=[
130
+ (
131
+ "sample",
132
+ [
133
+ rbln_batch_size,
134
+ model_config.in_channels,
135
+ input_width,
136
+ input_height,
137
+ ],
138
+ "float32",
139
+ ),
140
+ ("timestep", [], "float32"),
141
+ (
142
+ "encoder_hidden_states",
143
+ [
144
+ rbln_batch_size,
145
+ rbln_max_seq_len,
146
+ model_config.cross_attention_dim,
147
+ ],
148
+ "float32",
149
+ ),
150
+ ("controlnet_cond", [rbln_batch_size, 3, rbln_img_width, rbln_img_height], "float32"),
151
+ ("conditioning_scale", [], "float32"),
152
+ ],
153
+ batch_size=rbln_batch_size,
154
+ )
155
+ rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
156
+ return rbln_config
157
+
158
+ def forward(
159
+ self,
160
+ sample: torch.FloatTensor,
161
+ timestep: Union[torch.Tensor, float, int],
162
+ encoder_hidden_states: torch.Tensor,
163
+ controlnet_cond: torch.FloatTensor,
164
+ conditioning_scale: torch.Tensor = 1.0,
165
+ **kwargs,
166
+ ):
167
+ """
168
+ The [`ControlNetModel`] forward method.
169
+ """
170
+ output = super().forward(
171
+ sample.contiguous(),
172
+ timestep.float(),
173
+ encoder_hidden_states,
174
+ controlnet_cond,
175
+ torch.tensor(conditioning_scale),
176
+ )
177
+ down_block_res_samples = output[:-1]
178
+ mid_block_res_sample = output[-1]
179
+
180
+ return down_block_res_samples, mid_block_res_sample
@@ -0,0 +1,352 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from dataclasses import dataclass
26
+ from pathlib import Path
27
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
28
+
29
+ import torch
30
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
31
+ from optimum.exporters import TasksManager
32
+ from transformers import AutoConfig, AutoModel, PretrainedConfig
33
+
34
+ from ...modeling_base import RBLNModel
35
+ from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
36
+
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ if TYPE_CHECKING:
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
42
+
43
+
44
+ class _UNet_SD(torch.nn.Module):
45
+ def __init__(self, unet: "UNet2DConditionModel"):
46
+ super().__init__()
47
+ self.unet = unet
48
+
49
+ def forward(
50
+ self,
51
+ sample: torch.Tensor,
52
+ timestep: Union[torch.Tensor, float, int],
53
+ encoder_hidden_states: torch.Tensor,
54
+ *down_and_mid_block_additional_residuals: Optional[Tuple[torch.Tensor]],
55
+ text_embeds: Optional[torch.Tensor] = None,
56
+ time_ids: Optional[torch.Tensor] = None,
57
+ ) -> torch.Tensor:
58
+ if text_embeds is not None and time_ids is not None:
59
+ added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
60
+ else:
61
+ added_cond_kwargs = {}
62
+
63
+ if len(down_and_mid_block_additional_residuals) != 0:
64
+ down_block_additional_residuals, mid_block_additional_residual = (
65
+ down_and_mid_block_additional_residuals[:-1],
66
+ down_and_mid_block_additional_residuals[-1],
67
+ )
68
+ else:
69
+ down_block_additional_residuals, mid_block_additional_residual = None, None
70
+
71
+ unet_out = self.unet(
72
+ sample=sample,
73
+ timestep=timestep,
74
+ encoder_hidden_states=encoder_hidden_states,
75
+ down_block_additional_residuals=down_block_additional_residuals,
76
+ mid_block_additional_residual=mid_block_additional_residual,
77
+ added_cond_kwargs=added_cond_kwargs,
78
+ return_dict=False,
79
+ )
80
+ return unet_out
81
+
82
+
83
+ class _UNet_SDXL(torch.nn.Module):
84
+ def __init__(self, unet: "UNet2DConditionModel"):
85
+ super().__init__()
86
+ self.unet = unet
87
+
88
+ def forward(
89
+ self,
90
+ sample: torch.Tensor,
91
+ timestep: Union[torch.Tensor, float, int],
92
+ encoder_hidden_states: torch.Tensor,
93
+ text_embeds: Optional[torch.Tensor] = None,
94
+ time_ids: Optional[torch.Tensor] = None,
95
+ *down_and_mid_block_additional_residuals: Optional[Tuple[torch.Tensor]],
96
+ ) -> torch.Tensor:
97
+ if text_embeds is not None and time_ids is not None:
98
+ added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
99
+ else:
100
+ added_cond_kwargs = {}
101
+
102
+ if len(down_and_mid_block_additional_residuals) != 0:
103
+ down_block_additional_residuals, mid_block_additional_residual = (
104
+ down_and_mid_block_additional_residuals[:-1],
105
+ down_and_mid_block_additional_residuals[-1],
106
+ )
107
+ else:
108
+ down_block_additional_residuals, mid_block_additional_residual = None, None
109
+
110
+ unet_out = self.unet(
111
+ sample=sample,
112
+ timestep=timestep,
113
+ encoder_hidden_states=encoder_hidden_states,
114
+ down_block_additional_residuals=down_block_additional_residuals,
115
+ mid_block_additional_residual=mid_block_additional_residual,
116
+ added_cond_kwargs=added_cond_kwargs,
117
+ return_dict=False,
118
+ )
119
+ return unet_out
120
+
121
+
122
+ class RBLNUNet2DConditionModel(RBLNModel):
123
+ model_type = "rbln_model"
124
+ auto_model_class = AutoModel # feature extraction
125
+
126
+ def __post_init__(self, **kwargs):
127
+ self.dtype = torch.float32
128
+ self.in_features = self.rbln_config.meta.get("in_features", None)
129
+ if self.in_features is not None:
130
+
131
+ @dataclass
132
+ class LINEAR1:
133
+ in_features: int
134
+
135
+ @dataclass
136
+ class ADDEMBEDDING:
137
+ linear_1: LINEAR1
138
+
139
+ self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
140
+
141
+ @classmethod
142
+ def from_pretrained(cls, *args, **kwargs):
143
+ def get_model_from_task(
144
+ task: str,
145
+ model_name_or_path: Union[str, Path],
146
+ **kwargs,
147
+ ):
148
+ return UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
149
+
150
+ tasktmp = TasksManager.get_model_from_task
151
+ configtmp = AutoConfig.from_pretrained
152
+ modeltmp = AutoModel.from_pretrained
153
+ TasksManager.get_model_from_task = get_model_from_task
154
+ if kwargs.get("export", None):
155
+ # This is an ad-hoc to workaround save null values of the config.
156
+ # if export, pure optimum(not optimum-rbln) loads config using AutoConfig
157
+ # and diffusers model do not support loading by AutoConfig.
158
+ AutoConfig.from_pretrained = lambda *args, **kwargs: None
159
+ else:
160
+ AutoConfig.from_pretrained = UNet2DConditionModel.load_config
161
+ AutoModel.from_pretrained = UNet2DConditionModel.from_pretrained
162
+ rt = super().from_pretrained(*args, **kwargs)
163
+ AutoConfig.from_pretrained = configtmp
164
+ AutoModel.from_pretrained = modeltmp
165
+ TasksManager.get_model_from_task = tasktmp
166
+ return rt
167
+
168
+ @classmethod
169
+ def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
170
+ if model.config.addition_embed_type == "text_time":
171
+ return _UNet_SDXL(model).eval()
172
+ else:
173
+ return _UNet_SD(model).eval()
174
+
175
+ @classmethod
176
+ def _get_rbln_config(
177
+ cls,
178
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
179
+ model_config: "PretrainedConfig",
180
+ rbln_max_seq_len: Optional[int] = None,
181
+ rbln_text_model_hidden_size: Optional[int] = None,
182
+ rbln_batch_size: Optional[int] = None,
183
+ rbln_in_features: Optional[int] = None,
184
+ rbln_use_encode: Optional[bool] = None,
185
+ rbln_img_width: Optional[int] = None,
186
+ rbln_img_height: Optional[int] = None,
187
+ rbln_vae_scale_factor: Optional[int] = None,
188
+ rbln_is_controlnet: Optional[bool] = None,
189
+ ) -> RBLNConfig:
190
+ meta = {"type": "unet"}
191
+ if rbln_batch_size is None:
192
+ rbln_batch_size = 1
193
+
194
+ if rbln_max_seq_len is None:
195
+ rbln_max_seq_len = 77
196
+
197
+ meta["rbln_use_encode"] = rbln_use_encode
198
+
199
+ if rbln_use_encode:
200
+ input_width = rbln_img_width // rbln_vae_scale_factor
201
+ input_height = rbln_img_height // rbln_vae_scale_factor
202
+ else:
203
+ input_width, input_height = model_config.sample_size, model_config.sample_size
204
+
205
+ input_info = [
206
+ (
207
+ "sample",
208
+ [
209
+ rbln_batch_size,
210
+ model_config.in_channels,
211
+ input_width,
212
+ input_height,
213
+ ],
214
+ "float32",
215
+ ),
216
+ ("timestep", [], "float32"),
217
+ (
218
+ "encoder_hidden_states",
219
+ [
220
+ rbln_batch_size,
221
+ rbln_max_seq_len,
222
+ model_config.cross_attention_dim,
223
+ ],
224
+ "float32",
225
+ ),
226
+ ]
227
+ if rbln_is_controlnet:
228
+ input_info.extend(
229
+ [
230
+ (
231
+ f"down_block_additional_residuals_{i}",
232
+ [rbln_batch_size, model_config.block_out_channels[0], input_width, input_height],
233
+ "float32",
234
+ )
235
+ for i in range(3)
236
+ ]
237
+ )
238
+ input_info.append(
239
+ (
240
+ f"down_block_additional_residuals_{3}",
241
+ [rbln_batch_size, model_config.block_out_channels[0], input_width // 2, input_height // 2],
242
+ "float32",
243
+ )
244
+ )
245
+ input_info.extend(
246
+ [
247
+ (
248
+ f"down_block_additional_residuals_{i}",
249
+ [rbln_batch_size, model_config.block_out_channels[1], input_width // 2, input_height // 2],
250
+ "float32",
251
+ )
252
+ for i in range(4, 6)
253
+ ]
254
+ )
255
+ input_info.append(
256
+ (
257
+ f"down_block_additional_residuals_{6}",
258
+ [rbln_batch_size, model_config.block_out_channels[1], input_width // 4, input_height // 4],
259
+ "float32",
260
+ )
261
+ )
262
+ input_info.extend(
263
+ [
264
+ (
265
+ f"down_block_additional_residuals_{i}",
266
+ [rbln_batch_size, model_config.block_out_channels[2], input_width // 4, input_height // 4],
267
+ "float32",
268
+ )
269
+ for i in range(7, 9)
270
+ ]
271
+ )
272
+ input_info.extend(
273
+ [
274
+ (
275
+ f"down_block_additional_residuals_{i}",
276
+ [rbln_batch_size, model_config.block_out_channels[3], input_width // 8, input_height // 8],
277
+ "float32",
278
+ )
279
+ for i in range(9, 12)
280
+ ]
281
+ )
282
+ input_info.append(
283
+ (
284
+ "mid_block_additional_residual",
285
+ [rbln_batch_size, model_config.block_out_channels[3], input_width // 8, input_height // 8],
286
+ "float32",
287
+ )
288
+ )
289
+
290
+ rbln_runtime_config = RBLNRuntimeConfig(
291
+ input_info=input_info,
292
+ batch_size=rbln_batch_size,
293
+ )
294
+
295
+ if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
296
+ # In case of sdxl
297
+ if rbln_text_model_hidden_size is None:
298
+ rbln_text_model_hidden_size = 768
299
+ if rbln_in_features is None:
300
+ rbln_in_features = 2816
301
+ meta["in_features"] = rbln_in_features
302
+ rbln_runtime_config.input_info.append(
303
+ ("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
304
+ )
305
+ rbln_runtime_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
306
+
307
+ rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
308
+ return rbln_config
309
+
310
+ def forward(
311
+ self,
312
+ sample: torch.Tensor,
313
+ timestep: Union[torch.Tensor, float, int],
314
+ encoder_hidden_states: torch.Tensor,
315
+ class_labels: Optional[torch.Tensor] = None,
316
+ timestep_cond: Optional[torch.Tensor] = None,
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
319
+ added_cond_kwargs: Dict[str, torch.Tensor] = {},
320
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
321
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
322
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
323
+ encoder_attention_mask: Optional[torch.Tensor] = None,
324
+ return_dict: bool = True,
325
+ **kwargs,
326
+ ):
327
+ """
328
+ arg order : latent_model_input, t, prompt_embeds
329
+ """
330
+ added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
331
+
332
+ if down_block_additional_residuals is not None:
333
+ down_block_additional_residuals = [t.contiguous() for t in down_block_additional_residuals]
334
+ return (
335
+ super().forward(
336
+ sample.contiguous(),
337
+ timestep.float(),
338
+ encoder_hidden_states,
339
+ *down_block_additional_residuals,
340
+ mid_block_additional_residual,
341
+ **added_cond_kwargs,
342
+ ),
343
+ )
344
+
345
+ return (
346
+ super().forward(
347
+ sample,
348
+ timestep.float(),
349
+ encoder_hidden_states,
350
+ **added_cond_kwargs,
351
+ ),
352
+ )
@@ -0,0 +1,30 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .controlnet import RBLNMultiControlNetModel
25
+ from .stable_diffusion import (
26
+ RBLNStableDiffusionControlNetImg2ImgPipeline,
27
+ RBLNStableDiffusionImg2ImgPipeline,
28
+ RBLNStableDiffusionPipeline,
29
+ )
30
+ from .stable_diffusion_xl import RBLNStableDiffusionXLImg2ImgPipeline, RBLNStableDiffusionXLPipeline
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .multicontrolnet import RBLNMultiControlNetModel