diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import contextlib
2
2
  import copy
3
+ import gc
3
4
  import math
4
5
  import random
5
6
  from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
@@ -23,6 +24,9 @@ from .utils import (
23
24
  if is_transformers_available():
24
25
  import transformers
25
26
 
27
+ if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
28
+ import deepspeed
29
+
26
30
  if is_peft_available():
27
31
  from peft import set_peft_model_state_dict
28
32
 
@@ -35,9 +39,13 @@ if is_torch_npu_available():
35
39
 
36
40
  def set_seed(seed: int):
37
41
  """
38
- Args:
39
42
  Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
43
+
44
+ Args:
40
45
  seed (`int`): The seed to set.
46
+
47
+ Returns:
48
+ `None`
41
49
  """
42
50
  random.seed(seed)
43
51
  np.random.seed(seed)
@@ -53,6 +61,17 @@ def compute_snr(noise_scheduler, timesteps):
53
61
  """
54
62
  Computes SNR as per
55
63
  https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
64
+ for the given timesteps using the provided noise scheduler.
65
+
66
+ Args:
67
+ noise_scheduler (`NoiseScheduler`):
68
+ An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
69
+ the SNR values.
70
+ timesteps (`torch.Tensor`):
71
+ A tensor of timesteps for which the SNR is computed.
72
+
73
+ Returns:
74
+ `torch.Tensor`: A tensor containing the computed SNR values for each timestep.
56
75
  """
57
76
  alphas_cumprod = noise_scheduler.alphas_cumprod
58
77
  sqrt_alphas_cumprod = alphas_cumprod**0.5
@@ -193,6 +212,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
193
212
 
194
213
 
195
214
  def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
215
+ """
216
+ Casts the training parameters of the model to the specified data type.
217
+
218
+ Args:
219
+ model: The PyTorch model whose parameters will be cast.
220
+ dtype: The data type to which the model parameters will be cast.
221
+ """
196
222
  if not isinstance(model, list):
197
223
  model = [model]
198
224
  for m in model:
@@ -224,7 +250,8 @@ def _set_state_dict_into_text_encoder(
224
250
  def compute_density_for_timestep_sampling(
225
251
  weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
226
252
  ):
227
- """Compute the density for sampling the timesteps when doing SD3 training.
253
+ """
254
+ Compute the density for sampling the timesteps when doing SD3 training.
228
255
 
229
256
  Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
230
257
 
@@ -243,7 +270,8 @@ def compute_density_for_timestep_sampling(
243
270
 
244
271
 
245
272
  def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
246
- """Computes loss weighting scheme for SD3 training.
273
+ """
274
+ Computes loss weighting scheme for SD3 training.
247
275
 
248
276
  Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
249
277
 
@@ -259,6 +287,20 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
259
287
  return weighting
260
288
 
261
289
 
290
+ def free_memory():
291
+ """
292
+ Runs garbage collection. Then clears the cache of the available accelerator.
293
+ """
294
+ gc.collect()
295
+
296
+ if torch.cuda.is_available():
297
+ torch.cuda.empty_cache()
298
+ elif torch.backends.mps.is_available():
299
+ torch.mps.empty_cache()
300
+ elif is_torch_npu_available():
301
+ torch_npu.npu.empty_cache()
302
+
303
+
262
304
  # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
263
305
  class EMAModel:
264
306
  """
@@ -351,7 +393,7 @@ class EMAModel:
351
393
 
352
394
  @classmethod
353
395
  def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
354
- _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
396
+ _, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
355
397
  model = model_cls.from_pretrained(path)
356
398
 
357
399
  ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
@@ -417,15 +459,13 @@ class EMAModel:
417
459
  self.cur_decay_value = decay
418
460
  one_minus_decay = 1 - decay
419
461
 
420
- context_manager = contextlib.nullcontext
421
- if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
422
- import deepspeed
462
+ context_manager = contextlib.nullcontext()
423
463
 
424
464
  if self.foreach:
425
- if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
465
+ if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
426
466
  context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
427
467
 
428
- with context_manager():
468
+ with context_manager:
429
469
  params_grad = [param for param in parameters if param.requires_grad]
430
470
  s_params_grad = [
431
471
  s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
@@ -444,10 +484,10 @@ class EMAModel:
444
484
 
445
485
  else:
446
486
  for s_param, param in zip(self.shadow_params, parameters):
447
- if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
487
+ if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
448
488
  context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
449
489
 
450
- with context_manager():
490
+ with context_manager:
451
491
  if param.requires_grad:
452
492
  s_param.sub_(one_minus_decay * (s_param - param))
453
493
  else:
@@ -481,7 +521,8 @@ class EMAModel:
481
521
  self.shadow_params = [p.pin_memory() for p in self.shadow_params]
482
522
 
483
523
  def to(self, device=None, dtype=None, non_blocking=False) -> None:
484
- r"""Move internal buffers of the ExponentialMovingAverage to `device`.
524
+ r"""
525
+ Move internal buffers of the ExponentialMovingAverage to `device`.
485
526
 
486
527
  Args:
487
528
  device: like `device` argument to `torch.Tensor.to`
@@ -515,23 +556,25 @@ class EMAModel:
515
556
 
516
557
  def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
517
558
  r"""
559
+ Saves the current parameters for restoring later.
560
+
518
561
  Args:
519
- Save the current parameters for restoring later.
520
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
521
- temporarily stored.
562
+ parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
522
563
  """
523
564
  self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
524
565
 
525
566
  def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
526
567
  r"""
527
- Args:
528
- Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
529
- affecting the original optimization process. Store the parameters before the `copy_to()` method. After
568
+ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
569
+ without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
530
570
  validation (or model saving), use this to restore the former parameters.
571
+
572
+ Args:
531
573
  parameters: Iterable of `torch.nn.Parameter`; the parameters to be
532
574
  updated with the stored parameters. If `None`, the parameters with which this
533
575
  `ExponentialMovingAverage` was initialized will be used.
534
576
  """
577
+
535
578
  if self.temp_stored_params is None:
536
579
  raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
537
580
  if self.foreach:
@@ -547,9 +590,10 @@ class EMAModel:
547
590
 
548
591
  def load_state_dict(self, state_dict: dict) -> None:
549
592
  r"""
550
- Args:
551
593
  Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
552
594
  ema state dict.
595
+
596
+ Args:
553
597
  state_dict (dict): EMA state. Should be an object returned
554
598
  from a call to :meth:`state_dict`.
555
599
  """
@@ -23,6 +23,7 @@ from .constants import (
23
23
  DEPRECATED_REVISION_ARGS,
24
24
  DIFFUSERS_DYNAMIC_MODULE_NAME,
25
25
  FLAX_WEIGHTS_NAME,
26
+ GGUF_FILE_EXTENSION,
26
27
  HF_MODULES_CACHE,
27
28
  HUGGINGFACE_CO_RESOLVE_ENDPOINT,
28
29
  MIN_PEFT_VERSION,
@@ -62,9 +63,12 @@ from .import_utils import (
62
63
  is_accelerate_available,
63
64
  is_accelerate_version,
64
65
  is_bitsandbytes_available,
66
+ is_bitsandbytes_version,
65
67
  is_bs4_available,
66
68
  is_flax_available,
67
69
  is_ftfy_available,
70
+ is_gguf_available,
71
+ is_gguf_version,
68
72
  is_google_colab,
69
73
  is_inflect_available,
70
74
  is_invisible_watermark_available,
@@ -85,6 +89,8 @@ from .import_utils import (
85
89
  is_torch_npu_available,
86
90
  is_torch_version,
87
91
  is_torch_xla_available,
92
+ is_torch_xla_version,
93
+ is_torchao_available,
88
94
  is_torchsde_available,
89
95
  is_torchvision_available,
90
96
  is_transformers_available,
@@ -94,7 +100,7 @@ from .import_utils import (
94
100
  is_xformers_available,
95
101
  requires_backends,
96
102
  )
97
- from .loading_utils import load_image, load_video
103
+ from .loading_utils import get_module_from_name, load_image, load_video
98
104
  from .logging import get_logger
99
105
  from .outputs import BaseOutput
100
106
  from .peft_utils import (
@@ -34,6 +34,7 @@ ONNX_WEIGHTS_NAME = "model.onnx"
34
34
  SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
35
35
  SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
36
36
  SAFETENSORS_FILE_EXTENSION = "safetensors"
37
+ GGUF_FILE_EXTENSION = "gguf"
37
38
  ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
38
39
  HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
39
40
  DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
@@ -2,6 +2,21 @@
2
2
  from ..utils import DummyObject, requires_backends
3
3
 
4
4
 
5
+ class AllegroTransformer3DModel(metaclass=DummyObject):
6
+ _backends = ["torch"]
7
+
8
+ def __init__(self, *args, **kwargs):
9
+ requires_backends(self, ["torch"])
10
+
11
+ @classmethod
12
+ def from_config(cls, *args, **kwargs):
13
+ requires_backends(cls, ["torch"])
14
+
15
+ @classmethod
16
+ def from_pretrained(cls, *args, **kwargs):
17
+ requires_backends(cls, ["torch"])
18
+
19
+
5
20
  class AsymmetricAutoencoderKL(metaclass=DummyObject):
6
21
  _backends = ["torch"]
7
22
 
@@ -32,6 +47,21 @@ class AuraFlowTransformer2DModel(metaclass=DummyObject):
32
47
  requires_backends(cls, ["torch"])
33
48
 
34
49
 
50
+ class AutoencoderDC(metaclass=DummyObject):
51
+ _backends = ["torch"]
52
+
53
+ def __init__(self, *args, **kwargs):
54
+ requires_backends(self, ["torch"])
55
+
56
+ @classmethod
57
+ def from_config(cls, *args, **kwargs):
58
+ requires_backends(cls, ["torch"])
59
+
60
+ @classmethod
61
+ def from_pretrained(cls, *args, **kwargs):
62
+ requires_backends(cls, ["torch"])
63
+
64
+
35
65
  class AutoencoderKL(metaclass=DummyObject):
36
66
  _backends = ["torch"]
37
67
 
@@ -47,6 +77,21 @@ class AutoencoderKL(metaclass=DummyObject):
47
77
  requires_backends(cls, ["torch"])
48
78
 
49
79
 
80
+ class AutoencoderKLAllegro(metaclass=DummyObject):
81
+ _backends = ["torch"]
82
+
83
+ def __init__(self, *args, **kwargs):
84
+ requires_backends(self, ["torch"])
85
+
86
+ @classmethod
87
+ def from_config(cls, *args, **kwargs):
88
+ requires_backends(cls, ["torch"])
89
+
90
+ @classmethod
91
+ def from_pretrained(cls, *args, **kwargs):
92
+ requires_backends(cls, ["torch"])
93
+
94
+
50
95
  class AutoencoderKLCogVideoX(metaclass=DummyObject):
51
96
  _backends = ["torch"]
52
97
 
@@ -62,6 +107,51 @@ class AutoencoderKLCogVideoX(metaclass=DummyObject):
62
107
  requires_backends(cls, ["torch"])
63
108
 
64
109
 
110
+ class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
111
+ _backends = ["torch"]
112
+
113
+ def __init__(self, *args, **kwargs):
114
+ requires_backends(self, ["torch"])
115
+
116
+ @classmethod
117
+ def from_config(cls, *args, **kwargs):
118
+ requires_backends(cls, ["torch"])
119
+
120
+ @classmethod
121
+ def from_pretrained(cls, *args, **kwargs):
122
+ requires_backends(cls, ["torch"])
123
+
124
+
125
+ class AutoencoderKLLTXVideo(metaclass=DummyObject):
126
+ _backends = ["torch"]
127
+
128
+ def __init__(self, *args, **kwargs):
129
+ requires_backends(self, ["torch"])
130
+
131
+ @classmethod
132
+ def from_config(cls, *args, **kwargs):
133
+ requires_backends(cls, ["torch"])
134
+
135
+ @classmethod
136
+ def from_pretrained(cls, *args, **kwargs):
137
+ requires_backends(cls, ["torch"])
138
+
139
+
140
+ class AutoencoderKLMochi(metaclass=DummyObject):
141
+ _backends = ["torch"]
142
+
143
+ def __init__(self, *args, **kwargs):
144
+ requires_backends(self, ["torch"])
145
+
146
+ @classmethod
147
+ def from_config(cls, *args, **kwargs):
148
+ requires_backends(cls, ["torch"])
149
+
150
+ @classmethod
151
+ def from_pretrained(cls, *args, **kwargs):
152
+ requires_backends(cls, ["torch"])
153
+
154
+
65
155
  class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
66
156
  _backends = ["torch"]
67
157
 
@@ -122,6 +212,21 @@ class CogVideoXTransformer3DModel(metaclass=DummyObject):
122
212
  requires_backends(cls, ["torch"])
123
213
 
124
214
 
215
+ class CogView3PlusTransformer2DModel(metaclass=DummyObject):
216
+ _backends = ["torch"]
217
+
218
+ def __init__(self, *args, **kwargs):
219
+ requires_backends(self, ["torch"])
220
+
221
+ @classmethod
222
+ def from_config(cls, *args, **kwargs):
223
+ requires_backends(cls, ["torch"])
224
+
225
+ @classmethod
226
+ def from_pretrained(cls, *args, **kwargs):
227
+ requires_backends(cls, ["torch"])
228
+
229
+
125
230
  class ConsistencyDecoderVAE(metaclass=DummyObject):
126
231
  _backends = ["torch"]
127
232
 
@@ -152,6 +257,21 @@ class ControlNetModel(metaclass=DummyObject):
152
257
  requires_backends(cls, ["torch"])
153
258
 
154
259
 
260
+ class ControlNetUnionModel(metaclass=DummyObject):
261
+ _backends = ["torch"]
262
+
263
+ def __init__(self, *args, **kwargs):
264
+ requires_backends(self, ["torch"])
265
+
266
+ @classmethod
267
+ def from_config(cls, *args, **kwargs):
268
+ requires_backends(cls, ["torch"])
269
+
270
+ @classmethod
271
+ def from_pretrained(cls, *args, **kwargs):
272
+ requires_backends(cls, ["torch"])
273
+
274
+
155
275
  class ControlNetXSAdapter(metaclass=DummyObject):
156
276
  _backends = ["torch"]
157
277
 
@@ -182,6 +302,36 @@ class DiTTransformer2DModel(metaclass=DummyObject):
182
302
  requires_backends(cls, ["torch"])
183
303
 
184
304
 
305
+ class FluxControlNetModel(metaclass=DummyObject):
306
+ _backends = ["torch"]
307
+
308
+ def __init__(self, *args, **kwargs):
309
+ requires_backends(self, ["torch"])
310
+
311
+ @classmethod
312
+ def from_config(cls, *args, **kwargs):
313
+ requires_backends(cls, ["torch"])
314
+
315
+ @classmethod
316
+ def from_pretrained(cls, *args, **kwargs):
317
+ requires_backends(cls, ["torch"])
318
+
319
+
320
+ class FluxMultiControlNetModel(metaclass=DummyObject):
321
+ _backends = ["torch"]
322
+
323
+ def __init__(self, *args, **kwargs):
324
+ requires_backends(self, ["torch"])
325
+
326
+ @classmethod
327
+ def from_config(cls, *args, **kwargs):
328
+ requires_backends(cls, ["torch"])
329
+
330
+ @classmethod
331
+ def from_pretrained(cls, *args, **kwargs):
332
+ requires_backends(cls, ["torch"])
333
+
334
+
185
335
  class FluxTransformer2DModel(metaclass=DummyObject):
186
336
  _backends = ["torch"]
187
337
 
@@ -242,6 +392,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
242
392
  requires_backends(cls, ["torch"])
243
393
 
244
394
 
395
+ class HunyuanVideoTransformer3DModel(metaclass=DummyObject):
396
+ _backends = ["torch"]
397
+
398
+ def __init__(self, *args, **kwargs):
399
+ requires_backends(self, ["torch"])
400
+
401
+ @classmethod
402
+ def from_config(cls, *args, **kwargs):
403
+ requires_backends(cls, ["torch"])
404
+
405
+ @classmethod
406
+ def from_pretrained(cls, *args, **kwargs):
407
+ requires_backends(cls, ["torch"])
408
+
409
+
245
410
  class I2VGenXLUNet(metaclass=DummyObject):
246
411
  _backends = ["torch"]
247
412
 
@@ -287,6 +452,21 @@ class LatteTransformer3DModel(metaclass=DummyObject):
287
452
  requires_backends(cls, ["torch"])
288
453
 
289
454
 
455
+ class LTXVideoTransformer3DModel(metaclass=DummyObject):
456
+ _backends = ["torch"]
457
+
458
+ def __init__(self, *args, **kwargs):
459
+ requires_backends(self, ["torch"])
460
+
461
+ @classmethod
462
+ def from_config(cls, *args, **kwargs):
463
+ requires_backends(cls, ["torch"])
464
+
465
+ @classmethod
466
+ def from_pretrained(cls, *args, **kwargs):
467
+ requires_backends(cls, ["torch"])
468
+
469
+
290
470
  class LuminaNextDiT2DModel(metaclass=DummyObject):
291
471
  _backends = ["torch"]
292
472
 
@@ -302,6 +482,21 @@ class LuminaNextDiT2DModel(metaclass=DummyObject):
302
482
  requires_backends(cls, ["torch"])
303
483
 
304
484
 
485
+ class MochiTransformer3DModel(metaclass=DummyObject):
486
+ _backends = ["torch"]
487
+
488
+ def __init__(self, *args, **kwargs):
489
+ requires_backends(self, ["torch"])
490
+
491
+ @classmethod
492
+ def from_config(cls, *args, **kwargs):
493
+ requires_backends(cls, ["torch"])
494
+
495
+ @classmethod
496
+ def from_pretrained(cls, *args, **kwargs):
497
+ requires_backends(cls, ["torch"])
498
+
499
+
305
500
  class ModelMixin(metaclass=DummyObject):
306
501
  _backends = ["torch"]
307
502
 
@@ -347,6 +542,21 @@ class MultiAdapter(metaclass=DummyObject):
347
542
  requires_backends(cls, ["torch"])
348
543
 
349
544
 
545
+ class MultiControlNetModel(metaclass=DummyObject):
546
+ _backends = ["torch"]
547
+
548
+ def __init__(self, *args, **kwargs):
549
+ requires_backends(self, ["torch"])
550
+
551
+ @classmethod
552
+ def from_config(cls, *args, **kwargs):
553
+ requires_backends(cls, ["torch"])
554
+
555
+ @classmethod
556
+ def from_pretrained(cls, *args, **kwargs):
557
+ requires_backends(cls, ["torch"])
558
+
559
+
350
560
  class PixArtTransformer2DModel(metaclass=DummyObject):
351
561
  _backends = ["torch"]
352
562
 
@@ -377,6 +587,21 @@ class PriorTransformer(metaclass=DummyObject):
377
587
  requires_backends(cls, ["torch"])
378
588
 
379
589
 
590
+ class SanaTransformer2DModel(metaclass=DummyObject):
591
+ _backends = ["torch"]
592
+
593
+ def __init__(self, *args, **kwargs):
594
+ requires_backends(self, ["torch"])
595
+
596
+ @classmethod
597
+ def from_config(cls, *args, **kwargs):
598
+ requires_backends(cls, ["torch"])
599
+
600
+ @classmethod
601
+ def from_pretrained(cls, *args, **kwargs):
602
+ requires_backends(cls, ["torch"])
603
+
604
+
380
605
  class SD3ControlNetModel(metaclass=DummyObject):
381
606
  _backends = ["torch"]
382
607
 
@@ -975,6 +1200,21 @@ class StableDiffusionMixin(metaclass=DummyObject):
975
1200
  requires_backends(cls, ["torch"])
976
1201
 
977
1202
 
1203
+ class DiffusersQuantizer(metaclass=DummyObject):
1204
+ _backends = ["torch"]
1205
+
1206
+ def __init__(self, *args, **kwargs):
1207
+ requires_backends(self, ["torch"])
1208
+
1209
+ @classmethod
1210
+ def from_config(cls, *args, **kwargs):
1211
+ requires_backends(cls, ["torch"])
1212
+
1213
+ @classmethod
1214
+ def from_pretrained(cls, *args, **kwargs):
1215
+ requires_backends(cls, ["torch"])
1216
+
1217
+
978
1218
  class AmusedScheduler(metaclass=DummyObject):
979
1219
  _backends = ["torch"]
980
1220