diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Union
15
+ from typing import Callable, Dict, List, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
+ import torch.nn as nn
18
19
 
19
20
  from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
21
+ from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
22
+ from ..models.transformers.transformer_2d import Transformer2DModel
20
23
  from ..models.unets.unet_motion_model import (
24
+ AnimateDiffTransformer3D,
21
25
  CrossAttnDownBlockMotion,
22
26
  DownBlockMotion,
23
27
  UpBlockMotion,
24
28
  )
29
+ from ..pipelines.pipeline_utils import DiffusionPipeline
25
30
  from ..utils import logging
26
31
  from ..utils.torch_utils import randn_tensor
27
32
 
@@ -29,6 +34,114 @@ from ..utils.torch_utils import randn_tensor
29
34
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
35
 
31
36
 
37
+ class SplitInferenceModule(nn.Module):
38
+ r"""
39
+ A wrapper module class that splits inputs along a specified dimension before performing a forward pass.
40
+
41
+ This module is useful when you need to perform inference on large tensors in a memory-efficient way by breaking
42
+ them into smaller chunks, processing each chunk separately, and then reassembling the results.
43
+
44
+ Args:
45
+ module (`nn.Module`):
46
+ The underlying PyTorch module that will be applied to each chunk of split inputs.
47
+ split_size (`int`, defaults to `1`):
48
+ The size of each chunk after splitting the input tensor.
49
+ split_dim (`int`, defaults to `0`):
50
+ The dimension along which the input tensors are split.
51
+ input_kwargs_to_split (`List[str]`, defaults to `["hidden_states"]`):
52
+ A list of keyword arguments (strings) that represent the input tensors to be split.
53
+
54
+ Workflow:
55
+ 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using
56
+ `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`.
57
+ 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments
58
+ that were passed.
59
+ 3. The output tensors from each split are concatenated back together along `split_dim` before returning.
60
+
61
+ Example:
62
+ ```python
63
+ >>> import torch
64
+ >>> import torch.nn as nn
65
+
66
+ >>> model = nn.Linear(1000, 1000)
67
+ >>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"])
68
+
69
+ >>> input_tensor = torch.randn(42, 1000)
70
+ >>> # Will split the tensor into 21 slices of shape [2, 1000].
71
+ >>> output = split_module(input=input_tensor)
72
+ ```
73
+
74
+ It is also possible to nest `SplitInferenceModule` across different split dimensions for more complex
75
+ multi-dimensional splitting.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ module: nn.Module,
81
+ split_size: int = 1,
82
+ split_dim: int = 0,
83
+ input_kwargs_to_split: List[str] = ["hidden_states"],
84
+ ) -> None:
85
+ super().__init__()
86
+
87
+ self.module = module
88
+ self.split_size = split_size
89
+ self.split_dim = split_dim
90
+ self.input_kwargs_to_split = set(input_kwargs_to_split)
91
+
92
+ def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
93
+ r"""Forward method for the `SplitInferenceModule`.
94
+
95
+ This method processes the input by splitting specified keyword arguments along a given dimension, running the
96
+ underlying module on each split, and then concatenating the results. The splitting is controlled by the
97
+ `split_size` and `split_dim` parameters specified during initialization.
98
+
99
+ Args:
100
+ *args (`Any`):
101
+ Positional arguments that are passed directly to the `module` without modification.
102
+ **kwargs (`Dict[str, torch.Tensor]`):
103
+ Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the
104
+ entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword
105
+ arguments are passed unchanged.
106
+
107
+ Returns:
108
+ `Union[torch.Tensor, Tuple[torch.Tensor]]`:
109
+ The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred
110
+ without it.
111
+ - If the underlying module returns a single tensor, the result will be a single concatenated tensor
112
+ along the same `split_dim` after processing all splits.
113
+ - If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated
114
+ along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors.
115
+ """
116
+ split_inputs = {}
117
+
118
+ # 1. Split inputs that were specified during initialization and also present in passed kwargs
119
+ for key in list(kwargs.keys()):
120
+ if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]):
121
+ continue
122
+ split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim)
123
+ kwargs.pop(key)
124
+
125
+ # 2. Invoke forward pass across each split
126
+ results = []
127
+ for split_input in zip(*split_inputs.values()):
128
+ inputs = dict(zip(split_inputs.keys(), split_input))
129
+ inputs.update(kwargs)
130
+
131
+ intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
132
+ results.append(intermediate_tensor_or_tensor_tuple)
133
+
134
+ # 3. Concatenate split restuls to obtain final outputs
135
+ if isinstance(results[0], torch.Tensor):
136
+ return torch.cat(results, dim=self.split_dim)
137
+ elif isinstance(results[0], tuple):
138
+ return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)])
139
+ else:
140
+ raise ValueError(
141
+ "In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
142
+ )
143
+
144
+
32
145
  class AnimateDiffFreeNoiseMixin:
33
146
  r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
34
147
 
@@ -69,6 +182,9 @@ class AnimateDiffFreeNoiseMixin:
69
182
  motion_module.transformer_blocks[i].load_state_dict(
70
183
  basic_transfomer_block.state_dict(), strict=True
71
184
  )
185
+ motion_module.transformer_blocks[i].set_chunk_feed_forward(
186
+ basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim
187
+ )
72
188
 
73
189
  def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
74
190
  r"""Helper function to disable FreeNoise in transformer blocks."""
@@ -97,6 +213,145 @@ class AnimateDiffFreeNoiseMixin:
97
213
  motion_module.transformer_blocks[i].load_state_dict(
98
214
  free_noise_transfomer_block.state_dict(), strict=True
99
215
  )
216
+ motion_module.transformer_blocks[i].set_chunk_feed_forward(
217
+ free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim
218
+ )
219
+
220
+ def _check_inputs_free_noise(
221
+ self,
222
+ prompt,
223
+ negative_prompt,
224
+ prompt_embeds,
225
+ negative_prompt_embeds,
226
+ num_frames,
227
+ ) -> None:
228
+ if not isinstance(prompt, (str, dict)):
229
+ raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}")
230
+
231
+ if negative_prompt is not None:
232
+ if not isinstance(negative_prompt, (str, dict)):
233
+ raise ValueError(
234
+ f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}"
235
+ )
236
+
237
+ if prompt_embeds is not None or negative_prompt_embeds is not None:
238
+ raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.")
239
+
240
+ frame_indices = [isinstance(x, int) for x in prompt.keys()]
241
+ frame_prompts = [isinstance(x, str) for x in prompt.values()]
242
+ min_frame = min(list(prompt.keys()))
243
+ max_frame = max(list(prompt.keys()))
244
+
245
+ if not all(frame_indices):
246
+ raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.")
247
+ if not all(frame_prompts):
248
+ raise ValueError("Expected str values in `prompt` dict for FreeNoise.")
249
+ if min_frame != 0:
250
+ raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.")
251
+ if max_frame >= num_frames:
252
+ raise ValueError(
253
+ f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing."
254
+ )
255
+
256
+ def _encode_prompt_free_noise(
257
+ self,
258
+ prompt: Union[str, Dict[int, str]],
259
+ num_frames: int,
260
+ device: torch.device,
261
+ num_videos_per_prompt: int,
262
+ do_classifier_free_guidance: bool,
263
+ negative_prompt: Optional[Union[str, Dict[int, str]]] = None,
264
+ prompt_embeds: Optional[torch.Tensor] = None,
265
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
266
+ lora_scale: Optional[float] = None,
267
+ clip_skip: Optional[int] = None,
268
+ ) -> torch.Tensor:
269
+ if negative_prompt is None:
270
+ negative_prompt = ""
271
+
272
+ # Ensure that we have a dictionary of prompts
273
+ if isinstance(prompt, str):
274
+ prompt = {0: prompt}
275
+ if isinstance(negative_prompt, str):
276
+ negative_prompt = {0: negative_prompt}
277
+
278
+ self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames)
279
+
280
+ # Sort the prompts based on frame indices
281
+ prompt = dict(sorted(prompt.items()))
282
+ negative_prompt = dict(sorted(negative_prompt.items()))
283
+
284
+ # Ensure that we have a prompt for the last frame index
285
+ prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]]
286
+ negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]]
287
+
288
+ frame_indices = list(prompt.keys())
289
+ frame_prompts = list(prompt.values())
290
+ frame_negative_indices = list(negative_prompt.keys())
291
+ frame_negative_prompts = list(negative_prompt.values())
292
+
293
+ # Generate and interpolate positive prompts
294
+ prompt_embeds, _ = self.encode_prompt(
295
+ prompt=frame_prompts,
296
+ device=device,
297
+ num_images_per_prompt=num_videos_per_prompt,
298
+ do_classifier_free_guidance=False,
299
+ negative_prompt=None,
300
+ prompt_embeds=None,
301
+ negative_prompt_embeds=None,
302
+ lora_scale=lora_scale,
303
+ clip_skip=clip_skip,
304
+ )
305
+
306
+ shape = (num_frames, *prompt_embeds.shape[1:])
307
+ prompt_interpolation_embeds = prompt_embeds.new_zeros(shape)
308
+
309
+ for i in range(len(frame_indices) - 1):
310
+ start_frame = frame_indices[i]
311
+ end_frame = frame_indices[i + 1]
312
+ start_tensor = prompt_embeds[i].unsqueeze(0)
313
+ end_tensor = prompt_embeds[i + 1].unsqueeze(0)
314
+
315
+ prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(
316
+ start_frame, end_frame, start_tensor, end_tensor
317
+ )
318
+
319
+ # Generate and interpolate negative prompts
320
+ negative_prompt_embeds = None
321
+ negative_prompt_interpolation_embeds = None
322
+
323
+ if do_classifier_free_guidance:
324
+ _, negative_prompt_embeds = self.encode_prompt(
325
+ prompt=[""] * len(frame_negative_prompts),
326
+ device=device,
327
+ num_images_per_prompt=num_videos_per_prompt,
328
+ do_classifier_free_guidance=True,
329
+ negative_prompt=frame_negative_prompts,
330
+ prompt_embeds=None,
331
+ negative_prompt_embeds=None,
332
+ lora_scale=lora_scale,
333
+ clip_skip=clip_skip,
334
+ )
335
+
336
+ negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape)
337
+
338
+ for i in range(len(frame_negative_indices) - 1):
339
+ start_frame = frame_negative_indices[i]
340
+ end_frame = frame_negative_indices[i + 1]
341
+ start_tensor = negative_prompt_embeds[i].unsqueeze(0)
342
+ end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
343
+
344
+ negative_prompt_interpolation_embeds[
345
+ start_frame : end_frame + 1
346
+ ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
347
+
348
+ prompt_embeds = prompt_interpolation_embeds
349
+ negative_prompt_embeds = negative_prompt_interpolation_embeds
350
+
351
+ if do_classifier_free_guidance:
352
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
353
+
354
+ return prompt_embeds, negative_prompt_embeds
100
355
 
101
356
  def _prepare_latents_free_noise(
102
357
  self,
@@ -172,12 +427,29 @@ class AnimateDiffFreeNoiseMixin:
172
427
  latents = latents[:, :, :num_frames]
173
428
  return latents
174
429
 
430
+ def _lerp(
431
+ self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor
432
+ ) -> torch.Tensor:
433
+ num_indices = end_index - start_index + 1
434
+ interpolated_tensors = []
435
+
436
+ for i in range(num_indices):
437
+ alpha = i / (num_indices - 1)
438
+ interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor
439
+ interpolated_tensors.append(interpolated_tensor)
440
+
441
+ interpolated_tensors = torch.cat(interpolated_tensors)
442
+ return interpolated_tensors
443
+
175
444
  def enable_free_noise(
176
445
  self,
177
446
  context_length: Optional[int] = 16,
178
447
  context_stride: int = 4,
179
448
  weighting_scheme: str = "pyramid",
180
449
  noise_type: str = "shuffle_context",
450
+ prompt_interpolation_callback: Optional[
451
+ Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]
452
+ ] = None,
181
453
  ) -> None:
182
454
  r"""
183
455
  Enable long video generation using FreeNoise.
@@ -195,13 +467,27 @@ class AnimateDiffFreeNoiseMixin:
195
467
  weighting_scheme (`str`, defaults to `pyramid`):
196
468
  Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting
197
469
  schemes are supported currently:
470
+ - "flat"
471
+ Performs weighting averaging with a flat weight pattern: [1, 1, 1, 1, 1].
198
472
  - "pyramid"
199
- Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
473
+ Performs weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
474
+ - "delayed_reverse_sawtooth"
475
+ Performs weighted averaging with low weights for earlier frames and high-to-low weights for
476
+ later frames: [0.01, 0.01, 3, 2, 1].
200
477
  noise_type (`str`, defaults to "shuffle_context"):
201
- TODO
478
+ Must be one of ["shuffle_context", "repeat_context", "random"].
479
+ - "shuffle_context"
480
+ Shuffles a fixed batch of `context_length` latents to create a final latent of size
481
+ `num_frames`. This is usually the best setting for most generation scenarious. However, there
482
+ might be visible repetition noticeable in the kinds of motion/animation generated.
483
+ - "repeated_context"
484
+ Repeats a fixed batch of `context_length` latents to create a final latent of size
485
+ `num_frames`.
486
+ - "random"
487
+ The final latents are random without any repetition.
202
488
  """
203
489
 
204
- allowed_weighting_scheme = ["pyramid"]
490
+ allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"]
205
491
  allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
206
492
 
207
493
  if context_length > self.motion_adapter.config.motion_max_seq_length:
@@ -219,18 +505,92 @@ class AnimateDiffFreeNoiseMixin:
219
505
  self._free_noise_context_stride = context_stride
220
506
  self._free_noise_weighting_scheme = weighting_scheme
221
507
  self._free_noise_noise_type = noise_type
508
+ self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp
509
+
510
+ if hasattr(self.unet.mid_block, "motion_modules"):
511
+ blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
512
+ else:
513
+ blocks = [*self.unet.down_blocks, *self.unet.up_blocks]
222
514
 
223
- blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
224
515
  for block in blocks:
225
516
  self._enable_free_noise_in_block(block)
226
517
 
227
518
  def disable_free_noise(self) -> None:
519
+ r"""Disable the FreeNoise sampling mechanism."""
228
520
  self._free_noise_context_length = None
229
521
 
522
+ if hasattr(self.unet.mid_block, "motion_modules"):
523
+ blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
524
+ else:
525
+ blocks = [*self.unet.down_blocks, *self.unet.up_blocks]
526
+
230
527
  blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
231
528
  for block in blocks:
232
529
  self._disable_free_noise_in_block(block)
233
530
 
531
+ def _enable_split_inference_motion_modules_(
532
+ self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int
533
+ ) -> None:
534
+ for motion_module in motion_modules:
535
+ motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"])
536
+
537
+ for i in range(len(motion_module.transformer_blocks)):
538
+ motion_module.transformer_blocks[i] = SplitInferenceModule(
539
+ motion_module.transformer_blocks[i],
540
+ spatial_split_size,
541
+ 0,
542
+ ["hidden_states", "encoder_hidden_states"],
543
+ )
544
+
545
+ motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"])
546
+
547
+ def _enable_split_inference_attentions_(
548
+ self, attentions: List[Transformer2DModel], temporal_split_size: int
549
+ ) -> None:
550
+ for i in range(len(attentions)):
551
+ attentions[i] = SplitInferenceModule(
552
+ attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"]
553
+ )
554
+
555
+ def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None:
556
+ for i in range(len(resnets)):
557
+ resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"])
558
+
559
+ def _enable_split_inference_samplers_(
560
+ self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int
561
+ ) -> None:
562
+ for i in range(len(samplers)):
563
+ samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"])
564
+
565
+ def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None:
566
+ r"""
567
+ Enable FreeNoise memory optimizations by utilizing
568
+ [`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks.
569
+
570
+ Args:
571
+ spatial_split_size (`int`, defaults to `256`):
572
+ The split size across spatial dimensions for internal blocks. This is used in facilitating split
573
+ inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion
574
+ modeling blocks.
575
+ temporal_split_size (`int`, defaults to `16`):
576
+ The split size across temporal dimensions for internal blocks. This is used in facilitating split
577
+ inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial
578
+ attention, resnets, downsampling and upsampling blocks.
579
+ """
580
+ # TODO(aryan): Discuss on what's the best way to provide more control to users
581
+ blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
582
+ for block in blocks:
583
+ if getattr(block, "motion_modules", None) is not None:
584
+ self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size)
585
+ if getattr(block, "attentions", None) is not None:
586
+ self._enable_split_inference_attentions_(block.attentions, temporal_split_size)
587
+ if getattr(block, "resnets", None) is not None:
588
+ self._enable_split_inference_resnets_(block.resnets, temporal_split_size)
589
+ if getattr(block, "downsamplers", None) is not None:
590
+ self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size)
591
+ if getattr(block, "upsamplers", None) is not None:
592
+ self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size)
593
+
234
594
  @property
235
595
  def free_noise_enabled(self):
236
596
  return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
@@ -0,0 +1,48 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
26
+
27
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
+ try:
29
+ if not (is_transformers_available() and is_torch_available()):
30
+ raise OptionalDependencyNotAvailable()
31
+
32
+ except OptionalDependencyNotAvailable:
33
+ from ...utils.dummy_torch_and_transformers_objects import *
34
+ else:
35
+ from .pipeline_hunyuan_video import HunyuanVideoPipeline
36
+
37
+ else:
38
+ import sys
39
+
40
+ sys.modules[__name__] = _LazyModule(
41
+ __name__,
42
+ globals()["__file__"],
43
+ _import_structure,
44
+ module_spec=__spec__,
45
+ )
46
+
47
+ for name, value in _dummy_objects.items():
48
+ setattr(sys.modules[__name__], name, value)