diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -138,10 +138,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
138
138
  self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
139
139
  self.tile_overlap_factor = 0.25
140
140
 
141
- def _set_gradient_checkpointing(self, module, value=False):
142
- if isinstance(module, (Encoder, Decoder)):
143
- module.gradient_checkpointing = value
144
-
145
141
  def enable_tiling(self, use_tiling: bool = True):
146
142
  r"""
147
143
  Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -103,7 +103,7 @@ class AllegroTemporalConvLayer(nn.Module):
103
103
  if self.down_sample:
104
104
  identity = hidden_states[:, :, ::2]
105
105
  elif self.up_sample:
106
- identity = hidden_states.repeat_interleave(2, dim=2)
106
+ identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2)
107
107
  else:
108
108
  identity = hidden_states
109
109
 
@@ -507,19 +507,12 @@ class AllegroEncoder3D(nn.Module):
507
507
  sample = sample + residual
508
508
 
509
509
  if torch.is_grad_enabled() and self.gradient_checkpointing:
510
-
511
- def create_custom_forward(module):
512
- def custom_forward(*inputs):
513
- return module(*inputs)
514
-
515
- return custom_forward
516
-
517
510
  # Down blocks
518
511
  for down_block in self.down_blocks:
519
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
512
+ sample = self._gradient_checkpointing_func(down_block, sample)
520
513
 
521
514
  # Mid block
522
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
515
+ sample = self._gradient_checkpointing_func(self.mid_block, sample)
523
516
  else:
524
517
  # Down blocks
525
518
  for down_block in self.down_blocks:
@@ -647,19 +640,12 @@ class AllegroDecoder3D(nn.Module):
647
640
  upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
648
641
 
649
642
  if torch.is_grad_enabled() and self.gradient_checkpointing:
650
-
651
- def create_custom_forward(module):
652
- def custom_forward(*inputs):
653
- return module(*inputs)
654
-
655
- return custom_forward
656
-
657
643
  # Mid block
658
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
644
+ sample = self._gradient_checkpointing_func(self.mid_block, sample)
659
645
 
660
646
  # Up blocks
661
647
  for up_block in self.up_blocks:
662
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
648
+ sample = self._gradient_checkpointing_func(up_block, sample)
663
649
 
664
650
  else:
665
651
  # Mid block
@@ -809,10 +795,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
809
795
  sample_size - self.tile_overlap_w,
810
796
  )
811
797
 
812
- def _set_gradient_checkpointing(self, module, value=False):
813
- if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
814
- module.gradient_checkpointing = value
815
-
816
798
  def enable_tiling(self) -> None:
817
799
  r"""
818
800
  Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -105,6 +105,7 @@ class CogVideoXCausalConv3d(nn.Module):
105
105
  self.width_pad = width_pad
106
106
  self.time_pad = time_pad
107
107
  self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
108
+ self.const_padding_conv3d = (0, self.width_pad, self.height_pad)
108
109
 
109
110
  self.temporal_dim = 2
110
111
  self.time_kernel_size = time_kernel_size
@@ -117,6 +118,8 @@ class CogVideoXCausalConv3d(nn.Module):
117
118
  kernel_size=kernel_size,
118
119
  stride=stride,
119
120
  dilation=dilation,
121
+ padding=0 if self.pad_mode == "replicate" else self.const_padding_conv3d,
122
+ padding_mode="zeros",
120
123
  )
121
124
 
122
125
  def fake_context_parallel_forward(
@@ -137,9 +140,7 @@ class CogVideoXCausalConv3d(nn.Module):
137
140
  if self.pad_mode == "replicate":
138
141
  conv_cache = None
139
142
  else:
140
- padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
141
143
  conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
142
- inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
143
144
 
144
145
  output = self.conv(inputs)
145
146
  return output, conv_cache
@@ -421,15 +422,8 @@ class CogVideoXDownBlock3D(nn.Module):
421
422
  conv_cache_key = f"resnet_{i}"
422
423
 
423
424
  if torch.is_grad_enabled() and self.gradient_checkpointing:
424
-
425
- def create_custom_forward(module):
426
- def create_forward(*inputs):
427
- return module(*inputs)
428
-
429
- return create_forward
430
-
431
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
432
- create_custom_forward(resnet),
425
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
426
+ resnet,
433
427
  hidden_states,
434
428
  temb,
435
429
  zq,
@@ -523,15 +517,8 @@ class CogVideoXMidBlock3D(nn.Module):
523
517
  conv_cache_key = f"resnet_{i}"
524
518
 
525
519
  if torch.is_grad_enabled() and self.gradient_checkpointing:
526
-
527
- def create_custom_forward(module):
528
- def create_forward(*inputs):
529
- return module(*inputs)
530
-
531
- return create_forward
532
-
533
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
534
- create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
520
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
521
+ resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key)
535
522
  )
536
523
  else:
537
524
  hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -637,15 +624,8 @@ class CogVideoXUpBlock3D(nn.Module):
637
624
  conv_cache_key = f"resnet_{i}"
638
625
 
639
626
  if torch.is_grad_enabled() and self.gradient_checkpointing:
640
-
641
- def create_custom_forward(module):
642
- def create_forward(*inputs):
643
- return module(*inputs)
644
-
645
- return create_forward
646
-
647
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
648
- create_custom_forward(resnet),
627
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
628
+ resnet,
649
629
  hidden_states,
650
630
  temb,
651
631
  zq,
@@ -774,18 +754,11 @@ class CogVideoXEncoder3D(nn.Module):
774
754
  hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
775
755
 
776
756
  if torch.is_grad_enabled() and self.gradient_checkpointing:
777
-
778
- def create_custom_forward(module):
779
- def custom_forward(*inputs):
780
- return module(*inputs)
781
-
782
- return custom_forward
783
-
784
757
  # 1. Down
785
758
  for i, down_block in enumerate(self.down_blocks):
786
759
  conv_cache_key = f"down_block_{i}"
787
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
788
- create_custom_forward(down_block),
760
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
761
+ down_block,
789
762
  hidden_states,
790
763
  temb,
791
764
  None,
@@ -793,8 +766,8 @@ class CogVideoXEncoder3D(nn.Module):
793
766
  )
794
767
 
795
768
  # 2. Mid
796
- hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
797
- create_custom_forward(self.mid_block),
769
+ hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
770
+ self.mid_block,
798
771
  hidden_states,
799
772
  temb,
800
773
  None,
@@ -940,16 +913,9 @@ class CogVideoXDecoder3D(nn.Module):
940
913
  hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
941
914
 
942
915
  if torch.is_grad_enabled() and self.gradient_checkpointing:
943
-
944
- def create_custom_forward(module):
945
- def custom_forward(*inputs):
946
- return module(*inputs)
947
-
948
- return custom_forward
949
-
950
916
  # 1. Mid
951
- hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
952
- create_custom_forward(self.mid_block),
917
+ hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
918
+ self.mid_block,
953
919
  hidden_states,
954
920
  temb,
955
921
  sample,
@@ -959,8 +925,8 @@ class CogVideoXDecoder3D(nn.Module):
959
925
  # 2. Up
960
926
  for i, up_block in enumerate(self.up_blocks):
961
927
  conv_cache_key = f"up_block_{i}"
962
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
963
- create_custom_forward(up_block),
928
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
929
+ up_block,
964
930
  hidden_states,
965
931
  temb,
966
932
  sample,
@@ -1122,10 +1088,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1122
1088
  self.tile_overlap_factor_height = 1 / 6
1123
1089
  self.tile_overlap_factor_width = 1 / 5
1124
1090
 
1125
- def _set_gradient_checkpointing(self, module, value=False):
1126
- if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1127
- module.gradient_checkpointing = value
1128
-
1129
1091
  def enable_tiling(
1130
1092
  self,
1131
1093
  tile_sample_min_height: Optional[int] = None,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple, Union
15
+ from typing import Optional, Tuple, Union
16
16
 
17
17
  import numpy as np
18
18
  import torch
@@ -21,7 +21,7 @@ import torch.nn.functional as F
21
21
  import torch.utils.checkpoint
22
22
 
23
23
  from ...configuration_utils import ConfigMixin, register_to_config
24
- from ...utils import is_torch_version, logging
24
+ from ...utils import logging
25
25
  from ...utils.accelerate_utils import apply_forward_hook
26
26
  from ..activations import get_activation
27
27
  from ..attention_processor import Attention
@@ -36,11 +36,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
36
  def prepare_causal_attention_mask(
37
37
  num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
38
38
  ) -> torch.Tensor:
39
- seq_len = num_frames * height_width
40
- mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
41
- for i in range(seq_len):
42
- i_frame = i // height_width
43
- mask[i, : (i_frame + 1) * height_width] = 0
39
+ indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)
40
+ indices_blocks = indices.repeat_interleave(height_width)
41
+ x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy")
42
+ mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype)
43
+
44
44
  if batch_size is not None:
45
45
  mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
46
46
  return mask
@@ -252,21 +252,7 @@ class HunyuanVideoMidBlock3D(nn.Module):
252
252
 
253
253
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
254
254
  if torch.is_grad_enabled() and self.gradient_checkpointing:
255
-
256
- def create_custom_forward(module, return_dict=None):
257
- def custom_forward(*inputs):
258
- if return_dict is not None:
259
- return module(*inputs, return_dict=return_dict)
260
- else:
261
- return module(*inputs)
262
-
263
- return custom_forward
264
-
265
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
266
-
267
- hidden_states = torch.utils.checkpoint.checkpoint(
268
- create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs
269
- )
255
+ hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states)
270
256
 
271
257
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
272
258
  if attn is not None:
@@ -278,9 +264,7 @@ class HunyuanVideoMidBlock3D(nn.Module):
278
264
  hidden_states = attn(hidden_states, attention_mask=attention_mask)
279
265
  hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
280
266
 
281
- hidden_states = torch.utils.checkpoint.checkpoint(
282
- create_custom_forward(resnet), hidden_states, **ckpt_kwargs
283
- )
267
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
284
268
 
285
269
  else:
286
270
  hidden_states = self.resnets[0](hidden_states)
@@ -350,22 +334,8 @@ class HunyuanVideoDownBlock3D(nn.Module):
350
334
 
351
335
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
352
336
  if torch.is_grad_enabled() and self.gradient_checkpointing:
353
-
354
- def create_custom_forward(module, return_dict=None):
355
- def custom_forward(*inputs):
356
- if return_dict is not None:
357
- return module(*inputs, return_dict=return_dict)
358
- else:
359
- return module(*inputs)
360
-
361
- return custom_forward
362
-
363
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
364
-
365
337
  for resnet in self.resnets:
366
- hidden_states = torch.utils.checkpoint.checkpoint(
367
- create_custom_forward(resnet), hidden_states, **ckpt_kwargs
368
- )
338
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
369
339
  else:
370
340
  for resnet in self.resnets:
371
341
  hidden_states = resnet(hidden_states)
@@ -426,22 +396,8 @@ class HunyuanVideoUpBlock3D(nn.Module):
426
396
 
427
397
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
428
398
  if torch.is_grad_enabled() and self.gradient_checkpointing:
429
-
430
- def create_custom_forward(module, return_dict=None):
431
- def custom_forward(*inputs):
432
- if return_dict is not None:
433
- return module(*inputs, return_dict=return_dict)
434
- else:
435
- return module(*inputs)
436
-
437
- return custom_forward
438
-
439
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
440
-
441
399
  for resnet in self.resnets:
442
- hidden_states = torch.utils.checkpoint.checkpoint(
443
- create_custom_forward(resnet), hidden_states, **ckpt_kwargs
444
- )
400
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
445
401
 
446
402
  else:
447
403
  for resnet in self.resnets:
@@ -545,26 +501,10 @@ class HunyuanVideoEncoder3D(nn.Module):
545
501
  hidden_states = self.conv_in(hidden_states)
546
502
 
547
503
  if torch.is_grad_enabled() and self.gradient_checkpointing:
548
-
549
- def create_custom_forward(module, return_dict=None):
550
- def custom_forward(*inputs):
551
- if return_dict is not None:
552
- return module(*inputs, return_dict=return_dict)
553
- else:
554
- return module(*inputs)
555
-
556
- return custom_forward
557
-
558
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
559
-
560
504
  for down_block in self.down_blocks:
561
- hidden_states = torch.utils.checkpoint.checkpoint(
562
- create_custom_forward(down_block), hidden_states, **ckpt_kwargs
563
- )
505
+ hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
564
506
 
565
- hidden_states = torch.utils.checkpoint.checkpoint(
566
- create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
567
- )
507
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
568
508
  else:
569
509
  for down_block in self.down_blocks:
570
510
  hidden_states = down_block(hidden_states)
@@ -667,26 +607,10 @@ class HunyuanVideoDecoder3D(nn.Module):
667
607
  hidden_states = self.conv_in(hidden_states)
668
608
 
669
609
  if torch.is_grad_enabled() and self.gradient_checkpointing:
670
-
671
- def create_custom_forward(module, return_dict=None):
672
- def custom_forward(*inputs):
673
- if return_dict is not None:
674
- return module(*inputs, return_dict=return_dict)
675
- else:
676
- return module(*inputs)
677
-
678
- return custom_forward
679
-
680
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
681
-
682
- hidden_states = torch.utils.checkpoint.checkpoint(
683
- create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
684
- )
610
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
685
611
 
686
612
  for up_block in self.up_blocks:
687
- hidden_states = torch.utils.checkpoint.checkpoint(
688
- create_custom_forward(up_block), hidden_states, **ckpt_kwargs
689
- )
613
+ hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
690
614
  else:
691
615
  hidden_states = self.mid_block(hidden_states)
692
616
 
@@ -786,7 +710,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
786
710
  self.use_tiling = False
787
711
 
788
712
  # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
789
- # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
713
+ # at a fixed frame batch size (based on `self.tile_sample_min_num_frames`), the memory requirement can be lowered.
790
714
  self.use_framewise_encoding = True
791
715
  self.use_framewise_decoding = True
792
716
 
@@ -800,10 +724,6 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
800
724
  self.tile_sample_stride_width = 192
801
725
  self.tile_sample_stride_num_frames = 12
802
726
 
803
- def _set_gradient_checkpointing(self, module, value=False):
804
- if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
805
- module.gradient_checkpointing = value
806
-
807
727
  def enable_tiling(
808
728
  self,
809
729
  tile_sample_min_height: Optional[int] = None,
@@ -868,7 +788,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
868
788
  def _encode(self, x: torch.Tensor) -> torch.Tensor:
869
789
  batch_size, num_channels, num_frames, height, width = x.shape
870
790
 
871
- if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
791
+ if self.use_framewise_encoding and num_frames > self.tile_sample_min_num_frames:
872
792
  return self._temporal_tiled_encode(x)
873
793
 
874
794
  if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):