diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -367,6 +367,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
367
367
  width // self.vae_scale_factor_spatial,
368
368
  )
369
369
 
370
+ # For CogVideoX1.5, the latent should add 1 for padding (Not use)
371
+ if self.transformer.config.patch_size_t is not None:
372
+ shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]
373
+
370
374
  image = image.unsqueeze(2) # [B, C, F, H, W]
371
375
 
372
376
  if isinstance(generator, list):
@@ -377,7 +381,13 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
377
381
  image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
378
382
 
379
383
  image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
380
- image_latents = self.vae_scaling_factor_image * image_latents
384
+
385
+ if not self.vae.config.invert_scale_latents:
386
+ image_latents = self.vae_scaling_factor_image * image_latents
387
+ else:
388
+ # This is awkward but required because the CogVideoX team forgot to multiply the
389
+ # scaling factor during training :)
390
+ image_latents = 1 / self.vae_scaling_factor_image * image_latents
381
391
 
382
392
  padding_shape = (
383
393
  batch_size,
@@ -386,9 +396,15 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
386
396
  height // self.vae_scale_factor_spatial,
387
397
  width // self.vae_scale_factor_spatial,
388
398
  )
399
+
389
400
  latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
390
401
  image_latents = torch.cat([image_latents, latent_padding], dim=1)
391
402
 
403
+ # Select the first frame along the second dimension
404
+ if self.transformer.config.patch_size_t is not None:
405
+ first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...]
406
+ image_latents = torch.cat([first_frame, image_latents], dim=1)
407
+
392
408
  if latents is None:
393
409
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
394
410
  else:
@@ -522,21 +538,39 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
522
538
  ) -> Tuple[torch.Tensor, torch.Tensor]:
523
539
  grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
524
540
  grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
525
- base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
526
- base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
527
541
 
528
- grid_crops_coords = get_resize_crop_region_for_grid(
529
- (grid_height, grid_width), base_size_width, base_size_height
530
- )
531
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
532
- embed_dim=self.transformer.config.attention_head_dim,
533
- crops_coords=grid_crops_coords,
534
- grid_size=(grid_height, grid_width),
535
- temporal_size=num_frames,
536
- )
542
+ p = self.transformer.config.patch_size
543
+ p_t = self.transformer.config.patch_size_t
544
+
545
+ base_size_width = self.transformer.config.sample_width // p
546
+ base_size_height = self.transformer.config.sample_height // p
547
+
548
+ if p_t is None:
549
+ # CogVideoX 1.0
550
+ grid_crops_coords = get_resize_crop_region_for_grid(
551
+ (grid_height, grid_width), base_size_width, base_size_height
552
+ )
553
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
554
+ embed_dim=self.transformer.config.attention_head_dim,
555
+ crops_coords=grid_crops_coords,
556
+ grid_size=(grid_height, grid_width),
557
+ temporal_size=num_frames,
558
+ device=device,
559
+ )
560
+ else:
561
+ # CogVideoX 1.5
562
+ base_num_frames = (num_frames + p_t - 1) // p_t
563
+
564
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
565
+ embed_dim=self.transformer.config.attention_head_dim,
566
+ crops_coords=None,
567
+ grid_size=(grid_height, grid_width),
568
+ temporal_size=base_num_frames,
569
+ grid_type="slice",
570
+ max_size=(base_size_height, base_size_width),
571
+ device=device,
572
+ )
537
573
 
538
- freqs_cos = freqs_cos.to(device=device)
539
- freqs_sin = freqs_sin.to(device=device)
540
574
  return freqs_cos, freqs_sin
541
575
 
542
576
  @property
@@ -562,8 +596,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
562
596
  image: PipelineImageInput,
563
597
  prompt: Optional[Union[str, List[str]]] = None,
564
598
  negative_prompt: Optional[Union[str, List[str]]] = None,
565
- height: int = 480,
566
- width: int = 720,
599
+ height: Optional[int] = None,
600
+ width: Optional[int] = None,
567
601
  num_frames: int = 49,
568
602
  num_inference_steps: int = 50,
569
603
  timesteps: Optional[List[int]] = None,
@@ -666,14 +700,13 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
666
700
  `tuple`. When returning a tuple, the first element is a list with the generated images.
667
701
  """
668
702
 
669
- if num_frames > 49:
670
- raise ValueError(
671
- "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
672
- )
673
-
674
703
  if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
675
704
  callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
676
705
 
706
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
707
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
708
+ num_frames = num_frames or self.transformer.config.sample_frames
709
+
677
710
  num_videos_per_prompt = 1
678
711
 
679
712
  # 1. Check inputs. Raise error if not correct
@@ -726,6 +759,15 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
726
759
  self._num_timesteps = len(timesteps)
727
760
 
728
761
  # 5. Prepare latents
762
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
763
+
764
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
765
+ patch_size_t = self.transformer.config.patch_size_t
766
+ additional_frames = 0
767
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
768
+ additional_frames = patch_size_t - latent_frames % patch_size_t
769
+ num_frames += additional_frames * self.vae_scale_factor_temporal
770
+
729
771
  image = self.video_processor.preprocess(image, height=height, width=width).to(
730
772
  device, dtype=prompt_embeds.dtype
731
773
  )
@@ -754,6 +796,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
754
796
  else None
755
797
  )
756
798
 
799
+ # 8. Create ofs embeds if required
800
+ ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
801
+
757
802
  # 8. Denoising loop
758
803
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
759
804
 
@@ -778,6 +823,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
778
823
  hidden_states=latent_model_input,
779
824
  encoder_hidden_states=prompt_embeds,
780
825
  timestep=timestep,
826
+ ofs=ofs_emb,
781
827
  image_rotary_emb=image_rotary_emb,
782
828
  attention_kwargs=attention_kwargs,
783
829
  return_dict=False,
@@ -823,6 +869,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
823
869
  progress_bar.update()
824
870
 
825
871
  if not output_type == "latent":
872
+ # Discard any padding frames that were added for CogVideoX 1.5
873
+ latents = latents[:, additional_frames:]
826
874
  video = self.decode_latents(latents)
827
875
  video = self.video_processor.postprocess_video(video=video, output_type=output_type)
828
876
  else:
@@ -373,12 +373,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
373
373
 
374
374
  if latents is None:
375
375
  if isinstance(generator, list):
376
- if len(generator) != batch_size:
377
- raise ValueError(
378
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
379
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
380
- )
381
-
382
376
  init_latents = [
383
377
  retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
384
378
  ]
@@ -518,21 +512,39 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
518
512
  ) -> Tuple[torch.Tensor, torch.Tensor]:
519
513
  grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
520
514
  grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
521
- base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
522
- base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
523
515
 
524
- grid_crops_coords = get_resize_crop_region_for_grid(
525
- (grid_height, grid_width), base_size_width, base_size_height
526
- )
527
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
528
- embed_dim=self.transformer.config.attention_head_dim,
529
- crops_coords=grid_crops_coords,
530
- grid_size=(grid_height, grid_width),
531
- temporal_size=num_frames,
532
- )
516
+ p = self.transformer.config.patch_size
517
+ p_t = self.transformer.config.patch_size_t
518
+
519
+ base_size_width = self.transformer.config.sample_width // p
520
+ base_size_height = self.transformer.config.sample_height // p
521
+
522
+ if p_t is None:
523
+ # CogVideoX 1.0
524
+ grid_crops_coords = get_resize_crop_region_for_grid(
525
+ (grid_height, grid_width), base_size_width, base_size_height
526
+ )
527
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
528
+ embed_dim=self.transformer.config.attention_head_dim,
529
+ crops_coords=grid_crops_coords,
530
+ grid_size=(grid_height, grid_width),
531
+ temporal_size=num_frames,
532
+ device=device,
533
+ )
534
+ else:
535
+ # CogVideoX 1.5
536
+ base_num_frames = (num_frames + p_t - 1) // p_t
537
+
538
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
539
+ embed_dim=self.transformer.config.attention_head_dim,
540
+ crops_coords=None,
541
+ grid_size=(grid_height, grid_width),
542
+ temporal_size=base_num_frames,
543
+ grid_type="slice",
544
+ max_size=(base_size_height, base_size_width),
545
+ device=device,
546
+ )
533
547
 
534
- freqs_cos = freqs_cos.to(device=device)
535
- freqs_sin = freqs_sin.to(device=device)
536
548
  return freqs_cos, freqs_sin
537
549
 
538
550
  @property
@@ -558,8 +570,8 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
558
570
  video: List[Image.Image] = None,
559
571
  prompt: Optional[Union[str, List[str]]] = None,
560
572
  negative_prompt: Optional[Union[str, List[str]]] = None,
561
- height: int = 480,
562
- width: int = 720,
573
+ height: Optional[int] = None,
574
+ width: Optional[int] = None,
563
575
  num_inference_steps: int = 50,
564
576
  timesteps: Optional[List[int]] = None,
565
577
  strength: float = 0.8,
@@ -662,6 +674,10 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
662
674
  if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
663
675
  callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
664
676
 
677
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
678
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
679
+ num_frames = len(video) if latents is None else latents.size(1)
680
+
665
681
  num_videos_per_prompt = 1
666
682
 
667
683
  # 1. Check inputs. Raise error if not correct
@@ -717,6 +733,16 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
717
733
  self._num_timesteps = len(timesteps)
718
734
 
719
735
  # 5. Prepare latents
736
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
737
+
738
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
739
+ patch_size_t = self.transformer.config.patch_size_t
740
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
741
+ raise ValueError(
742
+ f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
743
+ f"contains {latent_frames=}, which is not divisible."
744
+ )
745
+
720
746
  if latents is None:
721
747
  video = self.video_processor.preprocess_video(video, height=height, width=width)
722
748
  video = video.to(device=device, dtype=prompt_embeds.dtype)
@@ -38,7 +38,7 @@ EXAMPLE_DOC_STRING = """
38
38
  >>> import torch
39
39
  >>> from diffusers import CogView3PlusPipeline
40
40
 
41
- >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3B", torch_dtype=torch.bfloat16)
41
+ >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16)
42
42
  >>> pipe.to("cuda")
43
43
 
44
44
  >>> prompt = "A photo of an astronaut riding a horse on mars"
@@ -1,80 +1,86 @@
1
- from typing import TYPE_CHECKING
2
-
3
- from ...utils import (
4
- DIFFUSERS_SLOW_IMPORT,
5
- OptionalDependencyNotAvailable,
6
- _LazyModule,
7
- get_objects_from_module,
8
- is_flax_available,
9
- is_torch_available,
10
- is_transformers_available,
11
- )
12
-
13
-
14
- _dummy_objects = {}
15
- _import_structure = {}
16
-
17
- try:
18
- if not (is_transformers_available() and is_torch_available()):
19
- raise OptionalDependencyNotAvailable()
20
- except OptionalDependencyNotAvailable:
21
- from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
-
23
- _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
- else:
25
- _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
26
- _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
27
- _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
28
- _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
29
- _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
30
- _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
31
- _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
32
- _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
33
- try:
34
- if not (is_transformers_available() and is_flax_available()):
35
- raise OptionalDependencyNotAvailable()
36
- except OptionalDependencyNotAvailable:
37
- from ...utils import dummy_flax_and_transformers_objects # noqa F403
38
-
39
- _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
40
- else:
41
- _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
42
-
43
-
44
- if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
45
- try:
46
- if not (is_transformers_available() and is_torch_available()):
47
- raise OptionalDependencyNotAvailable()
48
-
49
- except OptionalDependencyNotAvailable:
50
- from ...utils.dummy_torch_and_transformers_objects import *
51
- else:
52
- from .multicontrolnet import MultiControlNetModel
53
- from .pipeline_controlnet import StableDiffusionControlNetPipeline
54
- from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
55
- from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
56
- from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
57
- from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
58
- from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
59
- from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
60
-
61
- try:
62
- if not (is_transformers_available() and is_flax_available()):
63
- raise OptionalDependencyNotAvailable()
64
- except OptionalDependencyNotAvailable:
65
- from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
66
- else:
67
- from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
68
-
69
-
70
- else:
71
- import sys
72
-
73
- sys.modules[__name__] = _LazyModule(
74
- __name__,
75
- globals()["__file__"],
76
- _import_structure,
77
- module_spec=__spec__,
78
- )
79
- for name, value in _dummy_objects.items():
80
- setattr(sys.modules[__name__], name, value)
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_flax_available,
9
+ is_torch_available,
10
+ is_transformers_available,
11
+ )
12
+
13
+
14
+ _dummy_objects = {}
15
+ _import_structure = {}
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
26
+ _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
27
+ _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
28
+ _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
29
+ _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
30
+ _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
31
+ _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
32
+ _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
33
+ _import_structure["pipeline_controlnet_union_inpaint_sd_xl"] = ["StableDiffusionXLControlNetUnionInpaintPipeline"]
34
+ _import_structure["pipeline_controlnet_union_sd_xl"] = ["StableDiffusionXLControlNetUnionPipeline"]
35
+ _import_structure["pipeline_controlnet_union_sd_xl_img2img"] = ["StableDiffusionXLControlNetUnionImg2ImgPipeline"]
36
+ try:
37
+ if not (is_transformers_available() and is_flax_available()):
38
+ raise OptionalDependencyNotAvailable()
39
+ except OptionalDependencyNotAvailable:
40
+ from ...utils import dummy_flax_and_transformers_objects # noqa F403
41
+
42
+ _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
43
+ else:
44
+ _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
45
+
46
+
47
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
48
+ try:
49
+ if not (is_transformers_available() and is_torch_available()):
50
+ raise OptionalDependencyNotAvailable()
51
+
52
+ except OptionalDependencyNotAvailable:
53
+ from ...utils.dummy_torch_and_transformers_objects import *
54
+ else:
55
+ from .multicontrolnet import MultiControlNetModel
56
+ from .pipeline_controlnet import StableDiffusionControlNetPipeline
57
+ from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
58
+ from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
59
+ from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
60
+ from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
61
+ from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
62
+ from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
63
+ from .pipeline_controlnet_union_inpaint_sd_xl import StableDiffusionXLControlNetUnionInpaintPipeline
64
+ from .pipeline_controlnet_union_sd_xl import StableDiffusionXLControlNetUnionPipeline
65
+ from .pipeline_controlnet_union_sd_xl_img2img import StableDiffusionXLControlNetUnionImg2ImgPipeline
66
+
67
+ try:
68
+ if not (is_transformers_available() and is_flax_available()):
69
+ raise OptionalDependencyNotAvailable()
70
+ except OptionalDependencyNotAvailable:
71
+ from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
72
+ else:
73
+ from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
74
+
75
+
76
+ else:
77
+ import sys
78
+
79
+ sys.modules[__name__] = _LazyModule(
80
+ __name__,
81
+ globals()["__file__"],
82
+ _import_structure,
83
+ module_spec=__spec__,
84
+ )
85
+ for name, value in _dummy_objects.items():
86
+ setattr(sys.modules[__name__], name, value)
@@ -1,183 +1,12 @@
1
- import os
2
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
-
4
- import torch
5
- from torch import nn
6
-
7
- from ...models.controlnet import ControlNetModel, ControlNetOutput
8
- from ...models.modeling_utils import ModelMixin
9
- from ...utils import logging
1
+ from ...models.controlnets.multicontrolnet import MultiControlNetModel
2
+ from ...utils import deprecate, logging
10
3
 
11
4
 
12
5
  logger = logging.get_logger(__name__)
13
6
 
14
7
 
15
- class MultiControlNetModel(ModelMixin):
16
- r"""
17
- Multiple `ControlNetModel` wrapper class for Multi-ControlNet
18
-
19
- This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
20
- compatible with `ControlNetModel`.
21
-
22
- Args:
23
- controlnets (`List[ControlNetModel]`):
24
- Provides additional conditioning to the unet during the denoising process. You must set multiple
25
- `ControlNetModel` as a list.
26
- """
27
-
28
- def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
29
- super().__init__()
30
- self.nets = nn.ModuleList(controlnets)
31
-
32
- def forward(
33
- self,
34
- sample: torch.Tensor,
35
- timestep: Union[torch.Tensor, float, int],
36
- encoder_hidden_states: torch.Tensor,
37
- controlnet_cond: List[torch.tensor],
38
- conditioning_scale: List[float],
39
- class_labels: Optional[torch.Tensor] = None,
40
- timestep_cond: Optional[torch.Tensor] = None,
41
- attention_mask: Optional[torch.Tensor] = None,
42
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
43
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
44
- guess_mode: bool = False,
45
- return_dict: bool = True,
46
- ) -> Union[ControlNetOutput, Tuple]:
47
- for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
48
- down_samples, mid_sample = controlnet(
49
- sample=sample,
50
- timestep=timestep,
51
- encoder_hidden_states=encoder_hidden_states,
52
- controlnet_cond=image,
53
- conditioning_scale=scale,
54
- class_labels=class_labels,
55
- timestep_cond=timestep_cond,
56
- attention_mask=attention_mask,
57
- added_cond_kwargs=added_cond_kwargs,
58
- cross_attention_kwargs=cross_attention_kwargs,
59
- guess_mode=guess_mode,
60
- return_dict=return_dict,
61
- )
62
-
63
- # merge samples
64
- if i == 0:
65
- down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
66
- else:
67
- down_block_res_samples = [
68
- samples_prev + samples_curr
69
- for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
70
- ]
71
- mid_block_res_sample += mid_sample
72
-
73
- return down_block_res_samples, mid_block_res_sample
74
-
75
- def save_pretrained(
76
- self,
77
- save_directory: Union[str, os.PathLike],
78
- is_main_process: bool = True,
79
- save_function: Callable = None,
80
- safe_serialization: bool = True,
81
- variant: Optional[str] = None,
82
- ):
83
- """
84
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
85
- `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
86
-
87
- Arguments:
88
- save_directory (`str` or `os.PathLike`):
89
- Directory to which to save. Will be created if it doesn't exist.
90
- is_main_process (`bool`, *optional*, defaults to `True`):
91
- Whether the process calling this is the main process or not. Useful when in distributed training like
92
- TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
93
- the main process to avoid race conditions.
94
- save_function (`Callable`):
95
- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
96
- need to replace `torch.save` by another method. Can be configured with the environment variable
97
- `DIFFUSERS_SAVE_MODE`.
98
- safe_serialization (`bool`, *optional*, defaults to `True`):
99
- Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
100
- variant (`str`, *optional*):
101
- If specified, weights are saved in the format pytorch_model.<variant>.bin.
102
- """
103
- for idx, controlnet in enumerate(self.nets):
104
- suffix = "" if idx == 0 else f"_{idx}"
105
- controlnet.save_pretrained(
106
- save_directory + suffix,
107
- is_main_process=is_main_process,
108
- save_function=save_function,
109
- safe_serialization=safe_serialization,
110
- variant=variant,
111
- )
112
-
113
- @classmethod
114
- def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
115
- r"""
116
- Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
117
-
118
- The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
119
- the model, you should first set it back in training mode with `model.train()`.
120
-
121
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
122
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
123
- task.
124
-
125
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
126
- weights are discarded.
127
-
128
- Parameters:
129
- pretrained_model_path (`os.PathLike`):
130
- A path to a *directory* containing model weights saved using
131
- [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
132
- `./my_model_directory/controlnet`.
133
- torch_dtype (`str` or `torch.dtype`, *optional*):
134
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
135
- will be automatically derived from the model's weights.
136
- output_loading_info(`bool`, *optional*, defaults to `False`):
137
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
138
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
139
- A map that specifies where each submodule should go. It doesn't need to be refined to each
140
- parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
141
- same device.
142
-
143
- To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
144
- more information about each option see [designing a device
145
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
146
- max_memory (`Dict`, *optional*):
147
- A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
148
- GPU and the available CPU RAM if unset.
149
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
150
- Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
151
- also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
152
- model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
153
- setting this argument to `True` will raise an error.
154
- variant (`str`, *optional*):
155
- If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
156
- ignored when using `from_flax`.
157
- use_safetensors (`bool`, *optional*, defaults to `None`):
158
- If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
159
- `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
160
- `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
161
- """
162
- idx = 0
163
- controlnets = []
164
-
165
- # load controlnet and append to list until no controlnet directory exists anymore
166
- # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
167
- # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
168
- model_path_to_load = pretrained_model_path
169
- while os.path.isdir(model_path_to_load):
170
- controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
171
- controlnets.append(controlnet)
172
-
173
- idx += 1
174
- model_path_to_load = pretrained_model_path + f"_{idx}"
175
-
176
- logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
177
-
178
- if len(controlnets) == 0:
179
- raise ValueError(
180
- f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
181
- )
182
-
183
- return cls(controlnets)
8
+ class MultiControlNetModel(MultiControlNetModel):
9
+ def __init__(self, *args, **kwargs):
10
+ deprecation_message = "Importing `MultiControlNetModel` from `diffusers.pipelines.controlnet.multicontrolnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel`, instead."
11
+ deprecate("diffusers.pipelines.controlnet.multicontrolnet.MultiControlNetModel", "0.34", deprecation_message)
12
+ super().__init__(*args, **kwargs)