optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,17 +21,17 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ import importlib
24
25
  import logging
25
- from pathlib import Path
26
26
  from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
27
 
28
28
  import torch
29
29
  from diffusers import ControlNetModel
30
- from optimum.exporters import TasksManager
31
- from transformers import AutoConfig, AutoModel, PretrainedConfig
30
+ from transformers import PretrainedConfig
32
31
 
33
- from ...modeling_base import RBLNModel
32
+ from ...modeling import RBLNModel
34
33
  from ...modeling_config import RBLNCompileConfig, RBLNConfig
34
+ from ...modeling_diffusers import RBLNDiffusionMixin
35
35
 
36
36
 
37
37
  if TYPE_CHECKING:
@@ -105,33 +105,15 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
105
105
 
106
106
 
107
107
  class RBLNControlNetModel(RBLNModel):
108
+ hf_library_name = "diffusers"
109
+ auto_model_class = ControlNetModel
110
+
108
111
  def __post_init__(self, **kwargs):
109
112
  super().__post_init__(**kwargs)
110
113
  self.use_encoder_hidden_states = any(
111
114
  item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
112
115
  )
113
116
 
114
- @classmethod
115
- def from_pretrained(cls, *args, **kwargs):
116
- def get_model_from_task(
117
- task: str,
118
- model_name_or_path: Union[str, Path],
119
- **kwargs,
120
- ):
121
- return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
122
-
123
- tasktmp = TasksManager.get_model_from_task
124
- configtmp = AutoConfig.from_pretrained
125
- modeltmp = AutoModel.from_pretrained
126
- TasksManager.get_model_from_task = get_model_from_task
127
- AutoConfig.from_pretrained = ControlNetModel.load_config
128
- AutoModel.from_pretrained = ControlNetModel.from_pretrained
129
- rt = super().from_pretrained(*args, **kwargs)
130
- AutoConfig.from_pretrained = configtmp
131
- AutoModel.from_pretrained = modeltmp
132
- TasksManager.get_model_from_task = tasktmp
133
- return rt
134
-
135
117
  @classmethod
136
118
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
137
119
  use_encoder_hidden_states = False
@@ -144,6 +126,38 @@ class RBLNControlNetModel(RBLNModel):
144
126
  else:
145
127
  return _ControlNetModel(model).eval()
146
128
 
129
+ @classmethod
130
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
131
+ rbln_vae_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.vae.__class__.__name__}")
132
+ rbln_unet_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.unet.__class__.__name__}")
133
+ text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
134
+
135
+ batch_size = rbln_config.get("batch_size")
136
+ if not batch_size:
137
+ do_classifier_free_guidance = (
138
+ rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
139
+ )
140
+ batch_size = 2 if do_classifier_free_guidance else 1
141
+ else:
142
+ if rbln_config.get("guidance_scale"):
143
+ logger.warning(
144
+ "guidance_scale is ignored because batch size is explicitly specified. "
145
+ "To ensure consistent behavior, consider removing the guidance scale or "
146
+ "adjusting the batch size configuration as needed."
147
+ )
148
+
149
+ rbln_config.update(
150
+ {
151
+ "max_seq_len": pipe.text_encoder.config.max_position_embeddings,
152
+ "text_model_hidden_size": text_model_hidden_size,
153
+ "vae_sample_size": rbln_vae_cls.get_vae_sample_size(pipe, rbln_config),
154
+ "unet_sample_size": rbln_unet_cls.get_unet_sample_size(pipe, rbln_config),
155
+ "batch_size": batch_size,
156
+ }
157
+ )
158
+
159
+ return rbln_config
160
+
147
161
  @classmethod
148
162
  def _get_rbln_config(
149
163
  cls,
@@ -151,33 +165,35 @@ class RBLNControlNetModel(RBLNModel):
151
165
  model_config: "PretrainedConfig",
152
166
  rbln_kwargs: Dict[str, Any] = {},
153
167
  ) -> RBLNConfig:
154
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
155
- rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
156
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
157
- rbln_img_width = rbln_kwargs.get("img_width", None)
158
- rbln_img_height = rbln_kwargs.get("img_height", None)
159
- rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
168
+ batch_size = rbln_kwargs.get("batch_size")
169
+ max_seq_len = rbln_kwargs.get("max_seq_len")
170
+ unet_sample_size = rbln_kwargs.get("unet_sample_size")
171
+ vae_sample_size = rbln_kwargs.get("vae_sample_size")
160
172
 
161
- if rbln_batch_size is None:
162
- rbln_batch_size = 1
173
+ if batch_size is None:
174
+ batch_size = 1
163
175
 
164
- if rbln_max_seq_len is None:
165
- rbln_max_seq_len = 77
176
+ if unet_sample_size is None:
177
+ raise ValueError(
178
+ "`rbln_unet_sample_size` (latent height, widht) must be specified (ex. unet's sample_size)"
179
+ )
166
180
 
167
- if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
168
- raise ValueError("rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided")
181
+ if vae_sample_size is None:
182
+ raise ValueError(
183
+ "`rbln_vae_sample_size` (input image height, width) must be specified (ex. vae's sample_size)"
184
+ )
169
185
 
170
- input_width = rbln_img_width // rbln_vae_scale_factor
171
- input_height = rbln_img_height // rbln_vae_scale_factor
186
+ if max_seq_len is None:
187
+ raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings )must be specified")
172
188
 
173
189
  input_info = [
174
190
  (
175
191
  "sample",
176
192
  [
177
- rbln_batch_size,
193
+ batch_size,
178
194
  model_config.in_channels,
179
- input_height,
180
- input_width,
195
+ unet_sample_size[0],
196
+ unet_sample_size[1],
181
197
  ],
182
198
  "float32",
183
199
  ),
@@ -189,23 +205,24 @@ class RBLNControlNetModel(RBLNModel):
189
205
  input_info.append(
190
206
  (
191
207
  "encoder_hidden_states",
192
- [
193
- rbln_batch_size,
194
- rbln_max_seq_len,
195
- model_config.cross_attention_dim,
196
- ],
208
+ [batch_size, max_seq_len, model_config.cross_attention_dim],
197
209
  "float32",
198
210
  )
199
211
  )
200
212
 
201
- input_info.append(("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32"))
213
+ input_info.append(
214
+ (
215
+ "controlnet_cond",
216
+ [batch_size, 3, vae_sample_size[0], vae_sample_size[1]],
217
+ "float32",
218
+ )
219
+ )
202
220
  input_info.append(("conditioning_scale", [], "float32"))
203
221
 
204
222
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
205
- if rbln_text_model_hidden_size is None:
206
- rbln_text_model_hidden_size = 768
207
- input_info.append(("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32"))
208
- input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
223
+ rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
224
+ input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
225
+ input_info.append(("time_ids", [batch_size, 6], "float32"))
209
226
 
210
227
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
211
228
 
@@ -215,18 +232,12 @@ class RBLNControlNetModel(RBLNModel):
215
232
  rbln_kwargs=rbln_kwargs,
216
233
  )
217
234
 
218
- rbln_config.model_cfg.update(
219
- {
220
- "max_seq_len": rbln_max_seq_len,
221
- "batch_size": rbln_batch_size,
222
- "img_width": rbln_img_width,
223
- "img_height": rbln_img_height,
224
- "vae_scale_factor": rbln_vae_scale_factor,
225
- }
226
- )
227
-
228
235
  return rbln_config
229
236
 
237
+ @property
238
+ def compiled_batch_size(self):
239
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
240
+
230
241
  def forward(
231
242
  self,
232
243
  sample: torch.FloatTensor,
@@ -237,9 +248,18 @@ class RBLNControlNetModel(RBLNModel):
237
248
  added_cond_kwargs: Dict[str, torch.Tensor] = {},
238
249
  **kwargs,
239
250
  ):
240
- """
241
- The [`ControlNetModel`] forward method.
242
- """
251
+ sample_batch_size = sample.size()[0]
252
+ compiled_batch_size = self.compiled_batch_size
253
+ if sample_batch_size != compiled_batch_size and (
254
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
255
+ ):
256
+ raise ValueError(
257
+ f"Mismatch between ControlNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
258
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
259
+ "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
260
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
261
+ )
262
+
243
263
  added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
244
264
  if self.use_encoder_hidden_states:
245
265
  output = super().forward(
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .transformer_sd3 import RBLNSD3Transformer2DModel
@@ -0,0 +1,203 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
26
+
27
+ import torch
28
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
29
+ from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
30
+ from transformers import PretrainedConfig
31
+
32
+ from ....modeling import RBLNModel
33
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
34
+ from ....modeling_diffusers import RBLNDiffusionMixin
35
+
36
+
37
+ if TYPE_CHECKING:
38
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class SD3Transformer2DModelWrapper(torch.nn.Module):
44
+ def __init__(self, model: "SD3Transformer2DModel") -> None:
45
+ super().__init__()
46
+ self.model = model
47
+
48
+ def forward(
49
+ self,
50
+ hidden_states: torch.FloatTensor,
51
+ encoder_hidden_states: torch.FloatTensor = None,
52
+ pooled_projections: torch.FloatTensor = None,
53
+ timestep: torch.LongTensor = None,
54
+ # need controlnet support?
55
+ block_controlnet_hidden_states: List = None,
56
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
57
+ return_dict: bool = True,
58
+ ):
59
+ return self.model(
60
+ hidden_states=hidden_states,
61
+ encoder_hidden_states=encoder_hidden_states,
62
+ pooled_projections=pooled_projections,
63
+ timestep=timestep,
64
+ return_dict=False,
65
+ )
66
+
67
+
68
+ class RBLNSD3Transformer2DModel(RBLNModel):
69
+ hf_library_name = "diffusers"
70
+
71
+ def __post_init__(self, **kwargs):
72
+ super().__post_init__(**kwargs)
73
+
74
+ @classmethod
75
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
76
+ return SD3Transformer2DModelWrapper(model).eval()
77
+
78
+ @classmethod
79
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
80
+ sample_size = rbln_config.get("sample_size", pipe.default_sample_size)
81
+ img_width = rbln_config.get("img_width")
82
+ img_height = rbln_config.get("img_height")
83
+
84
+ if (img_width is None) ^ (img_height is None):
85
+ raise RuntimeError
86
+
87
+ elif img_width and img_height:
88
+ sample_size = img_height // pipe.vae_scale_factor, img_width // pipe.vae_scale_factor
89
+
90
+ prompt_max_length = rbln_config.get("max_sequence_length", 256)
91
+ prompt_embed_length = pipe.tokenizer_max_length + prompt_max_length
92
+
93
+ batch_size = rbln_config.get("batch_size")
94
+ if not batch_size:
95
+ do_classifier_free_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0
96
+ batch_size = 2 if do_classifier_free_guidance else 1
97
+ else:
98
+ if rbln_config.get("guidance_scale"):
99
+ logger.warning(
100
+ "guidance_scale is ignored because batch size is explicitly specified. "
101
+ "To ensure consistent behavior, consider removing the guidance scale or "
102
+ "adjusting the batch size configuration as needed."
103
+ )
104
+
105
+ rbln_config.update(
106
+ {
107
+ "batch_size": batch_size,
108
+ "prompt_embed_length": prompt_embed_length,
109
+ "sample_size": sample_size,
110
+ }
111
+ )
112
+
113
+ return rbln_config
114
+
115
+ @classmethod
116
+ def _get_rbln_config(
117
+ cls,
118
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
119
+ model_config: "PretrainedConfig",
120
+ rbln_kwargs: Dict[str, Any] = {},
121
+ ) -> RBLNConfig:
122
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
123
+
124
+ sample_size = rbln_kwargs.get("sample_size", model_config.sample_size)
125
+ if isinstance(sample_size, int):
126
+ sample_size = (sample_size, sample_size)
127
+
128
+ rbln_prompt_embed_length = rbln_kwargs.get("prompt_embed_length")
129
+ if rbln_prompt_embed_length is None:
130
+ raise ValueError("rbln_prompt_embed_length should be specified.")
131
+
132
+ input_info = [
133
+ (
134
+ "hidden_states",
135
+ [
136
+ rbln_batch_size,
137
+ model_config.in_channels,
138
+ sample_size[0],
139
+ sample_size[1],
140
+ ],
141
+ "float32",
142
+ ),
143
+ (
144
+ "encoder_hidden_states",
145
+ [
146
+ rbln_batch_size,
147
+ rbln_prompt_embed_length,
148
+ model_config.joint_attention_dim,
149
+ ],
150
+ "float32",
151
+ ),
152
+ (
153
+ "pooled_projections",
154
+ [
155
+ rbln_batch_size,
156
+ model_config.pooled_projection_dim,
157
+ ],
158
+ "float32",
159
+ ),
160
+ ("timestep", [rbln_batch_size], "float32"),
161
+ ]
162
+
163
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
164
+
165
+ rbln_config = RBLNConfig(
166
+ rbln_cls=cls.__name__,
167
+ compile_cfgs=[rbln_compile_config],
168
+ rbln_kwargs=rbln_kwargs,
169
+ )
170
+
171
+ rbln_config.model_cfg.update({"batch_size": rbln_batch_size})
172
+
173
+ return rbln_config
174
+
175
+ @property
176
+ def compiled_batch_size(self):
177
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
178
+
179
+ def forward(
180
+ self,
181
+ hidden_states: torch.FloatTensor,
182
+ encoder_hidden_states: torch.FloatTensor = None,
183
+ pooled_projections: torch.FloatTensor = None,
184
+ timestep: torch.LongTensor = None,
185
+ block_controlnet_hidden_states: List = None,
186
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
187
+ return_dict: bool = True,
188
+ **kwargs,
189
+ ):
190
+ sample_batch_size = hidden_states.size()[0]
191
+ compiled_batch_size = self.compiled_batch_size
192
+ if sample_batch_size != compiled_batch_size and (
193
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
194
+ ):
195
+ raise ValueError(
196
+ f"Mismatch between Transformers' runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
197
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
198
+ "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
199
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
200
+ )
201
+
202
+ sample = super().forward(hidden_states, encoder_hidden_states, pooled_projections, timestep)
203
+ return Transformer2DModelOutput(sample=sample)
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .unet_2d_condition import RBLNUNet2DConditionModel