optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. optimum/rbln/__init__.py +22 -12
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +54 -14
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
  13. optimum/rbln/diffusers/pipelines/__init__.py +22 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  30. optimum/rbln/modeling.py +572 -0
  31. optimum/rbln/modeling_alias.py +1 -1
  32. optimum/rbln/modeling_base.py +164 -758
  33. optimum/rbln/modeling_diffusers.py +51 -122
  34. optimum/rbln/transformers/__init__.py +0 -2
  35. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  36. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  37. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  38. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  39. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  40. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
  41. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
  42. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
  43. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  44. optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
  45. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  46. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  47. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  48. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  49. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
  50. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
  51. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
  52. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  54. optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
  55. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  57. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  58. optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
  59. optimum/rbln/utils/decorator_utils.py +10 -6
  60. optimum/rbln/utils/hub.py +131 -0
  61. optimum/rbln/utils/import_utils.py +15 -1
  62. optimum/rbln/utils/model_utils.py +53 -0
  63. optimum/rbln/utils/runtime_utils.py +1 -1
  64. optimum/rbln/utils/submodule.py +114 -0
  65. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  66. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
  67. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  68. optimum/rbln/transformers/generation/streamers.py +0 -139
  69. optimum/rbln/transformers/generation/utils.py +0 -397
  70. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  71. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  72. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  73. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  74. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  75. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  76. optimum/rbln/utils/context.py +0 -58
  77. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  78. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  79. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,203 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
26
+
27
+ import torch
28
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
29
+ from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
30
+ from transformers import PretrainedConfig
31
+
32
+ from ....modeling import RBLNModel
33
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
34
+ from ....modeling_diffusers import RBLNDiffusionMixin
35
+
36
+
37
+ if TYPE_CHECKING:
38
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class SD3Transformer2DModelWrapper(torch.nn.Module):
44
+ def __init__(self, model: "SD3Transformer2DModel") -> None:
45
+ super().__init__()
46
+ self.model = model
47
+
48
+ def forward(
49
+ self,
50
+ hidden_states: torch.FloatTensor,
51
+ encoder_hidden_states: torch.FloatTensor = None,
52
+ pooled_projections: torch.FloatTensor = None,
53
+ timestep: torch.LongTensor = None,
54
+ # need controlnet support?
55
+ block_controlnet_hidden_states: List = None,
56
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
57
+ return_dict: bool = True,
58
+ ):
59
+ return self.model(
60
+ hidden_states=hidden_states,
61
+ encoder_hidden_states=encoder_hidden_states,
62
+ pooled_projections=pooled_projections,
63
+ timestep=timestep,
64
+ return_dict=False,
65
+ )
66
+
67
+
68
+ class RBLNSD3Transformer2DModel(RBLNModel):
69
+ hf_library_name = "diffusers"
70
+
71
+ def __post_init__(self, **kwargs):
72
+ super().__post_init__(**kwargs)
73
+
74
+ @classmethod
75
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
76
+ return SD3Transformer2DModelWrapper(model).eval()
77
+
78
+ @classmethod
79
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
80
+ sample_size = rbln_config.get("sample_size", pipe.default_sample_size)
81
+ img_width = rbln_config.get("img_width")
82
+ img_height = rbln_config.get("img_height")
83
+
84
+ if (img_width is None) ^ (img_height is None):
85
+ raise RuntimeError
86
+
87
+ elif img_width and img_height:
88
+ sample_size = img_height // pipe.vae_scale_factor, img_width // pipe.vae_scale_factor
89
+
90
+ prompt_max_length = rbln_config.get("max_sequence_length", 256)
91
+ prompt_embed_length = pipe.tokenizer_max_length + prompt_max_length
92
+
93
+ batch_size = rbln_config.get("batch_size")
94
+ if not batch_size:
95
+ do_classifier_free_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0
96
+ batch_size = 2 if do_classifier_free_guidance else 1
97
+ else:
98
+ if rbln_config.get("guidance_scale"):
99
+ logger.warning(
100
+ "guidance_scale is ignored because batch size is explicitly specified. "
101
+ "To ensure consistent behavior, consider removing the guidance scale or "
102
+ "adjusting the batch size configuration as needed."
103
+ )
104
+
105
+ rbln_config.update(
106
+ {
107
+ "batch_size": batch_size,
108
+ "prompt_embed_length": prompt_embed_length,
109
+ "sample_size": sample_size,
110
+ }
111
+ )
112
+
113
+ return rbln_config
114
+
115
+ @classmethod
116
+ def _get_rbln_config(
117
+ cls,
118
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
119
+ model_config: "PretrainedConfig",
120
+ rbln_kwargs: Dict[str, Any] = {},
121
+ ) -> RBLNConfig:
122
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
123
+
124
+ sample_size = rbln_kwargs.get("sample_size", model_config.sample_size)
125
+ if isinstance(sample_size, int):
126
+ sample_size = (sample_size, sample_size)
127
+
128
+ rbln_prompt_embed_length = rbln_kwargs.get("prompt_embed_length")
129
+ if rbln_prompt_embed_length is None:
130
+ raise ValueError("rbln_prompt_embed_length should be specified.")
131
+
132
+ input_info = [
133
+ (
134
+ "hidden_states",
135
+ [
136
+ rbln_batch_size,
137
+ model_config.in_channels,
138
+ sample_size[0],
139
+ sample_size[1],
140
+ ],
141
+ "float32",
142
+ ),
143
+ (
144
+ "encoder_hidden_states",
145
+ [
146
+ rbln_batch_size,
147
+ rbln_prompt_embed_length,
148
+ model_config.joint_attention_dim,
149
+ ],
150
+ "float32",
151
+ ),
152
+ (
153
+ "pooled_projections",
154
+ [
155
+ rbln_batch_size,
156
+ model_config.pooled_projection_dim,
157
+ ],
158
+ "float32",
159
+ ),
160
+ ("timestep", [rbln_batch_size], "float32"),
161
+ ]
162
+
163
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
164
+
165
+ rbln_config = RBLNConfig(
166
+ rbln_cls=cls.__name__,
167
+ compile_cfgs=[rbln_compile_config],
168
+ rbln_kwargs=rbln_kwargs,
169
+ )
170
+
171
+ rbln_config.model_cfg.update({"batch_size": rbln_batch_size})
172
+
173
+ return rbln_config
174
+
175
+ @property
176
+ def compiled_batch_size(self):
177
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
178
+
179
+ def forward(
180
+ self,
181
+ hidden_states: torch.FloatTensor,
182
+ encoder_hidden_states: torch.FloatTensor = None,
183
+ pooled_projections: torch.FloatTensor = None,
184
+ timestep: torch.LongTensor = None,
185
+ block_controlnet_hidden_states: List = None,
186
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
187
+ return_dict: bool = True,
188
+ **kwargs,
189
+ ):
190
+ sample_batch_size = hidden_states.size()[0]
191
+ compiled_batch_size = self.compiled_batch_size
192
+ if sample_batch_size != compiled_batch_size and (
193
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
194
+ ):
195
+ raise ValueError(
196
+ f"Mismatch between Transformers' runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
197
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
198
+ "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
199
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
200
+ )
201
+
202
+ sample = super().forward(hidden_states, encoder_hidden_states, pooled_projections, timestep)
203
+ return Transformer2DModelOutput(sample=sample)
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .unet_2d_condition import RBLNUNet2DConditionModel
@@ -29,9 +29,9 @@ import torch
29
29
  from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
30
30
  from transformers import PretrainedConfig
31
31
 
32
- from ...modeling_base import RBLNModel
33
- from ...modeling_config import RBLNCompileConfig, RBLNConfig
34
- from ...utils.context import override_auto_classes
32
+ from ....modeling import RBLNModel
33
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
34
+ from ....modeling_diffusers import RBLNDiffusionMixin
35
35
 
36
36
 
37
37
  if TYPE_CHECKING:
@@ -125,6 +125,9 @@ class _UNet_SDXL(torch.nn.Module):
125
125
 
126
126
 
127
127
  class RBLNUNet2DConditionModel(RBLNModel):
128
+ hf_library_name = "diffusers"
129
+ auto_model_class = UNet2DConditionModel
130
+
128
131
  def __post_init__(self, **kwargs):
129
132
  super().__post_init__(**kwargs)
130
133
  self.in_features = self.rbln_config.model_cfg.get("in_features", None)
@@ -140,15 +143,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
140
143
 
141
144
  self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
142
145
 
143
- @classmethod
144
- def from_pretrained(cls, *args, **kwargs):
145
- with override_auto_classes(
146
- config_func=UNet2DConditionModel.load_config,
147
- model_func=UNet2DConditionModel.from_pretrained,
148
- ):
149
- rt = super().from_pretrained(*args, **kwargs)
150
- return rt
151
-
152
146
  @classmethod
153
147
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
154
148
  if model.config.addition_embed_type == "text_time":
@@ -156,6 +150,61 @@ class RBLNUNet2DConditionModel(RBLNModel):
156
150
  else:
157
151
  return _UNet_SD(model).eval()
158
152
 
153
+ @classmethod
154
+ def get_unet_sample_size(
155
+ cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]
156
+ ) -> Union[int, Tuple[int, int]]:
157
+ image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
158
+ if (image_size[0] is None) != (image_size[1] is None):
159
+ raise ValueError("Both image height and image width must be given or not given")
160
+ elif image_size[0] is None and image_size[1] is None:
161
+ if rbln_config["img2img_pipeline"]:
162
+ # In case of img2img, sample size of unet is determined by vae encoder.
163
+ vae_sample_size = pipe.vae.config.sample_size
164
+ if isinstance(vae_sample_size, int):
165
+ sample_size = vae_sample_size // pipe.vae_scale_factor
166
+ else:
167
+ sample_size = (
168
+ vae_sample_size[0] // pipe.vae_scale_factor,
169
+ vae_sample_size[1] // pipe.vae_scale_factor,
170
+ )
171
+ else:
172
+ sample_size = pipe.unet.config.sample_size
173
+ else:
174
+ sample_size = (image_size[0] // pipe.vae_scale_factor, image_size[1] // pipe.vae_scale_factor)
175
+
176
+ return sample_size
177
+
178
+ @classmethod
179
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
180
+ text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
181
+
182
+ batch_size = rbln_config.get("batch_size")
183
+ if not batch_size:
184
+ do_classifier_free_guidance = (
185
+ rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
186
+ )
187
+ batch_size = 2 if do_classifier_free_guidance else 1
188
+ else:
189
+ if rbln_config.get("guidance_scale"):
190
+ logger.warning(
191
+ "guidance_scale is ignored because batch size is explicitly specified. "
192
+ "To ensure consistent behavior, consider removing the guidance scale or "
193
+ "adjusting the batch size configuration as needed."
194
+ )
195
+
196
+ rbln_config.update(
197
+ {
198
+ "max_seq_len": pipe.text_encoder.config.max_position_embeddings,
199
+ "text_model_hidden_size": text_model_hidden_size,
200
+ "sample_size": cls.get_unet_sample_size(pipe, rbln_config),
201
+ "batch_size": batch_size,
202
+ "is_controlnet": "controlnet" in pipe.config.keys(),
203
+ }
204
+ )
205
+
206
+ return rbln_config
207
+
159
208
  @classmethod
160
209
  def _get_rbln_config(
161
210
  cls,
@@ -179,7 +228,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
179
228
  sample_size = (sample_size, sample_size)
180
229
 
181
230
  if max_seq_len is None:
182
- raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings )must be specified")
231
+ raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings) must be specified.")
183
232
 
184
233
  input_info = [
185
234
  ("sample", [batch_size, model_config.in_channels, sample_size[0], sample_size[1]], "float32"),
@@ -237,6 +286,10 @@ class RBLNUNet2DConditionModel(RBLNModel):
237
286
 
238
287
  return rbln_config
239
288
 
289
+ @property
290
+ def compiled_batch_size(self):
291
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
292
+
240
293
  def forward(
241
294
  self,
242
295
  sample: torch.Tensor,
@@ -254,9 +307,18 @@ class RBLNUNet2DConditionModel(RBLNModel):
254
307
  return_dict: bool = True,
255
308
  **kwargs,
256
309
  ):
257
- """
258
- arg order : latent_model_input, t, prompt_embeds
259
- """
310
+ sample_batch_size = sample.size()[0]
311
+ compiled_batch_size = self.compiled_batch_size
312
+ if sample_batch_size != compiled_batch_size and (
313
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
314
+ ):
315
+ raise ValueError(
316
+ f"Mismatch between UNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
317
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
318
+ "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
319
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
320
+ )
321
+
260
322
  added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
261
323
 
262
324
  if down_block_additional_residuals is not None:
@@ -36,8 +36,18 @@ _import_structure = {
36
36
  "stable_diffusion": [
37
37
  "RBLNStableDiffusionImg2ImgPipeline",
38
38
  "RBLNStableDiffusionPipeline",
39
+ "RBLNStableDiffusionInpaintPipeline",
40
+ ],
41
+ "stable_diffusion_xl": [
42
+ "RBLNStableDiffusionXLImg2ImgPipeline",
43
+ "RBLNStableDiffusionXLPipeline",
44
+ "RBLNStableDiffusionXLInpaintPipeline",
45
+ ],
46
+ "stable_diffusion_3": [
47
+ "RBLNStableDiffusion3Pipeline",
48
+ "RBLNStableDiffusion3Img2ImgPipeline",
49
+ "RBLNStableDiffusion3InpaintPipeline",
39
50
  ],
40
- "stable_diffusion_xl": ["RBLNStableDiffusionXLImg2ImgPipeline", "RBLNStableDiffusionXLPipeline"],
41
51
  }
42
52
  if TYPE_CHECKING:
43
53
  from .controlnet import (
@@ -49,9 +59,19 @@ if TYPE_CHECKING:
49
59
  )
50
60
  from .stable_diffusion import (
51
61
  RBLNStableDiffusionImg2ImgPipeline,
62
+ RBLNStableDiffusionInpaintPipeline,
52
63
  RBLNStableDiffusionPipeline,
53
64
  )
54
- from .stable_diffusion_xl import RBLNStableDiffusionXLImg2ImgPipeline, RBLNStableDiffusionXLPipeline
65
+ from .stable_diffusion_3 import (
66
+ RBLNStableDiffusion3Img2ImgPipeline,
67
+ RBLNStableDiffusion3InpaintPipeline,
68
+ RBLNStableDiffusion3Pipeline,
69
+ )
70
+ from .stable_diffusion_xl import (
71
+ RBLNStableDiffusionXLImg2ImgPipeline,
72
+ RBLNStableDiffusionXLInpaintPipeline,
73
+ RBLNStableDiffusionXLPipeline,
74
+ )
55
75
  else:
56
76
  import sys
57
77
 
@@ -27,12 +27,9 @@ from pathlib import Path
27
27
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
28
28
 
29
29
  import torch
30
- from diffusers import ControlNetModel
31
30
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
32
- from optimum.exporters import TasksManager
33
- from transformers import AutoConfig, AutoModel
34
31
 
35
- from ....modeling_base import RBLNModel
32
+ from ....modeling import RBLNModel
36
33
  from ....modeling_config import RBLNConfig
37
34
  from ...models.controlnet import RBLNControlNetModel
38
35
 
@@ -44,6 +41,9 @@ logger = logging.getLogger(__name__)
44
41
 
45
42
 
46
43
  class RBLNMultiControlNetModel(RBLNModel):
44
+ hf_library_name = "diffusers"
45
+ _hf_class = MultiControlNetModel
46
+
47
47
  def __init__(
48
48
  self,
49
49
  models: List[RBLNControlNetModel],
@@ -59,27 +59,6 @@ class RBLNMultiControlNetModel(RBLNModel):
59
59
  cm.extend(net.compiled_models)
60
60
  return cm
61
61
 
62
- @classmethod
63
- def from_pretrained(cls, *args, **kwargs):
64
- def get_model_from_task(
65
- task: str,
66
- model_name_or_path: Union[str, Path],
67
- **kwargs,
68
- ):
69
- return MultiControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
70
-
71
- tasktmp = TasksManager.get_model_from_task
72
- configtmp = AutoConfig.from_pretrained
73
- modeltmp = AutoModel.from_pretrained
74
- TasksManager.get_model_from_task = get_model_from_task
75
- AutoConfig.from_pretrained = ControlNetModel.load_config
76
- AutoModel.from_pretrained = MultiControlNetModel.from_pretrained
77
- rt = super().from_pretrained(*args, **kwargs)
78
- AutoConfig.from_pretrained = configtmp
79
- AutoModel.from_pretrained = modeltmp
80
- TasksManager.get_model_from_task = tasktmp
81
- return rt
82
-
83
62
  @classmethod
84
63
  def _from_pretrained(
85
64
  cls,
@@ -118,7 +97,7 @@ class RBLNMultiControlNetModel(RBLNModel):
118
97
  sample: torch.FloatTensor,
119
98
  timestep: Union[torch.Tensor, float, int],
120
99
  encoder_hidden_states: torch.Tensor,
121
- controlnet_cond: List[torch.tensor],
100
+ controlnet_cond: List[torch.Tensor],
122
101
  conditioning_scale: List[float],
123
102
  class_labels: Optional[torch.Tensor] = None,
124
103
  timestep_cond: Optional[torch.Tensor] = None,
@@ -42,6 +42,7 @@ logger = logging.get_logger(__name__)
42
42
 
43
43
 
44
44
  class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionControlNetImg2ImgPipeline):
45
+ original_class = StableDiffusionControlNetImg2ImgPipeline
45
46
  _submodules = ["text_encoder", "unet", "vae", "controlnet"]
46
47
 
47
48
  def check_inputs(
@@ -42,6 +42,7 @@ logger = logging.get_logger(__name__)
42
42
 
43
43
 
44
44
  class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetPipeline):
45
+ original_class = StableDiffusionXLControlNetPipeline
45
46
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
46
47
 
47
48
  def check_inputs(
@@ -42,6 +42,7 @@ logger = logging.get_logger(__name__)
42
42
 
43
43
 
44
44
  class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetImg2ImgPipeline):
45
+ original_class = StableDiffusionXLControlNetImg2ImgPipeline
45
46
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
46
47
 
47
48
  def check_inputs(
@@ -23,3 +23,4 @@
23
23
 
24
24
  from .pipeline_stable_diffusion import RBLNStableDiffusionPipeline
25
25
  from .pipeline_stable_diffusion_img2img import RBLNStableDiffusionImg2ImgPipeline
26
+ from .pipeline_stable_diffusion_inpaint import RBLNStableDiffusionInpaintPipeline
@@ -28,4 +28,5 @@ from ....modeling_diffusers import RBLNDiffusionMixin
28
28
 
29
29
 
30
30
  class RBLNStableDiffusionPipeline(RBLNDiffusionMixin, StableDiffusionPipeline):
31
+ original_class = StableDiffusionPipeline
31
32
  _submodules = ["text_encoder", "unet", "vae"]
@@ -28,16 +28,5 @@ from ....modeling_diffusers import RBLNDiffusionMixin
28
28
 
29
29
 
30
30
  class RBLNStableDiffusionImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionImg2ImgPipeline):
31
- """
32
- Pipeline for image-to-image generation using Stable Diffusion.
33
-
34
- This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods
35
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
36
-
37
- It implements the methods to convert a pre-trained Stable Diffusion pipeline into a RBLNStableDiffusion pipeline by:
38
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
39
- - compiling the resulting graph using the RBLN compiler.
40
- """
41
-
42
31
  original_class = StableDiffusionImg2ImgPipeline
43
32
  _submodules = ["text_encoder", "unet", "vae"]
@@ -0,0 +1,32 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+ """RBLNStableDiffusionInpaintPipeline class for inference of diffusion models on rbln devices."""
24
+
25
+ from diffusers import StableDiffusionInpaintPipeline
26
+
27
+ from ....modeling_diffusers import RBLNDiffusionMixin
28
+
29
+
30
+ class RBLNStableDiffusionInpaintPipeline(RBLNDiffusionMixin, StableDiffusionInpaintPipeline):
31
+ original_class = StableDiffusionInpaintPipeline
32
+ _submodules = ["text_encoder", "unet", "vae"]
@@ -0,0 +1,26 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .pipeline_stable_diffusion_3 import RBLNStableDiffusion3Pipeline
25
+ from .pipeline_stable_diffusion_3_img2img import RBLNStableDiffusion3Img2ImgPipeline
26
+ from .pipeline_stable_diffusion_3_inpaint import RBLNStableDiffusion3InpaintPipeline
@@ -0,0 +1,32 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+ """RBLNStableDiffusion3Pipeline class for inference of diffusion models on rbln devices."""
24
+
25
+ from diffusers import StableDiffusion3Pipeline
26
+
27
+ from ....modeling_diffusers import RBLNDiffusionMixin
28
+
29
+
30
+ class RBLNStableDiffusion3Pipeline(RBLNDiffusionMixin, StableDiffusion3Pipeline):
31
+ original_class = StableDiffusion3Pipeline
32
+ _submodules = ["transformer", "text_encoder_3", "text_encoder", "text_encoder_2", "vae"]
@@ -0,0 +1,32 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+ """RBLNStableDiffusion3Img2ImgPipeline class for inference of diffusion models on rbln devices."""
24
+
25
+ from diffusers import StableDiffusion3Img2ImgPipeline
26
+
27
+ from ....modeling_diffusers import RBLNDiffusionMixin
28
+
29
+
30
+ class RBLNStableDiffusion3Img2ImgPipeline(RBLNDiffusionMixin, StableDiffusion3Img2ImgPipeline):
31
+ original_class = StableDiffusion3Img2ImgPipeline
32
+ _submodules = ["transformer", "text_encoder_3", "text_encoder", "text_encoder_2", "vae"]