optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. optimum/rbln/__init__.py +164 -36
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +107 -78
  31. optimum/rbln/transformers/__init__.py +87 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +108 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +115 -84
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +282 -216
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  76. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  77. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  78. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  79. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  80. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  81. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  82. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  83. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  84. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  85. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  86. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  87. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  88. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  89. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  90. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  91. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  92. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  93. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  94. optimum/rbln/utils/runtime_utils.py +33 -2
  95. optimum/rbln/utils/submodule.py +26 -43
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/METADATA +1 -1
  97. optimum_rbln-0.7.4a6.dist-info/RECORD +166 -0
  98. optimum/rbln/modeling_config.py +0 -310
  99. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  100. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/WHEEL +0 -0
  101. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/licenses/LICENSE +0 -0
@@ -13,20 +13,22 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import importlib
16
- from typing import TYPE_CHECKING, Any, Dict, Optional, Union
16
+ from typing import TYPE_CHECKING, Dict, Optional, Union
17
17
 
18
18
  import torch
19
19
  from diffusers import ControlNetModel
20
+ from diffusers.models.controlnet import ControlNetOutput
20
21
  from transformers import PretrainedConfig
21
22
 
23
+ from ...configuration_utils import RBLNCompileConfig, RBLNModelConfig
22
24
  from ...modeling import RBLNModel
23
- from ...modeling_config import RBLNCompileConfig, RBLNConfig
24
25
  from ...utils.logging import get_logger
25
- from ..modeling_diffusers import RBLNDiffusionMixin
26
+ from ..configurations import RBLNControlNetModelConfig
27
+ from ..modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
26
28
 
27
29
 
28
30
  if TYPE_CHECKING:
29
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
31
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
30
32
 
31
33
 
32
34
  logger = get_logger(__name__)
@@ -98,6 +100,7 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
98
100
  class RBLNControlNetModel(RBLNModel):
99
101
  hf_library_name = "diffusers"
100
102
  auto_model_class = ControlNetModel
103
+ output_class = ControlNetOutput
101
104
 
102
105
  def __post_init__(self, **kwargs):
103
106
  super().__post_init__(**kwargs)
@@ -106,7 +109,7 @@ class RBLNControlNetModel(RBLNModel):
106
109
  )
107
110
 
108
111
  @classmethod
109
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
112
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
110
113
  use_encoder_hidden_states = False
111
114
  for down_block in model.down_blocks:
112
115
  if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
@@ -118,73 +121,50 @@ class RBLNControlNetModel(RBLNModel):
118
121
  return _ControlNetModel(model).eval()
119
122
 
120
123
  @classmethod
121
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
124
+ def update_rbln_config_using_pipe(
125
+ cls,
126
+ pipe: RBLNDiffusionMixin,
127
+ rbln_config: "RBLNDiffusionMixinConfig",
128
+ submodule_name: str,
129
+ ) -> "RBLNDiffusionMixinConfig":
122
130
  rbln_vae_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.vae.__class__.__name__}")
123
131
  rbln_unet_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.unet.__class__.__name__}")
124
- text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
125
-
126
- batch_size = rbln_config.get("batch_size")
127
- if not batch_size:
128
- do_classifier_free_guidance = (
129
- rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
130
- )
131
- batch_size = 2 if do_classifier_free_guidance else 1
132
- else:
133
- if rbln_config.get("guidance_scale"):
134
- logger.warning(
135
- "guidance_scale is ignored because batch size is explicitly specified. "
136
- "To ensure consistent behavior, consider removing the guidance scale or "
137
- "adjusting the batch size configuration as needed."
138
- )
139
132
 
140
- rbln_config.update(
141
- {
142
- "max_seq_len": pipe.text_encoder.config.max_position_embeddings,
143
- "text_model_hidden_size": text_model_hidden_size,
144
- "vae_sample_size": rbln_vae_cls.get_vae_sample_size(pipe, rbln_config),
145
- "unet_sample_size": rbln_unet_cls.get_unet_sample_size(pipe, rbln_config),
146
- "batch_size": batch_size,
147
- }
133
+ rbln_config.controlnet.max_seq_len = pipe.text_encoder.config.max_position_embeddings
134
+ text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
135
+ rbln_config.controlnet.text_model_hidden_size = text_model_hidden_size
136
+ rbln_config.controlnet.vae_sample_size = rbln_vae_cls.get_vae_sample_size(pipe, rbln_config.vae)
137
+ rbln_config.controlnet.unet_sample_size = rbln_unet_cls.get_unet_sample_size(
138
+ pipe, rbln_config.unet, image_size=rbln_config.image_size
148
139
  )
149
140
 
150
141
  return rbln_config
151
142
 
152
143
  @classmethod
153
- def _get_rbln_config(
144
+ def _update_rbln_config(
154
145
  cls,
155
146
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
147
+ model: "PreTrainedModel",
156
148
  model_config: "PretrainedConfig",
157
- rbln_kwargs: Dict[str, Any] = {},
158
- ) -> RBLNConfig:
159
- batch_size = rbln_kwargs.get("batch_size")
160
- max_seq_len = rbln_kwargs.get("max_seq_len")
161
- unet_sample_size = rbln_kwargs.get("unet_sample_size")
162
- vae_sample_size = rbln_kwargs.get("vae_sample_size")
163
-
164
- if batch_size is None:
165
- batch_size = 1
166
-
167
- if unet_sample_size is None:
168
- raise ValueError(
169
- "`rbln_unet_sample_size` (latent height, widht) must be specified (ex. unet's sample_size)"
170
- )
149
+ rbln_config: RBLNControlNetModelConfig,
150
+ ) -> RBLNModelConfig:
151
+ if rbln_config.unet_sample_size is None:
152
+ raise ValueError("`unet_sample_size` (latent height, width) must be specified (ex. unet's sample_size)")
171
153
 
172
- if vae_sample_size is None:
173
- raise ValueError(
174
- "`rbln_vae_sample_size` (input image height, width) must be specified (ex. vae's sample_size)"
175
- )
154
+ if rbln_config.vae_sample_size is None:
155
+ raise ValueError("`vae_sample_size` (input image height, width) must be specified (ex. vae's sample_size)")
176
156
 
177
- if max_seq_len is None:
178
- raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings )must be specified")
157
+ if rbln_config.max_seq_len is None:
158
+ raise ValueError("`max_seq_len` (ex. text_encoder's max_position_embeddings) must be specified")
179
159
 
180
160
  input_info = [
181
161
  (
182
162
  "sample",
183
163
  [
184
- batch_size,
164
+ rbln_config.batch_size,
185
165
  model_config.in_channels,
186
- unet_sample_size[0],
187
- unet_sample_size[1],
166
+ rbln_config.unet_sample_size[0],
167
+ rbln_config.unet_sample_size[1],
188
168
  ],
189
169
  "float32",
190
170
  ),
@@ -196,7 +176,7 @@ class RBLNControlNetModel(RBLNModel):
196
176
  input_info.append(
197
177
  (
198
178
  "encoder_hidden_states",
199
- [batch_size, max_seq_len, model_config.cross_attention_dim],
179
+ [rbln_config.batch_size, rbln_config.max_seq_len, model_config.cross_attention_dim],
200
180
  "float32",
201
181
  )
202
182
  )
@@ -204,25 +184,18 @@ class RBLNControlNetModel(RBLNModel):
204
184
  input_info.append(
205
185
  (
206
186
  "controlnet_cond",
207
- [batch_size, 3, vae_sample_size[0], vae_sample_size[1]],
187
+ [rbln_config.batch_size, 3, rbln_config.vae_sample_size[0], rbln_config.vae_sample_size[1]],
208
188
  "float32",
209
189
  )
210
190
  )
211
191
  input_info.append(("conditioning_scale", [], "float32"))
212
192
 
213
193
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
214
- rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
215
- input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
216
- input_info.append(("time_ids", [batch_size, 6], "float32"))
194
+ input_info.append(("text_embeds", [rbln_config.batch_size, rbln_config.text_model_hidden_size], "float32"))
195
+ input_info.append(("time_ids", [rbln_config.batch_size, 6], "float32"))
217
196
 
218
197
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
219
-
220
- rbln_config = RBLNConfig(
221
- rbln_cls=cls.__name__,
222
- compile_cfgs=[rbln_compile_config],
223
- rbln_kwargs=rbln_kwargs,
224
- )
225
-
198
+ rbln_config.set_compile_cfgs([rbln_compile_config])
226
199
  return rbln_config
227
200
 
228
201
  @property
@@ -237,6 +210,7 @@ class RBLNControlNetModel(RBLNModel):
237
210
  controlnet_cond: torch.FloatTensor,
238
211
  conditioning_scale: torch.Tensor = 1.0,
239
212
  added_cond_kwargs: Dict[str, torch.Tensor] = {},
213
+ return_dict: bool = True,
240
214
  **kwargs,
241
215
  ):
242
216
  sample_batch_size = sample.size()[0]
@@ -246,14 +220,14 @@ class RBLNControlNetModel(RBLNModel):
246
220
  ):
247
221
  raise ValueError(
248
222
  f"Mismatch between ControlNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
249
- "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
250
- "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
223
+ "This may be caused by the 'guidance_scale' parameter, which doubles the runtime batch size of ControlNet in Stable Diffusion. "
224
+ "Adjust the batch size of ControlNet during compilation to match the runtime batch size.\n\n"
251
225
  "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
252
226
  )
253
227
 
254
228
  added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
255
229
  if self.use_encoder_hidden_states:
256
- output = super().forward(
230
+ output = self.model[0](
257
231
  sample.contiguous(),
258
232
  timestep.float(),
259
233
  encoder_hidden_states,
@@ -262,14 +236,25 @@ class RBLNControlNetModel(RBLNModel):
262
236
  **added_cond_kwargs,
263
237
  )
264
238
  else:
265
- output = super().forward(
239
+ output = self.model[0](
266
240
  sample.contiguous(),
267
241
  timestep.float(),
268
242
  controlnet_cond,
269
243
  torch.tensor(conditioning_scale),
270
244
  **added_cond_kwargs,
271
245
  )
246
+
272
247
  down_block_res_samples = output[:-1]
273
248
  mid_block_res_sample = output[-1]
249
+ output = (down_block_res_samples, mid_block_res_sample)
250
+ output = self._prepare_output(output, return_dict)
251
+ return output
274
252
 
275
- return down_block_res_samples, mid_block_res_sample
253
+ def _prepare_output(self, output, return_dict):
254
+ if not return_dict:
255
+ return (output,) if not isinstance(output, (tuple, list)) else output
256
+ else:
257
+ return ControlNetOutput(
258
+ down_block_res_samples=output[:-1],
259
+ mid_block_res_sample=output[-1],
260
+ )
@@ -13,37 +13,22 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from pathlib import Path
16
- from typing import Any, Dict, Optional, Union
16
+ from typing import TYPE_CHECKING, Optional, Union
17
17
 
18
18
  import torch
19
19
  from diffusers.models.transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
20
- from transformers import PretrainedConfig, PreTrainedModel
21
20
 
21
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
22
22
  from ....modeling import RBLNModel
23
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
24
23
  from ....utils.logging import get_logger
25
- from ....utils.runtime_utils import RBLNPytorchRuntime
26
- from ...modeling_diffusers import RBLNDiffusionMixin
24
+ from ...configurations.models import RBLNPriorTransformerConfig
25
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
27
26
 
28
27
 
29
- logger = get_logger(__name__)
30
-
28
+ if TYPE_CHECKING:
29
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
31
30
 
32
- class RBLNRuntimePriorTransformer(RBLNPytorchRuntime):
33
- def forward(
34
- self, hidden_states, timestep, proj_embedding, encoder_hidden_states, attention_mask, return_dict: bool = True
35
- ):
36
- predicted_image_embedding = super().forward(
37
- hidden_states,
38
- timestep,
39
- proj_embedding,
40
- encoder_hidden_states,
41
- attention_mask,
42
- )
43
- if return_dict:
44
- return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
45
- else:
46
- return (predicted_image_embedding,)
31
+ logger = get_logger(__name__)
47
32
 
48
33
 
49
34
  class _PriorTransformer(torch.nn.Module):
@@ -73,51 +58,28 @@ class _PriorTransformer(torch.nn.Module):
73
58
  class RBLNPriorTransformer(RBLNModel):
74
59
  hf_library_name = "diffusers"
75
60
  auto_model_class = PriorTransformer
61
+ output_class = PriorTransformerOutput
62
+ output_key = "predicted_image_embedding"
76
63
 
77
64
  def __post_init__(self, **kwargs):
78
65
  super().__post_init__(**kwargs)
79
- self.runtime = RBLNRuntimePriorTransformer(runtime=self.model[0])
80
66
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
81
67
  self.clip_mean = artifacts["clip_mean"]
82
68
  self.clip_std = artifacts["clip_std"]
83
69
 
84
70
  @classmethod
85
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
71
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
86
72
  return _PriorTransformer(model).eval()
87
73
 
88
74
  @classmethod
89
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
90
- batch_size = rbln_config.get("batch_size")
91
- if not batch_size:
92
- do_classifier_free_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0
93
- batch_size = 2 if do_classifier_free_guidance else 1
94
- else:
95
- if rbln_config.get("guidance_scale"):
96
- logger.warning(
97
- "guidance_scale is ignored because batch size is explicitly specified. "
98
- "To ensure consistent behavior, consider removing the guidance scale or "
99
- "adjusting the batch size configuration as needed."
100
- )
101
- embedding_dim = rbln_config.get("embedding_dim", pipe.prior.config.embedding_dim)
102
- num_embeddings = rbln_config.get("num_embeddings", pipe.prior.config.num_embeddings)
103
-
104
- rbln_config.update(
105
- {
106
- "batch_size": batch_size,
107
- "embedding_dim": embedding_dim,
108
- "num_embeddings": num_embeddings,
109
- }
110
- )
111
-
75
+ def update_rbln_config_using_pipe(
76
+ cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
77
+ ) -> "RBLNDiffusionMixinConfig":
112
78
  return rbln_config
113
79
 
114
80
  @classmethod
115
81
  def save_torch_artifacts(
116
- cls,
117
- model: "PreTrainedModel",
118
- save_dir_path: Path,
119
- subfolder: str,
120
- rbln_config: RBLNConfig,
82
+ cls, model: "PreTrainedModel", save_dir_path: Path, subfolder: str, rbln_config: RBLNModelConfig
121
83
  ):
122
84
  save_dict = {}
123
85
  save_dict["clip_mean"] = model.clip_mean
@@ -125,50 +87,51 @@ class RBLNPriorTransformer(RBLNModel):
125
87
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
126
88
 
127
89
  @classmethod
128
- def _get_rbln_config(
90
+ def _update_rbln_config(
129
91
  cls,
130
- preprocessors,
131
- model_config: PretrainedConfig,
132
- rbln_kwargs,
133
- ) -> RBLNConfig:
134
- batch_size = rbln_kwargs.get("batch_size") or 1
135
- embedding_dim = rbln_kwargs.get("embedding_dim") or model_config.embedding_dim
136
- num_embeddings = rbln_kwargs.get("num_embeddings") or model_config.num_embeddings
92
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
93
+ model: "PreTrainedModel",
94
+ model_config: "PretrainedConfig",
95
+ rbln_config: RBLNPriorTransformerConfig,
96
+ ) -> RBLNPriorTransformerConfig:
97
+ rbln_config.embedding_dim = rbln_config.embedding_dim or model_config.embedding_dim
98
+ rbln_config.num_embeddings = rbln_config.num_embeddings or model_config.num_embeddings
137
99
 
138
100
  input_info = [
139
- ("hidden_states", [batch_size, embedding_dim], "float32"),
101
+ ("hidden_states", [rbln_config.batch_size, rbln_config.embedding_dim], "float32"),
140
102
  ("timestep", [], "float32"),
141
- ("proj_embedding", [batch_size, embedding_dim], "float32"),
142
- ("encoder_hidden_states", [batch_size, num_embeddings, embedding_dim], "float32"),
143
- ("attention_mask", [batch_size, num_embeddings], "float32"),
103
+ ("proj_embedding", [rbln_config.batch_size, rbln_config.embedding_dim], "float32"),
104
+ (
105
+ "encoder_hidden_states",
106
+ [rbln_config.batch_size, rbln_config.num_embeddings, rbln_config.embedding_dim],
107
+ "float32",
108
+ ),
109
+ ("attention_mask", [rbln_config.batch_size, rbln_config.num_embeddings], "float32"),
144
110
  ]
145
111
 
146
112
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
147
- rbln_config = RBLNConfig(
148
- rbln_cls=cls.__name__,
149
- compile_cfgs=[rbln_compile_config],
150
- rbln_kwargs=rbln_kwargs,
151
- )
113
+ rbln_config.set_compile_cfgs([rbln_compile_config])
152
114
  return rbln_config
153
115
 
116
+ def post_process_latents(self, prior_latents):
117
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
118
+ return prior_latents
119
+
154
120
  def forward(
155
121
  self,
156
122
  hidden_states,
157
123
  timestep: Union[torch.Tensor, float, int],
158
124
  proj_embedding: torch.Tensor,
159
125
  encoder_hidden_states: Optional[torch.Tensor] = None,
160
- attention_mask: Optional[torch.BoolTensor] = None,
126
+ attention_mask: Optional[torch.Tensor] = None,
161
127
  return_dict: bool = True,
162
128
  ):
163
- return self.runtime.forward(
164
- hidden_states.contiguous(),
129
+ # Convert timestep(long) and attention_mask(bool) to float
130
+ return super().forward(
131
+ hidden_states,
165
132
  timestep.float(),
166
133
  proj_embedding,
167
134
  encoder_hidden_states,
168
135
  attention_mask.float(),
169
- return_dict,
136
+ return_dict=return_dict,
170
137
  )
171
-
172
- def post_process_latents(self, prior_latents):
173
- prior_latents = (prior_latents * self.clip_std) + self.clip_mean
174
- return prior_latents
@@ -19,14 +19,16 @@ from diffusers.models.modeling_outputs import Transformer2DModelOutput
19
19
  from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
20
20
  from transformers import PretrainedConfig
21
21
 
22
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
22
23
  from ....modeling import RBLNModel
23
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
24
24
  from ....utils.logging import get_logger
25
- from ...modeling_diffusers import RBLNDiffusionMixin
25
+ from ...configurations import RBLNSD3Transformer2DModelConfig
26
26
 
27
27
 
28
28
  if TYPE_CHECKING:
29
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
29
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
30
+
31
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
30
32
 
31
33
  logger = get_logger(__name__)
32
34
 
@@ -58,84 +60,64 @@ class SD3Transformer2DModelWrapper(torch.nn.Module):
58
60
 
59
61
  class RBLNSD3Transformer2DModel(RBLNModel):
60
62
  hf_library_name = "diffusers"
63
+ auto_model_class = SD3Transformer2DModel
64
+ output_class = Transformer2DModelOutput
65
+ output_key = "sample"
61
66
 
62
67
  def __post_init__(self, **kwargs):
63
68
  super().__post_init__(**kwargs)
64
69
 
65
70
  @classmethod
66
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
71
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
67
72
  return SD3Transformer2DModelWrapper(model).eval()
68
73
 
69
74
  @classmethod
70
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
71
- sample_size = rbln_config.get("sample_size", pipe.default_sample_size)
72
- img_width = rbln_config.get("img_width")
73
- img_height = rbln_config.get("img_height")
74
-
75
- if (img_width is None) ^ (img_height is None):
76
- raise RuntimeError
77
-
78
- elif img_width and img_height:
79
- sample_size = img_height // pipe.vae_scale_factor, img_width // pipe.vae_scale_factor
80
-
81
- prompt_max_length = rbln_config.get("max_sequence_length", 256)
82
- prompt_embed_length = pipe.tokenizer_max_length + prompt_max_length
83
-
84
- batch_size = rbln_config.get("batch_size")
85
- if not batch_size:
86
- do_classifier_free_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0
87
- batch_size = 2 if do_classifier_free_guidance else 1
88
- else:
89
- if rbln_config.get("guidance_scale"):
90
- logger.warning(
91
- "guidance_scale is ignored because batch size is explicitly specified. "
92
- "To ensure consistent behavior, consider removing the guidance scale or "
93
- "adjusting the batch size configuration as needed."
75
+ def update_rbln_config_using_pipe(
76
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
77
+ ) -> "RBLNDiffusionMixinConfig":
78
+ if rbln_config.sample_size is None:
79
+ if rbln_config.image_size is not None:
80
+ rbln_config.transformer.sample_size = (
81
+ rbln_config.image_size[0] // pipe.vae_scale_factor,
82
+ rbln_config.image_size[1] // pipe.vae_scale_factor,
94
83
  )
84
+ else:
85
+ rbln_config.transformer.sample_size = pipe.default_sample_size
95
86
 
96
- rbln_config.update(
97
- {
98
- "batch_size": batch_size,
99
- "prompt_embed_length": prompt_embed_length,
100
- "sample_size": sample_size,
101
- }
102
- )
103
-
87
+ prompt_embed_length = pipe.tokenizer_max_length + rbln_config.max_seq_len
88
+ rbln_config.transformer.prompt_embed_length = prompt_embed_length
104
89
  return rbln_config
105
90
 
106
91
  @classmethod
107
- def _get_rbln_config(
92
+ def _update_rbln_config(
108
93
  cls,
109
94
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
95
+ model: "PreTrainedModel",
110
96
  model_config: "PretrainedConfig",
111
- rbln_kwargs: Dict[str, Any] = {},
112
- ) -> RBLNConfig:
113
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
97
+ rbln_config: RBLNSD3Transformer2DModelConfig,
98
+ ) -> RBLNSD3Transformer2DModelConfig:
99
+ if rbln_config.sample_size is None:
100
+ rbln_config.sample_size = model_config.sample_size
114
101
 
115
- sample_size = rbln_kwargs.get("sample_size", model_config.sample_size)
116
- if isinstance(sample_size, int):
117
- sample_size = (sample_size, sample_size)
118
-
119
- rbln_prompt_embed_length = rbln_kwargs.get("prompt_embed_length")
120
- if rbln_prompt_embed_length is None:
121
- raise ValueError("rbln_prompt_embed_length should be specified.")
102
+ if isinstance(rbln_config.sample_size, int):
103
+ rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
122
104
 
123
105
  input_info = [
124
106
  (
125
107
  "hidden_states",
126
108
  [
127
- rbln_batch_size,
109
+ rbln_config.batch_size,
128
110
  model_config.in_channels,
129
- sample_size[0],
130
- sample_size[1],
111
+ rbln_config.sample_size[0],
112
+ rbln_config.sample_size[1],
131
113
  ],
132
114
  "float32",
133
115
  ),
134
116
  (
135
117
  "encoder_hidden_states",
136
118
  [
137
- rbln_batch_size,
138
- rbln_prompt_embed_length,
119
+ rbln_config.batch_size,
120
+ rbln_config.prompt_embed_length,
139
121
  model_config.joint_attention_dim,
140
122
  ],
141
123
  "float32",
@@ -143,24 +125,16 @@ class RBLNSD3Transformer2DModel(RBLNModel):
143
125
  (
144
126
  "pooled_projections",
145
127
  [
146
- rbln_batch_size,
128
+ rbln_config.batch_size,
147
129
  model_config.pooled_projection_dim,
148
130
  ],
149
131
  "float32",
150
132
  ),
151
- ("timestep", [rbln_batch_size], "float32"),
133
+ ("timestep", [rbln_config.batch_size], "float32"),
152
134
  ]
153
135
 
154
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
155
-
156
- rbln_config = RBLNConfig(
157
- rbln_cls=cls.__name__,
158
- compile_cfgs=[rbln_compile_config],
159
- rbln_kwargs=rbln_kwargs,
160
- )
161
-
162
- rbln_config.model_cfg.update({"batch_size": rbln_batch_size})
163
-
136
+ compile_config = RBLNCompileConfig(input_info=input_info)
137
+ rbln_config.set_compile_cfgs([compile_config])
164
138
  return rbln_config
165
139
 
166
140
  @property
@@ -184,11 +158,12 @@ class RBLNSD3Transformer2DModel(RBLNModel):
184
158
  sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
185
159
  ):
186
160
  raise ValueError(
187
- f"Mismatch between Transformers' runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
161
+ f"Mismatch between transformer's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
188
162
  "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
189
- "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
163
+ "Adjust the batch size of transformer during compilation.\n\n"
190
164
  "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
191
165
  )
192
166
 
193
- sample = super().forward(hidden_states, encoder_hidden_states, pooled_projections, timestep)
194
- return Transformer2DModelOutput(sample=sample)
167
+ return super().forward(
168
+ hidden_states, encoder_hidden_states, pooled_projections, timestep, return_dict=return_dict
169
+ )