optimum-rbln 0.7.3.post1__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +11 -86
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -118
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +23 -151
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post1.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -16,17 +16,18 @@ from dataclasses import dataclass
16
16
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
19
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput
20
20
  from transformers import PretrainedConfig
21
21
 
22
+ from ....configuration_utils import RBLNCompileConfig
22
23
  from ....modeling import RBLNModel
23
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
24
24
  from ....utils.logging import get_logger
25
- from ...modeling_diffusers import RBLNDiffusionMixin
25
+ from ...configurations import RBLNUNet2DConditionModelConfig
26
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
26
27
 
27
28
 
28
29
  if TYPE_CHECKING:
29
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
30
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
30
31
 
31
32
  logger = get_logger(__name__)
32
33
 
@@ -141,10 +142,13 @@ class _UNet_Kandinsky(torch.nn.Module):
141
142
  class RBLNUNet2DConditionModel(RBLNModel):
142
143
  hf_library_name = "diffusers"
143
144
  auto_model_class = UNet2DConditionModel
145
+ _rbln_config_class = RBLNUNet2DConditionModelConfig
146
+ output_class = UNet2DConditionOutput
147
+ output_key = "sample"
144
148
 
145
149
  def __post_init__(self, **kwargs):
146
150
  super().__post_init__(**kwargs)
147
- self.in_features = self.rbln_config.model_cfg.get("in_features", None)
151
+ self.in_features = self.rbln_config.in_features
148
152
  if self.in_features is not None:
149
153
 
150
154
  @dataclass
@@ -158,7 +162,9 @@ class RBLNUNet2DConditionModel(RBLNModel):
158
162
  self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
159
163
 
160
164
  @classmethod
161
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
165
+ def wrap_model_if_needed(
166
+ cls, model: torch.nn.Module, rbln_config: RBLNUNet2DConditionModelConfig
167
+ ) -> torch.nn.Module:
162
168
  if model.config.addition_embed_type == "text_time":
163
169
  return _UNet_SDXL(model).eval()
164
170
  elif model.config.addition_embed_type == "image":
@@ -168,117 +174,117 @@ class RBLNUNet2DConditionModel(RBLNModel):
168
174
 
169
175
  @classmethod
170
176
  def get_unet_sample_size(
171
- cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]
172
- ) -> Union[int, Tuple[int, int]]:
173
- image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
177
+ cls,
178
+ pipe: RBLNDiffusionMixin,
179
+ rbln_config: RBLNUNet2DConditionModelConfig,
180
+ image_size: Optional[Tuple[int, int]] = None,
181
+ ) -> Tuple[int, int]:
174
182
  scale_factor = pipe.movq_scale_factor if hasattr(pipe, "movq_scale_factor") else pipe.vae_scale_factor
175
- if (image_size[0] is None) != (image_size[1] is None):
176
- raise ValueError("Both image height and image width must be given or not given")
177
- elif image_size[0] is None and image_size[1] is None:
178
- if rbln_config["img2img_pipeline"]:
183
+
184
+ if image_size is None:
185
+ if "Img2Img" in pipe.__class__.__name__:
179
186
  if hasattr(pipe, "vae"):
180
187
  # In case of img2img, sample size of unet is determined by vae encoder.
181
188
  vae_sample_size = pipe.vae.config.sample_size
182
189
  if isinstance(vae_sample_size, int):
183
- sample_size = vae_sample_size // scale_factor
184
- else:
185
- sample_size = (
186
- vae_sample_size[0] // scale_factor,
187
- vae_sample_size[1] // scale_factor,
188
- )
190
+ vae_sample_size = (vae_sample_size, vae_sample_size)
191
+
192
+ sample_size = (
193
+ vae_sample_size[0] // scale_factor,
194
+ vae_sample_size[1] // scale_factor,
195
+ )
189
196
  elif hasattr(pipe, "movq"):
190
197
  logger.warning(
191
- "RBLN config 'img_height' and 'img_width' should have been provided for this pipeline. "
198
+ "RBLN config 'image_size' should have been provided for this pipeline. "
192
199
  "Both variable will be set 512 by default."
193
200
  )
194
201
  sample_size = (512 // scale_factor, 512 // scale_factor)
195
202
  else:
196
203
  sample_size = pipe.unet.config.sample_size
204
+ if isinstance(sample_size, int):
205
+ sample_size = (sample_size, sample_size)
197
206
  else:
198
207
  sample_size = (image_size[0] // scale_factor, image_size[1] // scale_factor)
199
208
 
200
209
  return sample_size
201
210
 
202
211
  @classmethod
203
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
204
- text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
205
- image_model_hidden_size = pipe.unet.config.encoder_hid_dim if hasattr(pipe, "unet") else None
206
-
207
- batch_size = rbln_config.get("batch_size")
208
- if not batch_size:
209
- do_classifier_free_guidance = (
210
- rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
211
- )
212
- batch_size = 2 if do_classifier_free_guidance else 1
213
- else:
214
- if rbln_config.get("guidance_scale"):
215
- logger.warning(
216
- "guidance_scale is ignored because batch size is explicitly specified. "
217
- "To ensure consistent behavior, consider removing the guidance scale or "
218
- "adjusting the batch size configuration as needed."
219
- )
212
+ def update_rbln_config_using_pipe(
213
+ cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
214
+ ) -> "RBLNDiffusionMixinConfig":
215
+ rbln_config.unet.text_model_hidden_size = (
216
+ pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
217
+ )
218
+ rbln_config.unet.image_model_hidden_size = pipe.unet.config.encoder_hid_dim if hasattr(pipe, "unet") else None
220
219
 
221
- max_seq_len = pipe.text_encoder.config.max_position_embeddings if hasattr(pipe, "text_encoder") else None
222
- rbln_config.update(
223
- {
224
- "max_seq_len": max_seq_len,
225
- "text_model_hidden_size": text_model_hidden_size,
226
- "image_model_hidden_size": image_model_hidden_size,
227
- "sample_size": cls.get_unet_sample_size(pipe, rbln_config),
228
- "batch_size": batch_size,
229
- "is_controlnet": "controlnet" in pipe.config.keys(),
230
- }
220
+ rbln_config.unet.max_seq_len = (
221
+ pipe.text_encoder.config.max_position_embeddings if hasattr(pipe, "text_encoder") else None
231
222
  )
232
223
 
224
+ rbln_config.unet.sample_size = cls.get_unet_sample_size(
225
+ pipe, rbln_config.unet, image_size=rbln_config.image_size
226
+ )
227
+ rbln_config.unet.use_additional_residuals = "controlnet" in pipe.config.keys()
228
+
233
229
  return rbln_config
234
230
 
235
231
  @classmethod
236
- def _get_rbln_config(
232
+ def _update_rbln_config(
237
233
  cls,
238
234
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
235
+ model: "PreTrainedModel",
239
236
  model_config: "PretrainedConfig",
240
- rbln_kwargs: Dict[str, Any] = {},
241
- ) -> RBLNConfig:
242
- batch_size = rbln_kwargs.get("batch_size")
243
- max_seq_len = rbln_kwargs.get("max_seq_len")
244
- sample_size = rbln_kwargs.get("sample_size")
245
- is_controlnet = rbln_kwargs.get("is_controlnet")
246
- rbln_in_features = None
247
-
248
- if batch_size is None:
249
- batch_size = 1
250
-
251
- if sample_size is None:
252
- sample_size = model_config.sample_size
237
+ rbln_config: RBLNUNet2DConditionModelConfig,
238
+ ) -> RBLNUNet2DConditionModelConfig:
239
+ if rbln_config.sample_size is None:
240
+ rbln_config.sample_size = model_config.sample_size
253
241
 
254
- if isinstance(sample_size, int):
255
- sample_size = (sample_size, sample_size)
242
+ if isinstance(rbln_config.sample_size, int):
243
+ rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
256
244
 
257
245
  input_info = [
258
- ("sample", [batch_size, model_config.in_channels, sample_size[0], sample_size[1]], "float32"),
246
+ (
247
+ "sample",
248
+ [
249
+ rbln_config.batch_size,
250
+ model_config.in_channels,
251
+ rbln_config.sample_size[0],
252
+ rbln_config.sample_size[1],
253
+ ],
254
+ "float32",
255
+ ),
259
256
  ("timestep", [], "float32"),
260
257
  ]
261
258
 
262
- if max_seq_len is not None:
259
+ if rbln_config.max_seq_len is not None:
263
260
  input_info.append(
264
- ("encoder_hidden_states", [batch_size, max_seq_len, model_config.cross_attention_dim], "float32"),
261
+ (
262
+ "encoder_hidden_states",
263
+ [rbln_config.batch_size, rbln_config.max_seq_len, model_config.cross_attention_dim],
264
+ "float32",
265
+ ),
265
266
  )
266
267
 
267
- if is_controlnet:
268
+ if rbln_config.use_additional_residuals:
268
269
  # down block addtional residuals
269
- first_shape = [batch_size, model_config.block_out_channels[0], sample_size[0], sample_size[1]]
270
- height, width = sample_size[0], sample_size[1]
270
+ first_shape = [
271
+ rbln_config.batch_size,
272
+ model_config.block_out_channels[0],
273
+ rbln_config.sample_size[0],
274
+ rbln_config.sample_size[1],
275
+ ]
276
+ height, width = rbln_config.sample_size[0], rbln_config.sample_size[1]
271
277
  input_info.append(("down_block_additional_residuals_0", first_shape, "float32"))
272
278
  name_idx = 1
273
279
  for idx, _ in enumerate(model_config.down_block_types):
274
- shape = [batch_size, model_config.block_out_channels[idx], height, width]
280
+ shape = [rbln_config.batch_size, model_config.block_out_channels[idx], height, width]
275
281
  for _ in range(model_config.layers_per_block):
276
282
  input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
277
283
  name_idx += 1
278
284
  if idx != len(model_config.down_block_types) - 1:
279
285
  height = height // 2
280
286
  width = width // 2
281
- shape = [batch_size, model_config.block_out_channels[idx], height, width]
287
+ shape = [rbln_config.batch_size, model_config.block_out_channels[idx], height, width]
282
288
  input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
283
289
  name_idx += 1
284
290
 
@@ -286,33 +292,27 @@ class RBLNUNet2DConditionModel(RBLNModel):
286
292
  num_cross_attn_blocks = model_config.down_block_types.count("CrossAttnDownBlock2D")
287
293
  out_channels = model_config.block_out_channels[-1]
288
294
  shape = [
289
- batch_size,
295
+ rbln_config.batch_size,
290
296
  out_channels,
291
- sample_size[0] // 2**num_cross_attn_blocks,
292
- sample_size[1] // 2**num_cross_attn_blocks,
297
+ rbln_config.sample_size[0] // 2**num_cross_attn_blocks,
298
+ rbln_config.sample_size[1] // 2**num_cross_attn_blocks,
293
299
  ]
294
300
  input_info.append(("mid_block_additional_residual", shape, "float32"))
295
301
 
296
302
  if hasattr(model_config, "addition_embed_type"):
297
303
  if model_config.addition_embed_type == "text_time":
298
- rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
299
- rbln_in_features = model_config.projection_class_embeddings_input_dim
300
- input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
301
- input_info.append(("time_ids", [batch_size, 6], "float32"))
304
+ rbln_config.in_features = model_config.projection_class_embeddings_input_dim
305
+ input_info.append(
306
+ ("text_embeds", [rbln_config.batch_size, rbln_config.text_model_hidden_size], "float32")
307
+ )
308
+ input_info.append(("time_ids", [rbln_config.batch_size, 6], "float32"))
302
309
  elif model_config.addition_embed_type == "image":
303
- rbln_image_model_hidden_size = rbln_kwargs["image_model_hidden_size"]
304
- input_info.append(("image_embeds", [batch_size, rbln_image_model_hidden_size], "float32"))
310
+ input_info.append(
311
+ ("image_embeds", [rbln_config.batch_size, rbln_config.image_model_hidden_size], "float32")
312
+ )
305
313
 
306
314
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
307
-
308
- rbln_config = RBLNConfig(
309
- rbln_cls=cls.__name__,
310
- compile_cfgs=[rbln_compile_config],
311
- rbln_kwargs=rbln_kwargs,
312
- )
313
-
314
- if rbln_in_features is not None:
315
- rbln_config.model_cfg["in_features"] = rbln_in_features
315
+ rbln_config.set_compile_cfgs([rbln_compile_config])
316
316
 
317
317
  return rbln_config
318
318
 
@@ -336,7 +336,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
336
336
  encoder_attention_mask: Optional[torch.Tensor] = None,
337
337
  return_dict: bool = True,
338
338
  **kwargs,
339
- ):
339
+ ) -> Union[UNet2DConditionOutput, Tuple]:
340
340
  sample_batch_size = sample.size()[0]
341
341
  compiled_batch_size = self.compiled_batch_size
342
342
  if sample_batch_size != compiled_batch_size and (
@@ -344,40 +344,37 @@ class RBLNUNet2DConditionModel(RBLNModel):
344
344
  ):
345
345
  raise ValueError(
346
346
  f"Mismatch between UNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
347
- "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
348
- "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
349
- "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
347
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size of UNet in Stable Diffusion. "
348
+ "Adjust the batch size of UNet during compilation to match the runtime batch size.\n\n"
349
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api/diffusers/pipelines/stable_diffusion.html#important-batch-size-configuration-for-guidance-scale"
350
350
  )
351
351
 
352
352
  added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
353
353
 
354
354
  if down_block_additional_residuals is not None:
355
355
  down_block_additional_residuals = [t.contiguous() for t in down_block_additional_residuals]
356
- return (
357
- super().forward(
358
- sample.contiguous(),
359
- timestep.float(),
360
- encoder_hidden_states,
361
- *down_block_additional_residuals,
362
- mid_block_additional_residual,
363
- **added_cond_kwargs,
364
- ),
356
+ return super().forward(
357
+ sample.contiguous(),
358
+ timestep.float(),
359
+ encoder_hidden_states,
360
+ *down_block_additional_residuals,
361
+ mid_block_additional_residual,
362
+ **added_cond_kwargs,
363
+ return_dict=return_dict,
365
364
  )
366
365
 
367
366
  if "image_embeds" in added_cond_kwargs:
368
- return (
369
- super().forward(
370
- sample.contiguous(),
371
- timestep.float(),
372
- **added_cond_kwargs,
373
- ),
374
- )
375
-
376
- return (
377
- super().forward(
367
+ return super().forward(
378
368
  sample.contiguous(),
379
369
  timestep.float(),
380
- encoder_hidden_states,
381
370
  **added_cond_kwargs,
382
- ),
371
+ return_dict=return_dict,
372
+ )
373
+
374
+ return super().forward(
375
+ sample.contiguous(),
376
+ timestep.float(),
377
+ encoder_hidden_states,
378
+ **added_cond_kwargs,
379
+ return_dict=return_dict,
383
380
  )
@@ -81,7 +81,7 @@ class RBLNMultiControlNetModel(RBLNModel):
81
81
  model.save_pretrained(real_save_path)
82
82
 
83
83
  @classmethod
84
- def _get_rbln_config(cls, **rbln_config_kwargs):
84
+ def _update_rbln_config(cls, **rbln_config_kwargs):
85
85
  pass
86
86
 
87
87
  def forward(
@@ -100,16 +100,15 @@ class RBLNMultiControlNetModel(RBLNModel):
100
100
  return_dict: bool = True,
101
101
  ):
102
102
  for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
103
- output = controlnet.model[0](
103
+ down_samples, mid_sample = controlnet(
104
104
  sample=sample.contiguous(),
105
105
  timestep=timestep.float(),
106
106
  encoder_hidden_states=encoder_hidden_states,
107
107
  controlnet_cond=image,
108
108
  conditioning_scale=torch.tensor(scale),
109
+ return_dict=return_dict,
109
110
  )
110
111
 
111
- down_samples, mid_sample = output[:-1], output[-1]
112
-
113
112
  # merge samples
114
113
  if i == 0:
115
114
  down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
@@ -39,6 +39,7 @@ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
39
39
 
40
40
  from ....utils.decorator_utils import remove_compile_time_kwargs
41
41
  from ....utils.logging import get_logger
42
+ from ...configurations import RBLNStableDiffusionControlNetPipelineConfig
42
43
  from ...modeling_diffusers import RBLNDiffusionMixin
43
44
  from ...models import RBLNControlNetModel
44
45
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
@@ -49,6 +50,7 @@ logger = get_logger(__name__)
49
50
 
50
51
  class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionControlNetPipeline):
51
52
  original_class = StableDiffusionControlNetPipeline
53
+ _rbln_config_class = RBLNStableDiffusionControlNetPipelineConfig
52
54
  _submodules = ["text_encoder", "unet", "vae", "controlnet"]
53
55
 
54
56
  # Almost copied from diffusers.pipelines.controlnet.pipeline_controlnet.py
@@ -37,6 +37,7 @@ from diffusers.utils import deprecate, logging
37
37
  from diffusers.utils.torch_utils import is_compiled_module
38
38
 
39
39
  from ....utils.decorator_utils import remove_compile_time_kwargs
40
+ from ...configurations import RBLNStableDiffusionControlNetImg2ImgPipelineConfig
40
41
  from ...modeling_diffusers import RBLNDiffusionMixin
41
42
  from ...models import RBLNControlNetModel
42
43
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
@@ -48,6 +49,7 @@ logger = logging.get_logger(__name__)
48
49
  class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionControlNetImg2ImgPipeline):
49
50
  original_class = StableDiffusionControlNetImg2ImgPipeline
50
51
  _submodules = ["text_encoder", "unet", "vae", "controlnet"]
52
+ _rbln_config_class = RBLNStableDiffusionControlNetImg2ImgPipelineConfig
51
53
 
52
54
  # Almost copied from diffusers.pipelines.controlnet.pipeline_controlnet_img2img.py
53
55
  def check_inputs(
@@ -37,6 +37,7 @@ from diffusers.utils import deprecate, logging
37
37
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
38
38
 
39
39
  from ....utils.decorator_utils import remove_compile_time_kwargs
40
+ from ...configurations import RBLNStableDiffusionXLControlNetPipelineConfig
40
41
  from ...modeling_diffusers import RBLNDiffusionMixin
41
42
  from ...models import RBLNControlNetModel
42
43
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
@@ -47,6 +48,7 @@ logger = logging.get_logger(__name__)
47
48
 
48
49
  class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetPipeline):
49
50
  original_class = StableDiffusionXLControlNetPipeline
51
+ _rbln_config_class = RBLNStableDiffusionXLControlNetPipelineConfig
50
52
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
51
53
 
52
54
  # Almost copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.py
@@ -37,6 +37,7 @@ from diffusers.utils import deprecate, logging
37
37
  from diffusers.utils.torch_utils import is_compiled_module
38
38
 
39
39
  from ....utils.decorator_utils import remove_compile_time_kwargs
40
+ from ...configurations import RBLNStableDiffusionXLControlNetImg2ImgPipelineConfig
40
41
  from ...modeling_diffusers import RBLNDiffusionMixin
41
42
  from ...models import RBLNControlNetModel
42
43
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
@@ -47,6 +48,7 @@ logger = logging.get_logger(__name__)
47
48
 
48
49
  class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetImg2ImgPipeline):
49
50
  original_class = StableDiffusionXLControlNetImg2ImgPipeline
51
+ _rbln_config_class = RBLNStableDiffusionXLControlNetImg2ImgPipelineConfig
50
52
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
51
53
 
52
54
  # Almost copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img.py
@@ -14,11 +14,13 @@
14
14
 
15
15
  from diffusers import KandinskyV22Pipeline
16
16
 
17
+ from ...configurations import RBLNKandinskyV22PipelineConfig
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
18
19
 
19
20
 
20
21
  class RBLNKandinskyV22Pipeline(RBLNDiffusionMixin, KandinskyV22Pipeline):
21
22
  original_class = KandinskyV22Pipeline
23
+ _rbln_config_class = RBLNKandinskyV22PipelineConfig
22
24
  _submodules = ["unet", "movq"]
23
25
 
24
26
  def get_compiled_image_size(self):
@@ -29,6 +29,7 @@ from transformers import (
29
29
  CLIPVisionModelWithProjection,
30
30
  )
31
31
 
32
+ from ...configurations import RBLNKandinskyV22CombinedPipelineConfig
32
33
  from ...modeling_diffusers import RBLNDiffusionMixin
33
34
  from .pipeline_kandinsky2_2 import RBLNKandinskyV22Pipeline
34
35
  from .pipeline_kandinsky2_2_img2img import RBLNKandinskyV22Img2ImgPipeline
@@ -38,6 +39,7 @@ from .pipeline_kandinsky2_2_prior import RBLNKandinskyV22PriorPipeline
38
39
 
39
40
  class RBLNKandinskyV22CombinedPipeline(RBLNDiffusionMixin, KandinskyV22CombinedPipeline):
40
41
  original_class = KandinskyV22CombinedPipeline
42
+ _rbln_config_class = RBLNKandinskyV22CombinedPipelineConfig
41
43
  _connected_classes = {"prior_pipe": RBLNKandinskyV22PriorPipeline, "decoder_pipe": RBLNKandinskyV22Pipeline}
42
44
  _submodules = ["prior_image_encoder", "prior_text_encoder", "prior_prior", "unet", "movq"]
43
45
  _prefix = {"prior_pipe": "prior_"}
@@ -14,11 +14,13 @@
14
14
 
15
15
  from diffusers import KandinskyV22Img2ImgPipeline
16
16
 
17
+ from ...configurations import RBLNKandinskyV22Img2ImgPipelineConfig
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
18
19
 
19
20
 
20
21
  class RBLNKandinskyV22Img2ImgPipeline(RBLNDiffusionMixin, KandinskyV22Img2ImgPipeline):
21
22
  original_class = KandinskyV22Img2ImgPipeline
23
+ _rbln_config_class = RBLNKandinskyV22Img2ImgPipelineConfig
22
24
  _submodules = ["unet", "movq"]
23
25
 
24
26
  def get_compiled_image_size(self):
@@ -14,11 +14,13 @@
14
14
 
15
15
  from diffusers import KandinskyV22InpaintPipeline
16
16
 
17
+ from ...configurations import RBLNKandinskyV22InpaintPipelineConfig
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
18
19
 
19
20
 
20
21
  class RBLNKandinskyV22InpaintPipeline(RBLNDiffusionMixin, KandinskyV22InpaintPipeline):
21
22
  original_class = KandinskyV22InpaintPipeline
23
+ _rbln_config_class = RBLNKandinskyV22InpaintPipelineConfig
22
24
  _submodules = ["unet", "movq"]
23
25
 
24
26
  def get_compiled_image_size(self):
@@ -14,9 +14,11 @@
14
14
 
15
15
  from diffusers import KandinskyV22PriorPipeline
16
16
 
17
+ from ...configurations import RBLNKandinskyV22PriorPipelineConfig
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
18
19
 
19
20
 
20
21
  class RBLNKandinskyV22PriorPipeline(RBLNDiffusionMixin, KandinskyV22PriorPipeline):
21
22
  original_class = KandinskyV22PriorPipeline
23
+ _rbln_config_class = RBLNKandinskyV22PriorPipelineConfig
22
24
  _submodules = ["text_encoder", "image_encoder", "prior"]
@@ -12,11 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+
15
16
  from diffusers import StableDiffusionPipeline
16
17
 
18
+ from ...configurations import RBLNStableDiffusionPipelineConfig
17
19
  from ...modeling_diffusers import RBLNDiffusionMixin
18
20
 
19
21
 
20
22
  class RBLNStableDiffusionPipeline(RBLNDiffusionMixin, StableDiffusionPipeline):
21
23
  original_class = StableDiffusionPipeline
22
- _submodules = ["text_encoder", "unet", "vae"]
24
+ _rbln_config_class = RBLNStableDiffusionPipelineConfig
25
+ _submodules = ["vae", "text_encoder", "unet"]
@@ -14,9 +14,11 @@
14
14
 
15
15
  from diffusers import StableDiffusionImg2ImgPipeline
16
16
 
17
+ from ...configurations import RBLNStableDiffusionImg2ImgPipelineConfig
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
18
19
 
19
20
 
20
21
  class RBLNStableDiffusionImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionImg2ImgPipeline):
21
22
  original_class = StableDiffusionImg2ImgPipeline
23
+ _rbln_config_class = RBLNStableDiffusionImg2ImgPipelineConfig
22
24
  _submodules = ["text_encoder", "unet", "vae"]
@@ -14,9 +14,11 @@
14
14
 
15
15
  from diffusers import StableDiffusionInpaintPipeline
16
16
 
17
+ from ...configurations import RBLNStableDiffusionInpaintPipelineConfig
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
18
19
 
19
20
 
20
21
  class RBLNStableDiffusionInpaintPipeline(RBLNDiffusionMixin, StableDiffusionInpaintPipeline):
21
22
  original_class = StableDiffusionInpaintPipeline
23
+ _rbln_config_class = RBLNStableDiffusionInpaintPipelineConfig
22
24
  _submodules = ["text_encoder", "unet", "vae"]
@@ -14,9 +14,11 @@
14
14
 
15
15
  from diffusers import StableDiffusion3Pipeline
16
16
 
17
+ from ...configurations import RBLNStableDiffusion3PipelineConfig
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
18
19
 
19
20
 
20
21
  class RBLNStableDiffusion3Pipeline(RBLNDiffusionMixin, StableDiffusion3Pipeline):
21
22
  original_class = StableDiffusion3Pipeline
23
+ _rbln_config_class = RBLNStableDiffusion3PipelineConfig
22
24
  _submodules = ["transformer", "text_encoder_3", "text_encoder", "text_encoder_2", "vae"]
@@ -14,9 +14,11 @@
14
14
 
15
15
  from diffusers import StableDiffusion3Img2ImgPipeline
16
16
 
17
+ from ...configurations import RBLNStableDiffusion3Img2ImgPipelineConfig
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
18
19
 
19
20
 
20
21
  class RBLNStableDiffusion3Img2ImgPipeline(RBLNDiffusionMixin, StableDiffusion3Img2ImgPipeline):
21
22
  original_class = StableDiffusion3Img2ImgPipeline
23
+ _rbln_config_class = RBLNStableDiffusion3Img2ImgPipelineConfig
22
24
  _submodules = ["transformer", "text_encoder_3", "text_encoder", "text_encoder_2", "vae"]
@@ -14,9 +14,11 @@
14
14
 
15
15
  from diffusers import StableDiffusion3InpaintPipeline
16
16
 
17
+ from ...configurations import RBLNStableDiffusion3InpaintPipelineConfig
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
18
19
 
19
20
 
20
21
  class RBLNStableDiffusion3InpaintPipeline(RBLNDiffusionMixin, StableDiffusion3InpaintPipeline):
21
22
  original_class = StableDiffusion3InpaintPipeline
23
+ _rbln_config_class = RBLNStableDiffusion3InpaintPipelineConfig
22
24
  _submodules = ["transformer", "text_encoder_3", "text_encoder", "text_encoder_2", "vae"]
@@ -14,9 +14,11 @@
14
14
 
15
15
  from diffusers import StableDiffusionXLPipeline
16
16
 
17
+ from ...configurations import RBLNStableDiffusionXLPipelineConfig
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
18
19
 
19
20
 
20
21
  class RBLNStableDiffusionXLPipeline(RBLNDiffusionMixin, StableDiffusionXLPipeline):
21
22
  original_class = StableDiffusionXLPipeline
23
+ _rbln_config_class = RBLNStableDiffusionXLPipelineConfig
22
24
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
@@ -14,9 +14,11 @@
14
14
 
15
15
  from diffusers import StableDiffusionXLImg2ImgPipeline
16
16
 
17
+ from ...configurations import RBLNStableDiffusionXLImg2ImgPipelineConfig
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
18
19
 
19
20
 
20
21
  class RBLNStableDiffusionXLImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionXLImg2ImgPipeline):
21
22
  original_class = StableDiffusionXLImg2ImgPipeline
23
+ _rbln_config_class = RBLNStableDiffusionXLImg2ImgPipelineConfig
22
24
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
@@ -14,9 +14,11 @@
14
14
 
15
15
  from diffusers import StableDiffusionXLInpaintPipeline
16
16
 
17
+ from ...configurations import RBLNStableDiffusionXLInpaintPipelineConfig
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
18
19
 
19
20
 
20
21
  class RBLNStableDiffusionXLInpaintPipeline(RBLNDiffusionMixin, StableDiffusionXLInpaintPipeline):
21
22
  original_class = StableDiffusionXLInpaintPipeline
23
+ _rbln_config_class = RBLNStableDiffusionXLInpaintPipelineConfig
22
24
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]