optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. optimum/rbln/__init__.py +156 -36
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +85 -75
  31. optimum/rbln/transformers/__init__.py +79 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +96 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  73. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  74. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  75. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  76. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  77. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  78. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  79. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  80. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  81. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  82. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  83. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  84. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  85. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  86. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  87. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  88. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  89. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  90. optimum/rbln/utils/submodule.py +26 -43
  91. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
  92. optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
  93. optimum/rbln/modeling_config.py +0 -310
  94. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  95. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -16,17 +16,18 @@ from dataclasses import dataclass
16
16
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
19
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput
20
20
  from transformers import PretrainedConfig
21
21
 
22
+ from ....configuration_utils import RBLNCompileConfig
22
23
  from ....modeling import RBLNModel
23
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
24
24
  from ....utils.logging import get_logger
25
- from ...modeling_diffusers import RBLNDiffusionMixin
25
+ from ...configurations import RBLNUNet2DConditionModelConfig
26
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
26
27
 
27
28
 
28
29
  if TYPE_CHECKING:
29
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
30
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
30
31
 
31
32
  logger = get_logger(__name__)
32
33
 
@@ -141,10 +142,13 @@ class _UNet_Kandinsky(torch.nn.Module):
141
142
  class RBLNUNet2DConditionModel(RBLNModel):
142
143
  hf_library_name = "diffusers"
143
144
  auto_model_class = UNet2DConditionModel
145
+ _rbln_config_class = RBLNUNet2DConditionModelConfig
146
+ output_class = UNet2DConditionOutput
147
+ output_key = "sample"
144
148
 
145
149
  def __post_init__(self, **kwargs):
146
150
  super().__post_init__(**kwargs)
147
- self.in_features = self.rbln_config.model_cfg.get("in_features", None)
151
+ self.in_features = self.rbln_config.in_features
148
152
  if self.in_features is not None:
149
153
 
150
154
  @dataclass
@@ -158,7 +162,9 @@ class RBLNUNet2DConditionModel(RBLNModel):
158
162
  self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
159
163
 
160
164
  @classmethod
161
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
165
+ def wrap_model_if_needed(
166
+ cls, model: torch.nn.Module, rbln_config: RBLNUNet2DConditionModelConfig
167
+ ) -> torch.nn.Module:
162
168
  if model.config.addition_embed_type == "text_time":
163
169
  return _UNet_SDXL(model).eval()
164
170
  elif model.config.addition_embed_type == "image":
@@ -168,117 +174,117 @@ class RBLNUNet2DConditionModel(RBLNModel):
168
174
 
169
175
  @classmethod
170
176
  def get_unet_sample_size(
171
- cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]
172
- ) -> Union[int, Tuple[int, int]]:
173
- image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
177
+ cls,
178
+ pipe: RBLNDiffusionMixin,
179
+ rbln_config: RBLNUNet2DConditionModelConfig,
180
+ image_size: Optional[Tuple[int, int]] = None,
181
+ ) -> Tuple[int, int]:
174
182
  scale_factor = pipe.movq_scale_factor if hasattr(pipe, "movq_scale_factor") else pipe.vae_scale_factor
175
- if (image_size[0] is None) != (image_size[1] is None):
176
- raise ValueError("Both image height and image width must be given or not given")
177
- elif image_size[0] is None and image_size[1] is None:
178
- if rbln_config["img2img_pipeline"]:
183
+
184
+ if image_size is None:
185
+ if "Img2Img" in pipe.__class__.__name__:
179
186
  if hasattr(pipe, "vae"):
180
187
  # In case of img2img, sample size of unet is determined by vae encoder.
181
188
  vae_sample_size = pipe.vae.config.sample_size
182
189
  if isinstance(vae_sample_size, int):
183
- sample_size = vae_sample_size // scale_factor
184
- else:
185
- sample_size = (
186
- vae_sample_size[0] // scale_factor,
187
- vae_sample_size[1] // scale_factor,
188
- )
190
+ vae_sample_size = (vae_sample_size, vae_sample_size)
191
+
192
+ sample_size = (
193
+ vae_sample_size[0] // scale_factor,
194
+ vae_sample_size[1] // scale_factor,
195
+ )
189
196
  elif hasattr(pipe, "movq"):
190
197
  logger.warning(
191
- "RBLN config 'img_height' and 'img_width' should have been provided for this pipeline. "
198
+ "RBLN config 'image_size' should have been provided for this pipeline. "
192
199
  "Both variable will be set 512 by default."
193
200
  )
194
201
  sample_size = (512 // scale_factor, 512 // scale_factor)
195
202
  else:
196
203
  sample_size = pipe.unet.config.sample_size
204
+ if isinstance(sample_size, int):
205
+ sample_size = (sample_size, sample_size)
197
206
  else:
198
207
  sample_size = (image_size[0] // scale_factor, image_size[1] // scale_factor)
199
208
 
200
209
  return sample_size
201
210
 
202
211
  @classmethod
203
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
204
- text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
205
- image_model_hidden_size = pipe.unet.config.encoder_hid_dim if hasattr(pipe, "unet") else None
206
-
207
- batch_size = rbln_config.get("batch_size")
208
- if not batch_size:
209
- do_classifier_free_guidance = (
210
- rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
211
- )
212
- batch_size = 2 if do_classifier_free_guidance else 1
213
- else:
214
- if rbln_config.get("guidance_scale"):
215
- logger.warning(
216
- "guidance_scale is ignored because batch size is explicitly specified. "
217
- "To ensure consistent behavior, consider removing the guidance scale or "
218
- "adjusting the batch size configuration as needed."
219
- )
212
+ def update_rbln_config_using_pipe(
213
+ cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
214
+ ) -> "RBLNDiffusionMixinConfig":
215
+ rbln_config.unet.text_model_hidden_size = (
216
+ pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
217
+ )
218
+ rbln_config.unet.image_model_hidden_size = pipe.unet.config.encoder_hid_dim if hasattr(pipe, "unet") else None
220
219
 
221
- max_seq_len = pipe.text_encoder.config.max_position_embeddings if hasattr(pipe, "text_encoder") else None
222
- rbln_config.update(
223
- {
224
- "max_seq_len": max_seq_len,
225
- "text_model_hidden_size": text_model_hidden_size,
226
- "image_model_hidden_size": image_model_hidden_size,
227
- "sample_size": cls.get_unet_sample_size(pipe, rbln_config),
228
- "batch_size": batch_size,
229
- "is_controlnet": "controlnet" in pipe.config.keys(),
230
- }
220
+ rbln_config.unet.max_seq_len = (
221
+ pipe.text_encoder.config.max_position_embeddings if hasattr(pipe, "text_encoder") else None
231
222
  )
232
223
 
224
+ rbln_config.unet.sample_size = cls.get_unet_sample_size(
225
+ pipe, rbln_config.unet, image_size=rbln_config.image_size
226
+ )
227
+ rbln_config.unet.use_additional_residuals = "controlnet" in pipe.config.keys()
228
+
233
229
  return rbln_config
234
230
 
235
231
  @classmethod
236
- def _get_rbln_config(
232
+ def _update_rbln_config(
237
233
  cls,
238
234
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
235
+ model: "PreTrainedModel",
239
236
  model_config: "PretrainedConfig",
240
- rbln_kwargs: Dict[str, Any] = {},
241
- ) -> RBLNConfig:
242
- batch_size = rbln_kwargs.get("batch_size")
243
- max_seq_len = rbln_kwargs.get("max_seq_len")
244
- sample_size = rbln_kwargs.get("sample_size")
245
- is_controlnet = rbln_kwargs.get("is_controlnet")
246
- rbln_in_features = None
247
-
248
- if batch_size is None:
249
- batch_size = 1
250
-
251
- if sample_size is None:
252
- sample_size = model_config.sample_size
237
+ rbln_config: RBLNUNet2DConditionModelConfig,
238
+ ) -> RBLNUNet2DConditionModelConfig:
239
+ if rbln_config.sample_size is None:
240
+ rbln_config.sample_size = model_config.sample_size
253
241
 
254
- if isinstance(sample_size, int):
255
- sample_size = (sample_size, sample_size)
242
+ if isinstance(rbln_config.sample_size, int):
243
+ rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
256
244
 
257
245
  input_info = [
258
- ("sample", [batch_size, model_config.in_channels, sample_size[0], sample_size[1]], "float32"),
246
+ (
247
+ "sample",
248
+ [
249
+ rbln_config.batch_size,
250
+ model_config.in_channels,
251
+ rbln_config.sample_size[0],
252
+ rbln_config.sample_size[1],
253
+ ],
254
+ "float32",
255
+ ),
259
256
  ("timestep", [], "float32"),
260
257
  ]
261
258
 
262
- if max_seq_len is not None:
259
+ if rbln_config.max_seq_len is not None:
263
260
  input_info.append(
264
- ("encoder_hidden_states", [batch_size, max_seq_len, model_config.cross_attention_dim], "float32"),
261
+ (
262
+ "encoder_hidden_states",
263
+ [rbln_config.batch_size, rbln_config.max_seq_len, model_config.cross_attention_dim],
264
+ "float32",
265
+ ),
265
266
  )
266
267
 
267
- if is_controlnet:
268
+ if rbln_config.use_additional_residuals:
268
269
  # down block addtional residuals
269
- first_shape = [batch_size, model_config.block_out_channels[0], sample_size[0], sample_size[1]]
270
- height, width = sample_size[0], sample_size[1]
270
+ first_shape = [
271
+ rbln_config.batch_size,
272
+ model_config.block_out_channels[0],
273
+ rbln_config.sample_size[0],
274
+ rbln_config.sample_size[1],
275
+ ]
276
+ height, width = rbln_config.sample_size[0], rbln_config.sample_size[1]
271
277
  input_info.append(("down_block_additional_residuals_0", first_shape, "float32"))
272
278
  name_idx = 1
273
279
  for idx, _ in enumerate(model_config.down_block_types):
274
- shape = [batch_size, model_config.block_out_channels[idx], height, width]
280
+ shape = [rbln_config.batch_size, model_config.block_out_channels[idx], height, width]
275
281
  for _ in range(model_config.layers_per_block):
276
282
  input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
277
283
  name_idx += 1
278
284
  if idx != len(model_config.down_block_types) - 1:
279
285
  height = height // 2
280
286
  width = width // 2
281
- shape = [batch_size, model_config.block_out_channels[idx], height, width]
287
+ shape = [rbln_config.batch_size, model_config.block_out_channels[idx], height, width]
282
288
  input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
283
289
  name_idx += 1
284
290
 
@@ -286,33 +292,27 @@ class RBLNUNet2DConditionModel(RBLNModel):
286
292
  num_cross_attn_blocks = model_config.down_block_types.count("CrossAttnDownBlock2D")
287
293
  out_channels = model_config.block_out_channels[-1]
288
294
  shape = [
289
- batch_size,
295
+ rbln_config.batch_size,
290
296
  out_channels,
291
- sample_size[0] // 2**num_cross_attn_blocks,
292
- sample_size[1] // 2**num_cross_attn_blocks,
297
+ rbln_config.sample_size[0] // 2**num_cross_attn_blocks,
298
+ rbln_config.sample_size[1] // 2**num_cross_attn_blocks,
293
299
  ]
294
300
  input_info.append(("mid_block_additional_residual", shape, "float32"))
295
301
 
296
302
  if hasattr(model_config, "addition_embed_type"):
297
303
  if model_config.addition_embed_type == "text_time":
298
- rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
299
- rbln_in_features = model_config.projection_class_embeddings_input_dim
300
- input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
301
- input_info.append(("time_ids", [batch_size, 6], "float32"))
304
+ rbln_config.in_features = model_config.projection_class_embeddings_input_dim
305
+ input_info.append(
306
+ ("text_embeds", [rbln_config.batch_size, rbln_config.text_model_hidden_size], "float32")
307
+ )
308
+ input_info.append(("time_ids", [rbln_config.batch_size, 6], "float32"))
302
309
  elif model_config.addition_embed_type == "image":
303
- rbln_image_model_hidden_size = rbln_kwargs["image_model_hidden_size"]
304
- input_info.append(("image_embeds", [batch_size, rbln_image_model_hidden_size], "float32"))
310
+ input_info.append(
311
+ ("image_embeds", [rbln_config.batch_size, rbln_config.image_model_hidden_size], "float32")
312
+ )
305
313
 
306
314
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
307
-
308
- rbln_config = RBLNConfig(
309
- rbln_cls=cls.__name__,
310
- compile_cfgs=[rbln_compile_config],
311
- rbln_kwargs=rbln_kwargs,
312
- )
313
-
314
- if rbln_in_features is not None:
315
- rbln_config.model_cfg["in_features"] = rbln_in_features
315
+ rbln_config.set_compile_cfgs([rbln_compile_config])
316
316
 
317
317
  return rbln_config
318
318
 
@@ -336,7 +336,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
336
336
  encoder_attention_mask: Optional[torch.Tensor] = None,
337
337
  return_dict: bool = True,
338
338
  **kwargs,
339
- ):
339
+ ) -> Union[UNet2DConditionOutput, Tuple]:
340
340
  sample_batch_size = sample.size()[0]
341
341
  compiled_batch_size = self.compiled_batch_size
342
342
  if sample_batch_size != compiled_batch_size and (
@@ -344,8 +344,8 @@ class RBLNUNet2DConditionModel(RBLNModel):
344
344
  ):
345
345
  raise ValueError(
346
346
  f"Mismatch between UNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
347
- "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
348
- "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
347
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size of UNet in Stable Diffusion. "
348
+ "Adjust the batch size of UNet during compilation to match the runtime batch size.\n\n"
349
349
  "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
350
350
  )
351
351
 
@@ -353,31 +353,28 @@ class RBLNUNet2DConditionModel(RBLNModel):
353
353
 
354
354
  if down_block_additional_residuals is not None:
355
355
  down_block_additional_residuals = [t.contiguous() for t in down_block_additional_residuals]
356
- return (
357
- super().forward(
358
- sample.contiguous(),
359
- timestep.float(),
360
- encoder_hidden_states,
361
- *down_block_additional_residuals,
362
- mid_block_additional_residual,
363
- **added_cond_kwargs,
364
- ),
356
+ return super().forward(
357
+ sample.contiguous(),
358
+ timestep.float(),
359
+ encoder_hidden_states,
360
+ *down_block_additional_residuals,
361
+ mid_block_additional_residual,
362
+ **added_cond_kwargs,
363
+ return_dict=return_dict,
365
364
  )
366
365
 
367
366
  if "image_embeds" in added_cond_kwargs:
368
- return (
369
- super().forward(
370
- sample.contiguous(),
371
- timestep.float(),
372
- **added_cond_kwargs,
373
- ),
374
- )
375
-
376
- return (
377
- super().forward(
367
+ return super().forward(
378
368
  sample.contiguous(),
379
369
  timestep.float(),
380
- encoder_hidden_states,
381
370
  **added_cond_kwargs,
382
- ),
371
+ return_dict=return_dict,
372
+ )
373
+
374
+ return super().forward(
375
+ sample.contiguous(),
376
+ timestep.float(),
377
+ encoder_hidden_states,
378
+ **added_cond_kwargs,
379
+ return_dict=return_dict,
383
380
  )
@@ -81,7 +81,7 @@ class RBLNMultiControlNetModel(RBLNModel):
81
81
  model.save_pretrained(real_save_path)
82
82
 
83
83
  @classmethod
84
- def _get_rbln_config(cls, **rbln_config_kwargs):
84
+ def _update_rbln_config(cls, **rbln_config_kwargs):
85
85
  pass
86
86
 
87
87
  def forward(
@@ -100,16 +100,15 @@ class RBLNMultiControlNetModel(RBLNModel):
100
100
  return_dict: bool = True,
101
101
  ):
102
102
  for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
103
- output = controlnet.model[0](
103
+ down_samples, mid_sample = controlnet(
104
104
  sample=sample.contiguous(),
105
105
  timestep=timestep.float(),
106
106
  encoder_hidden_states=encoder_hidden_states,
107
107
  controlnet_cond=image,
108
108
  conditioning_scale=torch.tensor(scale),
109
+ return_dict=return_dict,
109
110
  )
110
111
 
111
- down_samples, mid_sample = output[:-1], output[-1]
112
-
113
112
  # merge samples
114
113
  if i == 0:
115
114
  down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+
15
16
  from diffusers import StableDiffusionPipeline
16
17
 
17
18
  from ...modeling_diffusers import RBLNDiffusionMixin
@@ -19,4 +20,4 @@ from ...modeling_diffusers import RBLNDiffusionMixin
19
20
 
20
21
  class RBLNStableDiffusionPipeline(RBLNDiffusionMixin, StableDiffusionPipeline):
21
22
  original_class = StableDiffusionPipeline
22
- _submodules = ["text_encoder", "unet", "vae"]
23
+ _submodules = ["vae", "text_encoder", "unet"]
optimum/rbln/modeling.py CHANGED
@@ -14,15 +14,16 @@
14
14
 
15
15
  from pathlib import Path
16
16
  from tempfile import TemporaryDirectory
17
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
17
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
18
18
 
19
19
  import rebel
20
20
  import torch
21
21
  from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
22
22
  from transformers import AutoConfig, PretrainedConfig
23
+ from transformers.modeling_outputs import BaseModelOutput
23
24
 
25
+ from .configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNModelConfig
24
26
  from .modeling_base import RBLNBaseModel
25
- from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, use_rbln_config
26
27
  from .utils.logging import get_logger
27
28
 
28
29
 
@@ -48,6 +49,9 @@ class RBLNModel(RBLNBaseModel):
48
49
  ```
49
50
  """
50
51
 
52
+ output_class = None
53
+ output_key = "last_hidden_state"
54
+
51
55
  @classmethod
52
56
  def update_kwargs(cls, kwargs):
53
57
  """
@@ -56,12 +60,7 @@ class RBLNModel(RBLNBaseModel):
56
60
  For example, `torchscript`=True should be set because torch.jit
57
61
  does not support `transformers` output instances as module output;
58
62
  """
59
- kwargs.update(
60
- {
61
- "torchscript": True,
62
- "return_dict": False,
63
- }
64
- )
63
+ kwargs.update({"torchscript": True})
65
64
  return kwargs
66
65
 
67
66
  @classmethod
@@ -70,7 +69,7 @@ class RBLNModel(RBLNBaseModel):
70
69
  model: "PreTrainedModel",
71
70
  save_dir_path: Path,
72
71
  subfolder: str,
73
- rbln_config: RBLNConfig,
72
+ rbln_config: RBLNModelConfig,
74
73
  ):
75
74
  """
76
75
  If you are unavoidably running on a CPU rather than an RBLN device,
@@ -78,30 +77,29 @@ class RBLNModel(RBLNBaseModel):
78
77
  """
79
78
 
80
79
  @classmethod
81
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
80
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
82
81
  # Wrap the model if needed.
83
82
  return model
84
83
 
85
84
  @classmethod
86
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
85
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
87
86
  model = cls.wrap_model_if_needed(model, rbln_config)
88
87
  rbln_compile_config = rbln_config.compile_cfgs[0]
89
88
  compiled_model = cls.compile(model, rbln_compile_config=rbln_compile_config)
90
89
  return compiled_model
91
90
 
92
91
  @classmethod
93
- @use_rbln_config
94
92
  def from_model(
95
93
  cls,
96
94
  model: "PreTrainedModel",
97
95
  config: Optional[PretrainedConfig] = None,
98
- rbln_config: Dict[str, Any] = {},
96
+ rbln_config: Optional[RBLNModelConfig] = None,
99
97
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
100
98
  subfolder: str = "",
101
99
  **kwargs,
102
100
  ):
103
101
  preprocessors = kwargs.pop("preprocessors", [])
104
- rbln_kwargs = rbln_config
102
+ rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
105
103
 
106
104
  # Directory to save compile artifacts(.rbln) and original configs
107
105
  if model_save_dir is None:
@@ -141,14 +139,21 @@ class RBLNModel(RBLNBaseModel):
141
139
  for preprocessor in preprocessors:
142
140
  preprocessor.save_pretrained(save_dir_path / subfolder)
143
141
 
144
- # ad-hoc
145
- rbln_kwargs["n_model_params"] = sum(p.numel() for p in model.parameters())
142
+ # Load submodules
143
+ if len(cls._rbln_submodules) > 0:
144
+ rbln_submodules = cls._load_submodules(
145
+ model=model,
146
+ model_save_dir=save_dir,
147
+ rbln_config=rbln_config,
148
+ **kwargs,
149
+ )
150
+ else:
151
+ rbln_submodules = []
146
152
 
147
153
  # Get compilation arguments (e.g. input_info)
148
- rbln_config: RBLNConfig = cls.get_rbln_config(
149
- preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
154
+ rbln_config: RBLNModelConfig = cls.update_rbln_config(
155
+ preprocessors=preprocessors, model=model, model_config=config, rbln_config=rbln_config
150
156
  )
151
- # rbln_config.update_runtime_cfg(rbln_kwargs) # This is done in get_rbln_config
152
157
 
153
158
  compiled_model: Union[rebel.RBLNCompiledModel, Dict[str, rebel.RBLNCompiledModel]] = cls.get_compiled_model(
154
159
  model, rbln_config=rbln_config
@@ -167,17 +172,6 @@ class RBLNModel(RBLNBaseModel):
167
172
  # Save torch artifacts (e.g. embedding matrix if needed.)
168
173
  cls.save_torch_artifacts(model, save_dir_path=save_dir_path, subfolder=subfolder, rbln_config=rbln_config)
169
174
 
170
- # Load submodules
171
- if len(cls._rbln_submodules) > 0:
172
- rbln_submodules = cls._load_submodules(
173
- model=model,
174
- model_save_dir=save_dir,
175
- rbln_kwargs=rbln_kwargs,
176
- **kwargs,
177
- )
178
- else:
179
- rbln_submodules = []
180
-
181
175
  # Instantiate
182
176
  return cls._from_pretrained(
183
177
  model_id=save_dir_path,
@@ -201,8 +195,8 @@ class RBLNModel(RBLNBaseModel):
201
195
  subfolder: str = "",
202
196
  local_files_only: bool = False,
203
197
  trust_remote_code: bool = False,
204
- # Some rbln-kwargs should be applied before loading torch module (i.e. quantized llm)
205
- rbln_kwargs: Optional[Dict[str, Any]] = None,
198
+ # Some rbln-config should be applied before loading torch module (i.e. quantized llm)
199
+ rbln_config: Optional[RBLNModelConfig] = None,
206
200
  **kwargs,
207
201
  ) -> "PreTrainedModel":
208
202
  kwargs = cls.update_kwargs(kwargs)
@@ -222,18 +216,43 @@ class RBLNModel(RBLNBaseModel):
222
216
  def _create_runtimes(
223
217
  cls,
224
218
  compiled_models: List[rebel.RBLNCompiledModel],
225
- rbln_device_map: Dict[str, int],
226
- activate_profiler: Optional[bool] = None,
219
+ rbln_config: RBLNModelConfig,
227
220
  ) -> List[rebel.Runtime]:
228
- if DEFAULT_COMPILED_MODEL_NAME not in rbln_device_map:
221
+ if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
229
222
  cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
230
223
 
231
- device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
232
224
  return [
233
- compiled_model.create_runtime(tensor_type="pt", device=device, activate_profiler=activate_profiler)
225
+ rebel.Runtime(
226
+ compiled_model,
227
+ tensor_type="pt",
228
+ device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
229
+ activate_profiler=rbln_config.activate_profiler,
230
+ )
234
231
  for compiled_model in compiled_models
235
232
  ]
236
233
 
237
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
234
+ def forward(self, *args, return_dict: Optional[bool] = None, **kwargs):
235
+ if self.hf_library_name == "transformers":
236
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
237
+ else:
238
+ return_dict = True if return_dict is None else return_dict
239
+
240
+ # Get output from the model
238
241
  output = self.model[0](*args, **kwargs)
239
- return output
242
+
243
+ # Format output according to task requirements
244
+ return self._prepare_output(output, return_dict)
245
+
246
+ def _prepare_output(self, output, return_dict):
247
+ """
248
+ Prepare model output based on return_dict flag.
249
+ This method can be overridden by subclasses to provide task-specific output handling.
250
+ """
251
+ if not return_dict:
252
+ return (output,) if not isinstance(output, (tuple, list)) else output
253
+ else:
254
+ if self.output_class is None:
255
+ return BaseModelOutput(last_hidden_state=output)
256
+
257
+ # Create output with the appropriate class and key
258
+ return self.output_class(**{self.output_key: output})