diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -339,6 +339,23 @@ if _imageio_available:
339
339
  except importlib_metadata.PackageNotFoundError:
340
340
  _imageio_available = False
341
341
 
342
+ _is_gguf_available = importlib.util.find_spec("gguf") is not None
343
+ if _is_gguf_available:
344
+ try:
345
+ _gguf_version = importlib_metadata.version("gguf")
346
+ logger.debug(f"Successfully import gguf version {_gguf_version}")
347
+ except importlib_metadata.PackageNotFoundError:
348
+ _is_gguf_available = False
349
+
350
+
351
+ _is_torchao_available = importlib.util.find_spec("torchao") is not None
352
+ if _is_torchao_available:
353
+ try:
354
+ _torchao_version = importlib_metadata.version("torchao")
355
+ logger.debug(f"Successfully import torchao version {_torchao_version}")
356
+ except importlib_metadata.PackageNotFoundError:
357
+ _is_torchao_available = False
358
+
342
359
 
343
360
  def is_torch_available():
344
361
  return _torch_available
@@ -460,6 +477,14 @@ def is_imageio_available():
460
477
  return _imageio_available
461
478
 
462
479
 
480
+ def is_gguf_available():
481
+ return _is_gguf_available
482
+
483
+
484
+ def is_torchao_available():
485
+ return _is_torchao_available
486
+
487
+
463
488
  # docstyle-ignore
464
489
  FLAX_IMPORT_ERROR = """
465
490
  {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -593,6 +618,16 @@ IMAGEIO_IMPORT_ERROR = """
593
618
  {0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg`
594
619
  """
595
620
 
621
+ # docstyle-ignore
622
+ GGUF_IMPORT_ERROR = """
623
+ {0} requires the gguf library but it was not found in your environment. You can install it with pip: `pip install gguf`
624
+ """
625
+
626
+ TORCHAO_IMPORT_ERROR = """
627
+ {0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install
628
+ torchao`
629
+ """
630
+
596
631
  BACKENDS_MAPPING = OrderedDict(
597
632
  [
598
633
  ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
@@ -618,6 +653,8 @@ BACKENDS_MAPPING = OrderedDict(
618
653
  ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
619
654
  ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
620
655
  ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
656
+ ("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
657
+ ("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
621
658
  ]
622
659
  )
623
660
 
@@ -668,8 +705,9 @@ class DummyObject(type):
668
705
  # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
669
706
  def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
670
707
  """
671
- Args:
672
708
  Compares a library version to some requirement using a given operation.
709
+
710
+ Args:
673
711
  library_or_version (`str` or `packaging.version.Version`):
674
712
  A library name or a version to check.
675
713
  operation (`str`):
@@ -688,8 +726,9 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
688
726
  # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
689
727
  def is_torch_version(operation: str, version: str):
690
728
  """
691
- Args:
692
729
  Compares the current PyTorch version to a given reference with an operation.
730
+
731
+ Args:
693
732
  operation (`str`):
694
733
  A string representation of an operator, such as `">"` or `"<="`
695
734
  version (`str`):
@@ -698,10 +737,26 @@ def is_torch_version(operation: str, version: str):
698
737
  return compare_versions(parse(_torch_version), operation, version)
699
738
 
700
739
 
701
- def is_transformers_version(operation: str, version: str):
740
+ def is_torch_xla_version(operation: str, version: str):
702
741
  """
742
+ Compares the current torch_xla version to a given reference with an operation.
743
+
703
744
  Args:
745
+ operation (`str`):
746
+ A string representation of an operator, such as `">"` or `"<="`
747
+ version (`str`):
748
+ A string version of torch_xla
749
+ """
750
+ if not is_torch_xla_available:
751
+ return False
752
+ return compare_versions(parse(_torch_xla_version), operation, version)
753
+
754
+
755
+ def is_transformers_version(operation: str, version: str):
756
+ """
704
757
  Compares the current Transformers version to a given reference with an operation.
758
+
759
+ Args:
705
760
  operation (`str`):
706
761
  A string representation of an operator, such as `">"` or `"<="`
707
762
  version (`str`):
@@ -714,8 +769,9 @@ def is_transformers_version(operation: str, version: str):
714
769
 
715
770
  def is_accelerate_version(operation: str, version: str):
716
771
  """
717
- Args:
718
772
  Compares the current Accelerate version to a given reference with an operation.
773
+
774
+ Args:
719
775
  operation (`str`):
720
776
  A string representation of an operator, such as `">"` or `"<="`
721
777
  version (`str`):
@@ -728,8 +784,9 @@ def is_accelerate_version(operation: str, version: str):
728
784
 
729
785
  def is_peft_version(operation: str, version: str):
730
786
  """
731
- Args:
732
787
  Compares the current PEFT version to a given reference with an operation.
788
+
789
+ Args:
733
790
  operation (`str`):
734
791
  A string representation of an operator, such as `">"` or `"<="`
735
792
  version (`str`):
@@ -740,10 +797,40 @@ def is_peft_version(operation: str, version: str):
740
797
  return compare_versions(parse(_peft_version), operation, version)
741
798
 
742
799
 
743
- def is_k_diffusion_version(operation: str, version: str):
800
+ def is_bitsandbytes_version(operation: str, version: str):
801
+ """
802
+ Args:
803
+ Compares the current bitsandbytes version to a given reference with an operation.
804
+ operation (`str`):
805
+ A string representation of an operator, such as `">"` or `"<="`
806
+ version (`str`):
807
+ A version string
744
808
  """
809
+ if not _bitsandbytes_version:
810
+ return False
811
+ return compare_versions(parse(_bitsandbytes_version), operation, version)
812
+
813
+
814
+ def is_gguf_version(operation: str, version: str):
815
+ """
816
+ Compares the current Accelerate version to a given reference with an operation.
817
+
745
818
  Args:
819
+ operation (`str`):
820
+ A string representation of an operator, such as `">"` or `"<="`
821
+ version (`str`):
822
+ A version string
823
+ """
824
+ if not _is_gguf_available:
825
+ return False
826
+ return compare_versions(parse(_gguf_version), operation, version)
827
+
828
+
829
+ def is_k_diffusion_version(operation: str, version: str):
830
+ """
746
831
  Compares the current k-diffusion version to a given reference with an operation.
832
+
833
+ Args:
747
834
  operation (`str`):
748
835
  A string representation of an operator, such as `">"` or `"<="`
749
836
  version (`str`):
@@ -756,8 +843,9 @@ def is_k_diffusion_version(operation: str, version: str):
756
843
 
757
844
  def get_objects_from_module(module):
758
845
  """
759
- Args:
760
846
  Returns a dict of object names and values in a module, while skipping private/internal objects
847
+
848
+ Args:
761
849
  module (ModuleType):
762
850
  Module to extract the objects from.
763
851
 
@@ -775,7 +863,9 @@ def get_objects_from_module(module):
775
863
 
776
864
 
777
865
  class OptionalDependencyNotAvailable(BaseException):
778
- """An error indicating that an optional dependency of Diffusers was not found in the environment."""
866
+ """
867
+ An error indicating that an optional dependency of Diffusers was not found in the environment.
868
+ """
779
869
 
780
870
 
781
871
  class _LazyModule(ModuleType):
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import tempfile
3
- from typing import Callable, List, Optional, Union
3
+ from typing import Any, Callable, List, Optional, Tuple, Union
4
+ from urllib.parse import unquote, urlparse
4
5
 
5
6
  import PIL.Image
6
7
  import PIL.ImageOps
@@ -80,12 +81,22 @@ def load_video(
80
81
  )
81
82
 
82
83
  if is_url:
83
- video_data = requests.get(video, stream=True).raw
84
- suffix = os.path.splitext(video)[1] or ".mp4"
84
+ response = requests.get(video, stream=True)
85
+ if response.status_code != 200:
86
+ raise ValueError(f"Failed to download video. Status code: {response.status_code}")
87
+
88
+ parsed_url = urlparse(video)
89
+ file_name = os.path.basename(unquote(parsed_url.path))
90
+
91
+ suffix = os.path.splitext(file_name)[1] or ".mp4"
85
92
  video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name
93
+
86
94
  was_tempfile_created = True
95
+
96
+ video_data = response.iter_content(chunk_size=8192)
87
97
  with open(video_path, "wb") as f:
88
- f.write(video_data.read())
98
+ for chunk in video_data:
99
+ f.write(chunk)
89
100
 
90
101
  video = video_path
91
102
 
@@ -124,3 +135,16 @@ def load_video(
124
135
  pil_images = convert_method(pil_images)
125
136
 
126
137
  return pil_images
138
+
139
+
140
+ # Taken from `transformers`.
141
+ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
142
+ if "." in tensor_name:
143
+ splits = tensor_name.split(".")
144
+ for split in splits[:-1]:
145
+ new_module = getattr(module, split)
146
+ if new_module is None:
147
+ raise ValueError(f"{module} has no attribute {split}.")
148
+ module = new_module
149
+ tensor_name = splits[-1]
150
+ return module, tensor_name
@@ -134,14 +134,14 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
134
134
  """
135
135
  from peft.tuners.tuners_utils import BaseTunerLayer
136
136
 
137
- if weight == 1.0:
137
+ if weight is None or weight == 1.0:
138
138
  return
139
139
 
140
140
  for module in model.modules():
141
141
  if isinstance(module, BaseTunerLayer):
142
- if weight is not None and weight != 0:
142
+ if weight != 0:
143
143
  module.unscale_layer(weight)
144
- elif weight is not None and weight == 0:
144
+ else:
145
145
  for adapter_name in module.active_adapters:
146
146
  # if weight == 0 unscale should re-set the scale to the original value.
147
147
  module.set_scale(adapter_name, 1.0)
@@ -180,6 +180,8 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
180
180
  # layer names without the Diffusers specific
181
181
  target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
182
182
  use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
183
+ # for now we know that the "bias" keys are only associated with `lora_B`.
184
+ lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict)
183
185
 
184
186
  lora_config_kwargs = {
185
187
  "r": r,
@@ -188,6 +190,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
188
190
  "alpha_pattern": alpha_pattern,
189
191
  "target_modules": target_modules,
190
192
  "use_dora": use_dora,
193
+ "lora_bias": lora_bias,
191
194
  }
192
195
  return lora_config_kwargs
193
196
 
@@ -1,5 +1,6 @@
1
1
  import functools
2
2
  import importlib
3
+ import importlib.metadata
3
4
  import inspect
4
5
  import io
5
6
  import logging
@@ -27,8 +28,11 @@ from packaging import version
27
28
 
28
29
  from .import_utils import (
29
30
  BACKENDS_MAPPING,
31
+ is_accelerate_available,
32
+ is_bitsandbytes_available,
30
33
  is_compel_available,
31
34
  is_flax_available,
35
+ is_gguf_available,
32
36
  is_note_seq_available,
33
37
  is_onnx_available,
34
38
  is_opencv_available,
@@ -36,6 +40,7 @@ from .import_utils import (
36
40
  is_timm_available,
37
41
  is_torch_available,
38
42
  is_torch_version,
43
+ is_torchao_available,
39
44
  is_torchsde_available,
40
45
  is_transformers_available,
41
46
  )
@@ -54,6 +59,7 @@ _required_transformers_version = is_transformers_available() and version.parse(
54
59
  ) > version.parse("4.33")
55
60
 
56
61
  USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
62
+ BIG_GPU_MEMORY = int(os.getenv("BIG_GPU_MEMORY", 40))
57
63
 
58
64
  if is_torch_available():
59
65
  import torch
@@ -252,6 +258,18 @@ def require_torch_2(test_case):
252
258
  )
253
259
 
254
260
 
261
+ def require_torch_version_greater_equal(torch_version):
262
+ """Decorator marking a test that requires torch with a specific version or greater."""
263
+
264
+ def decorator(test_case):
265
+ correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
266
+ return unittest.skipUnless(
267
+ correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}"
268
+ )(test_case)
269
+
270
+ return decorator
271
+
272
+
255
273
  def require_torch_gpu(test_case):
256
274
  """Decorator marking a test that requires CUDA and PyTorch."""
257
275
  return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
@@ -295,6 +313,26 @@ def require_torch_accelerator_with_fp64(test_case):
295
313
  )
296
314
 
297
315
 
316
+ def require_big_gpu_with_torch_cuda(test_case):
317
+ """
318
+ Decorator marking a test that requires a bigger GPU (24GB) for execution. Some example pipelines: Flux, SD3, Cog,
319
+ etc.
320
+ """
321
+ if not is_torch_available():
322
+ return unittest.skip("test requires PyTorch")(test_case)
323
+
324
+ import torch
325
+
326
+ if not torch.cuda.is_available():
327
+ return unittest.skip("test requires PyTorch CUDA")(test_case)
328
+
329
+ device_properties = torch.cuda.get_device_properties(0)
330
+ total_memory = device_properties.total_memory / (1024**3)
331
+ return unittest.skipUnless(
332
+ total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
333
+ )(test_case)
334
+
335
+
298
336
  def require_torch_accelerator_with_training(test_case):
299
337
  """Decorator marking a test that requires an accelerator with support for training."""
300
338
  return unittest.skipUnless(
@@ -337,6 +375,14 @@ def require_note_seq(test_case):
337
375
  return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
338
376
 
339
377
 
378
+ def require_accelerator(test_case):
379
+ """
380
+ Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
381
+ hardware accelerator available.
382
+ """
383
+ return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case)
384
+
385
+
340
386
  def require_torchsde(test_case):
341
387
  """
342
388
  Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
@@ -359,6 +405,20 @@ def require_timm(test_case):
359
405
  return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
360
406
 
361
407
 
408
+ def require_bitsandbytes(test_case):
409
+ """
410
+ Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed.
411
+ """
412
+ return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
413
+
414
+
415
+ def require_accelerate(test_case):
416
+ """
417
+ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
418
+ """
419
+ return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
420
+
421
+
362
422
  def require_peft_version_greater(peft_version):
363
423
  """
364
424
  Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
@@ -376,9 +436,27 @@ def require_peft_version_greater(peft_version):
376
436
  return decorator
377
437
 
378
438
 
439
+ def require_transformers_version_greater(transformers_version):
440
+ """
441
+ Decorator marking a test that requires transformers with a specific version, this would require some specific
442
+ versions of PEFT and transformers.
443
+ """
444
+
445
+ def decorator(test_case):
446
+ correct_transformers_version = is_transformers_available() and version.parse(
447
+ version.parse(importlib.metadata.version("transformers")).base_version
448
+ ) > version.parse(transformers_version)
449
+ return unittest.skipUnless(
450
+ correct_transformers_version,
451
+ f"test requires transformers with the version greater than {transformers_version}",
452
+ )(test_case)
453
+
454
+ return decorator
455
+
456
+
379
457
  def require_accelerate_version_greater(accelerate_version):
380
458
  def decorator(test_case):
381
- correct_accelerate_version = is_peft_available() and version.parse(
459
+ correct_accelerate_version = is_accelerate_available() and version.parse(
382
460
  version.parse(importlib.metadata.version("accelerate")).base_version
383
461
  ) > version.parse(accelerate_version)
384
462
  return unittest.skipUnless(
@@ -388,6 +466,42 @@ def require_accelerate_version_greater(accelerate_version):
388
466
  return decorator
389
467
 
390
468
 
469
+ def require_bitsandbytes_version_greater(bnb_version):
470
+ def decorator(test_case):
471
+ correct_bnb_version = is_bitsandbytes_available() and version.parse(
472
+ version.parse(importlib.metadata.version("bitsandbytes")).base_version
473
+ ) > version.parse(bnb_version)
474
+ return unittest.skipUnless(
475
+ correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}."
476
+ )(test_case)
477
+
478
+ return decorator
479
+
480
+
481
+ def require_gguf_version_greater_or_equal(gguf_version):
482
+ def decorator(test_case):
483
+ correct_gguf_version = is_gguf_available() and version.parse(
484
+ version.parse(importlib.metadata.version("gguf")).base_version
485
+ ) >= version.parse(gguf_version)
486
+ return unittest.skipUnless(
487
+ correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
488
+ )(test_case)
489
+
490
+ return decorator
491
+
492
+
493
+ def require_torchao_version_greater_or_equal(torchao_version):
494
+ def decorator(test_case):
495
+ correct_torchao_version = is_torchao_available() and version.parse(
496
+ version.parse(importlib.metadata.version("torchao")).base_version
497
+ ) >= version.parse(torchao_version)
498
+ return unittest.skipUnless(
499
+ correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
500
+ )(test_case)
501
+
502
+ return decorator
503
+
504
+
391
505
  def deprecate_after_peft_backend(test_case):
392
506
  """
393
507
  Decorator marking a test that will be skipped after PEFT backend
@@ -102,6 +102,9 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
102
102
  # Non-power of 2 images must be float32
103
103
  if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
104
104
  x = x.to(dtype=torch.float32)
105
+ # fftn does not support bfloat16
106
+ elif x.dtype == torch.bfloat16:
107
+ x = x.to(dtype=torch.float32)
105
108
 
106
109
  # FFT
107
110
  x_freq = fftn(x, dim=(-2, -1))