optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -23,16 +23,15 @@
23
23
 
24
24
  import logging
25
25
  from dataclasses import dataclass
26
- from pathlib import Path
27
26
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
28
27
 
29
28
  import torch
30
29
  from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
31
- from optimum.exporters import TasksManager
32
- from transformers import AutoConfig, AutoModel, PretrainedConfig
30
+ from transformers import PretrainedConfig
33
31
 
34
- from ...modeling_base import RBLNModel
35
- from ...modeling_config import RBLNCompileConfig, RBLNConfig
32
+ from ....modeling import RBLNModel
33
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
34
+ from ....modeling_diffusers import RBLNDiffusionMixin
36
35
 
37
36
 
38
37
  if TYPE_CHECKING:
@@ -126,6 +125,9 @@ class _UNet_SDXL(torch.nn.Module):
126
125
 
127
126
 
128
127
  class RBLNUNet2DConditionModel(RBLNModel):
128
+ hf_library_name = "diffusers"
129
+ auto_model_class = UNet2DConditionModel
130
+
129
131
  def __post_init__(self, **kwargs):
130
132
  super().__post_init__(**kwargs)
131
133
  self.in_features = self.rbln_config.model_cfg.get("in_features", None)
@@ -141,33 +143,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
141
143
 
142
144
  self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
143
145
 
144
- @classmethod
145
- def from_pretrained(cls, *args, **kwargs):
146
- def get_model_from_task(
147
- task: str,
148
- model_name_or_path: Union[str, Path],
149
- **kwargs,
150
- ):
151
- return UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
152
-
153
- tasktmp = TasksManager.get_model_from_task
154
- configtmp = AutoConfig.from_pretrained
155
- modeltmp = AutoModel.from_pretrained
156
- TasksManager.get_model_from_task = get_model_from_task
157
- if kwargs.get("export", None):
158
- # This is an ad-hoc to workaround save null values of the config.
159
- # if export, pure optimum(not optimum-rbln) loads config using AutoConfig
160
- # and diffusers model do not support loading by AutoConfig.
161
- AutoConfig.from_pretrained = lambda *args, **kwargs: None
162
- else:
163
- AutoConfig.from_pretrained = UNet2DConditionModel.load_config
164
- AutoModel.from_pretrained = UNet2DConditionModel.from_pretrained
165
- rt = super().from_pretrained(*args, **kwargs)
166
- AutoConfig.from_pretrained = configtmp
167
- AutoModel.from_pretrained = modeltmp
168
- TasksManager.get_model_from_task = tasktmp
169
- return rt
170
-
171
146
  @classmethod
172
147
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
173
148
  if model.config.addition_embed_type == "text_time":
@@ -175,6 +150,61 @@ class RBLNUNet2DConditionModel(RBLNModel):
175
150
  else:
176
151
  return _UNet_SD(model).eval()
177
152
 
153
+ @classmethod
154
+ def get_unet_sample_size(
155
+ cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]
156
+ ) -> Union[int, Tuple[int, int]]:
157
+ image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
158
+ if (image_size[0] is None) != (image_size[1] is None):
159
+ raise ValueError("Both image height and image width must be given or not given")
160
+ elif image_size[0] is None and image_size[1] is None:
161
+ if rbln_config["img2img_pipeline"]:
162
+ # In case of img2img, sample size of unet is determined by vae encoder.
163
+ vae_sample_size = pipe.vae.config.sample_size
164
+ if isinstance(vae_sample_size, int):
165
+ sample_size = vae_sample_size // pipe.vae_scale_factor
166
+ else:
167
+ sample_size = (
168
+ vae_sample_size[0] // pipe.vae_scale_factor,
169
+ vae_sample_size[1] // pipe.vae_scale_factor,
170
+ )
171
+ else:
172
+ sample_size = pipe.unet.config.sample_size
173
+ else:
174
+ sample_size = (image_size[0] // pipe.vae_scale_factor, image_size[1] // pipe.vae_scale_factor)
175
+
176
+ return sample_size
177
+
178
+ @classmethod
179
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
180
+ text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
181
+
182
+ batch_size = rbln_config.get("batch_size")
183
+ if not batch_size:
184
+ do_classifier_free_guidance = (
185
+ rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
186
+ )
187
+ batch_size = 2 if do_classifier_free_guidance else 1
188
+ else:
189
+ if rbln_config.get("guidance_scale"):
190
+ logger.warning(
191
+ "guidance_scale is ignored because batch size is explicitly specified. "
192
+ "To ensure consistent behavior, consider removing the guidance scale or "
193
+ "adjusting the batch size configuration as needed."
194
+ )
195
+
196
+ rbln_config.update(
197
+ {
198
+ "max_seq_len": pipe.text_encoder.config.max_position_embeddings,
199
+ "text_model_hidden_size": text_model_hidden_size,
200
+ "sample_size": cls.get_unet_sample_size(pipe, rbln_config),
201
+ "batch_size": batch_size,
202
+ "is_controlnet": "controlnet" in pipe.config.keys(),
203
+ }
204
+ )
205
+
206
+ return rbln_config
207
+
178
208
  @classmethod
179
209
  def _get_rbln_config(
180
210
  cls,
@@ -182,137 +212,68 @@ class RBLNUNet2DConditionModel(RBLNModel):
182
212
  model_config: "PretrainedConfig",
183
213
  rbln_kwargs: Dict[str, Any] = {},
184
214
  ) -> RBLNConfig:
185
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
186
- rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
187
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
188
- rbln_in_features = rbln_kwargs.get("in_features", None)
189
- rbln_use_encode = rbln_kwargs.get("use_encode", None)
190
- rbln_img_width = rbln_kwargs.get("img_width", None)
191
- rbln_img_height = rbln_kwargs.get("img_height", None)
192
- rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
193
- rbln_is_controlnet = rbln_kwargs.get("is_controlnet", None)
194
-
195
- if rbln_max_seq_len is None:
196
- rbln_max_seq_len = 77
197
- if rbln_batch_size is None:
198
- rbln_batch_size = 1
199
-
200
- if rbln_use_encode:
201
- if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
202
- raise ValueError(
203
- "rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided when rbln_use_encode is True"
204
- )
205
- input_width = rbln_img_width // rbln_vae_scale_factor
206
- input_height = rbln_img_height // rbln_vae_scale_factor
207
- else:
208
- input_width, input_height = model_config.sample_size, model_config.sample_size
215
+ batch_size = rbln_kwargs.get("batch_size")
216
+ max_seq_len = rbln_kwargs.get("max_seq_len")
217
+ sample_size = rbln_kwargs.get("sample_size")
218
+ is_controlnet = rbln_kwargs.get("is_controlnet")
219
+ rbln_in_features = None
220
+
221
+ if batch_size is None:
222
+ batch_size = 1
223
+
224
+ if sample_size is None:
225
+ sample_size = model_config.sample_size
226
+
227
+ if isinstance(sample_size, int):
228
+ sample_size = (sample_size, sample_size)
229
+
230
+ if max_seq_len is None:
231
+ raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings) must be specified.")
209
232
 
210
233
  input_info = [
211
- (
212
- "sample",
213
- [
214
- rbln_batch_size,
215
- model_config.in_channels,
216
- input_height,
217
- input_width,
218
- ],
219
- "float32",
220
- ),
234
+ ("sample", [batch_size, model_config.in_channels, sample_size[0], sample_size[1]], "float32"),
221
235
  ("timestep", [], "float32"),
222
- (
223
- "encoder_hidden_states",
224
- [
225
- rbln_batch_size,
226
- rbln_max_seq_len,
227
- model_config.cross_attention_dim,
228
- ],
229
- "float32",
230
- ),
236
+ ("encoder_hidden_states", [batch_size, max_seq_len, model_config.cross_attention_dim], "float32"),
231
237
  ]
232
238
 
233
- if rbln_is_controlnet:
234
- if len(model_config.block_out_channels) > 0:
235
- input_info.extend(
236
- [
237
- (
238
- f"down_block_additional_residuals_{i}",
239
- [rbln_batch_size, model_config.block_out_channels[0], input_height, input_width],
240
- "float32",
241
- )
242
- for i in range(3)
243
- ]
244
- )
245
- if len(model_config.block_out_channels) > 1:
246
- input_info.append(
247
- (
248
- "down_block_additional_residuals_3",
249
- [rbln_batch_size, model_config.block_out_channels[0], input_height // 2, input_width // 2],
250
- "float32",
251
- )
252
- )
253
- input_info.extend(
254
- [
255
- (
256
- f"down_block_additional_residuals_{i}",
257
- [rbln_batch_size, model_config.block_out_channels[1], input_height // 2, input_width // 2],
258
- "float32",
259
- )
260
- for i in range(4, 6)
261
- ]
262
- )
263
- if len(model_config.block_out_channels) > 2:
264
- input_info.append(
265
- (
266
- f"down_block_additional_residuals_{6}",
267
- [rbln_batch_size, model_config.block_out_channels[1], input_height // 4, input_width // 4],
268
- "float32",
269
- )
270
- )
271
- input_info.extend(
272
- [
273
- (
274
- f"down_block_additional_residuals_{i}",
275
- [rbln_batch_size, model_config.block_out_channels[2], input_height // 4, input_width // 4],
276
- "float32",
277
- )
278
- for i in range(7, 9)
279
- ]
280
- )
281
- if len(model_config.block_out_channels) > 3:
282
- input_info.extend(
283
- [
284
- (
285
- f"down_block_additional_residuals_{i}",
286
- [rbln_batch_size, model_config.block_out_channels[3], input_height // 8, input_width // 8],
287
- "float32",
288
- )
289
- for i in range(9, 12)
290
- ]
291
- )
292
- input_info.append(
293
- (
294
- "mid_block_additional_residual",
295
- [
296
- rbln_batch_size,
297
- model_config.block_out_channels[-1],
298
- input_height // 2 ** (len(model_config.block_out_channels) - 1),
299
- input_width // 2 ** (len(model_config.block_out_channels) - 1),
300
- ],
301
- "float32",
302
- )
303
- )
239
+ if is_controlnet:
240
+ # down block addtional residuals
241
+ first_shape = [batch_size, model_config.block_out_channels[0], sample_size[0], sample_size[1]]
242
+ height, width = sample_size[0], sample_size[1]
243
+ input_info.append(("down_block_additional_residuals_0", first_shape, "float32"))
244
+ name_idx = 1
245
+ for idx, _ in enumerate(model_config.down_block_types):
246
+ shape = [batch_size, model_config.block_out_channels[idx], height, width]
247
+ for _ in range(model_config.layers_per_block):
248
+ input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
249
+ name_idx += 1
250
+ if idx != len(model_config.down_block_types) - 1:
251
+ height = height // 2
252
+ width = width // 2
253
+ shape = [batch_size, model_config.block_out_channels[idx], height, width]
254
+ input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
255
+ name_idx += 1
256
+
257
+ # mid block addtional residual
258
+ num_cross_attn_blocks = model_config.down_block_types.count("CrossAttnDownBlock2D")
259
+ out_channels = model_config.block_out_channels[-1]
260
+ shape = [
261
+ batch_size,
262
+ out_channels,
263
+ sample_size[0] // 2**num_cross_attn_blocks,
264
+ sample_size[1] // 2**num_cross_attn_blocks,
265
+ ]
266
+ input_info.append(("mid_block_additional_residual", shape, "float32"))
304
267
 
305
268
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
306
269
 
307
270
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
308
- if rbln_text_model_hidden_size is None:
309
- rbln_text_model_hidden_size = 768
310
- if rbln_in_features is None:
311
- rbln_in_features = model_config.projection_class_embeddings_input_dim
271
+ rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
272
+ rbln_in_features = model_config.projection_class_embeddings_input_dim
312
273
  rbln_compile_config.input_info.append(
313
- ("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
274
+ ("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32")
314
275
  )
315
- rbln_compile_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
276
+ rbln_compile_config.input_info.append(("time_ids", [batch_size, 6], "float32"))
316
277
 
317
278
  rbln_config = RBLNConfig(
318
279
  rbln_cls=cls.__name__,
@@ -320,19 +281,15 @@ class RBLNUNet2DConditionModel(RBLNModel):
320
281
  rbln_kwargs=rbln_kwargs,
321
282
  )
322
283
 
323
- rbln_config.model_cfg.update(
324
- {
325
- "max_seq_len": rbln_max_seq_len,
326
- "batch_size": rbln_batch_size,
327
- "use_encode": rbln_use_encode,
328
- }
329
- )
330
-
331
284
  if rbln_in_features is not None:
332
285
  rbln_config.model_cfg["in_features"] = rbln_in_features
333
286
 
334
287
  return rbln_config
335
288
 
289
+ @property
290
+ def compiled_batch_size(self):
291
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
292
+
336
293
  def forward(
337
294
  self,
338
295
  sample: torch.Tensor,
@@ -350,9 +307,18 @@ class RBLNUNet2DConditionModel(RBLNModel):
350
307
  return_dict: bool = True,
351
308
  **kwargs,
352
309
  ):
353
- """
354
- arg order : latent_model_input, t, prompt_embeds
355
- """
310
+ sample_batch_size = sample.size()[0]
311
+ compiled_batch_size = self.compiled_batch_size
312
+ if sample_batch_size != compiled_batch_size and (
313
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
314
+ ):
315
+ raise ValueError(
316
+ f"Mismatch between UNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
317
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
318
+ "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
319
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
320
+ )
321
+
356
322
  added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
357
323
 
358
324
  if down_block_additional_residuals is not None:
@@ -20,16 +20,64 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+ from typing import TYPE_CHECKING
23
24
 
24
- from .controlnet import (
25
- RBLNMultiControlNetModel,
26
- RBLNStableDiffusionControlNetImg2ImgPipeline,
27
- RBLNStableDiffusionControlNetPipeline,
28
- RBLNStableDiffusionXLControlNetImg2ImgPipeline,
29
- RBLNStableDiffusionXLControlNetPipeline,
30
- )
31
- from .stable_diffusion import (
32
- RBLNStableDiffusionImg2ImgPipeline,
33
- RBLNStableDiffusionPipeline,
34
- )
35
- from .stable_diffusion_xl import RBLNStableDiffusionXLImg2ImgPipeline, RBLNStableDiffusionXLPipeline
25
+ from transformers.utils import _LazyModule
26
+
27
+
28
+ _import_structure = {
29
+ "controlnet": [
30
+ "RBLNMultiControlNetModel",
31
+ "RBLNStableDiffusionControlNetImg2ImgPipeline",
32
+ "RBLNStableDiffusionControlNetPipeline",
33
+ "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
34
+ "RBLNStableDiffusionXLControlNetPipeline",
35
+ ],
36
+ "stable_diffusion": [
37
+ "RBLNStableDiffusionImg2ImgPipeline",
38
+ "RBLNStableDiffusionPipeline",
39
+ "RBLNStableDiffusionInpaintPipeline",
40
+ ],
41
+ "stable_diffusion_xl": [
42
+ "RBLNStableDiffusionXLImg2ImgPipeline",
43
+ "RBLNStableDiffusionXLPipeline",
44
+ "RBLNStableDiffusionXLInpaintPipeline",
45
+ ],
46
+ "stable_diffusion_3": [
47
+ "RBLNStableDiffusion3Pipeline",
48
+ "RBLNStableDiffusion3Img2ImgPipeline",
49
+ "RBLNStableDiffusion3InpaintPipeline",
50
+ ],
51
+ }
52
+ if TYPE_CHECKING:
53
+ from .controlnet import (
54
+ RBLNMultiControlNetModel,
55
+ RBLNStableDiffusionControlNetImg2ImgPipeline,
56
+ RBLNStableDiffusionControlNetPipeline,
57
+ RBLNStableDiffusionXLControlNetImg2ImgPipeline,
58
+ RBLNStableDiffusionXLControlNetPipeline,
59
+ )
60
+ from .stable_diffusion import (
61
+ RBLNStableDiffusionImg2ImgPipeline,
62
+ RBLNStableDiffusionInpaintPipeline,
63
+ RBLNStableDiffusionPipeline,
64
+ )
65
+ from .stable_diffusion_3 import (
66
+ RBLNStableDiffusion3Img2ImgPipeline,
67
+ RBLNStableDiffusion3InpaintPipeline,
68
+ RBLNStableDiffusion3Pipeline,
69
+ )
70
+ from .stable_diffusion_xl import (
71
+ RBLNStableDiffusionXLImg2ImgPipeline,
72
+ RBLNStableDiffusionXLInpaintPipeline,
73
+ RBLNStableDiffusionXLPipeline,
74
+ )
75
+ else:
76
+ import sys
77
+
78
+ sys.modules[__name__] = _LazyModule(
79
+ __name__,
80
+ globals()["__file__"],
81
+ _import_structure,
82
+ module_spec=__spec__,
83
+ )
@@ -27,12 +27,9 @@ from pathlib import Path
27
27
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
28
28
 
29
29
  import torch
30
- from diffusers import ControlNetModel
31
30
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
32
- from optimum.exporters import TasksManager
33
- from transformers import AutoConfig, AutoModel
34
31
 
35
- from ....modeling_base import RBLNModel
32
+ from ....modeling import RBLNModel
36
33
  from ....modeling_config import RBLNConfig
37
34
  from ...models.controlnet import RBLNControlNetModel
38
35
 
@@ -44,6 +41,9 @@ logger = logging.getLogger(__name__)
44
41
 
45
42
 
46
43
  class RBLNMultiControlNetModel(RBLNModel):
44
+ hf_library_name = "diffusers"
45
+ _hf_class = MultiControlNetModel
46
+
47
47
  def __init__(
48
48
  self,
49
49
  models: List[RBLNControlNetModel],
@@ -52,26 +52,12 @@ class RBLNMultiControlNetModel(RBLNModel):
52
52
  self.nets = models
53
53
  self.dtype = torch.float32
54
54
 
55
- @classmethod
56
- def from_pretrained(cls, *args, **kwargs):
57
- def get_model_from_task(
58
- task: str,
59
- model_name_or_path: Union[str, Path],
60
- **kwargs,
61
- ):
62
- return MultiControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
63
-
64
- tasktmp = TasksManager.get_model_from_task
65
- configtmp = AutoConfig.from_pretrained
66
- modeltmp = AutoModel.from_pretrained
67
- TasksManager.get_model_from_task = get_model_from_task
68
- AutoConfig.from_pretrained = ControlNetModel.load_config
69
- AutoModel.from_pretrained = MultiControlNetModel.from_pretrained
70
- rt = super().from_pretrained(*args, **kwargs)
71
- AutoConfig.from_pretrained = configtmp
72
- AutoModel.from_pretrained = modeltmp
73
- TasksManager.get_model_from_task = tasktmp
74
- return rt
55
+ @property
56
+ def compiled_models(self):
57
+ cm = []
58
+ for net in self.nets:
59
+ cm.extend(net.compiled_models)
60
+ return cm
75
61
 
76
62
  @classmethod
77
63
  def _from_pretrained(
@@ -111,7 +97,7 @@ class RBLNMultiControlNetModel(RBLNModel):
111
97
  sample: torch.FloatTensor,
112
98
  timestep: Union[torch.Tensor, float, int],
113
99
  encoder_hidden_states: torch.Tensor,
114
- controlnet_cond: List[torch.tensor],
100
+ controlnet_cond: List[torch.Tensor],
115
101
  conditioning_scale: List[float],
116
102
  class_labels: Optional[torch.Tensor] = None,
117
103
  timestep_cond: Optional[torch.Tensor] = None,