diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +595 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
  384. diffusers-0.33.1.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -37,11 +37,7 @@ from ..embeddings import TimestepEmbedding, Timesteps
37
37
  from ..modeling_utils import ModelMixin
38
38
  from ..transformers.transformer_temporal import TransformerTemporalModel
39
39
  from .unet_3d_blocks import (
40
- CrossAttnDownBlock3D,
41
- CrossAttnUpBlock3D,
42
- DownBlock3D,
43
40
  UNetMidBlock3DCrossAttn,
44
- UpBlock3D,
45
41
  get_down_block,
46
42
  get_up_block,
47
43
  )
@@ -97,6 +93,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
97
93
  """
98
94
 
99
95
  _supports_gradient_checkpointing = False
96
+ _skip_layerwise_casting_patterns = ["norm", "time_embedding"]
100
97
 
101
98
  @register_to_config
102
99
  def __init__(
@@ -471,10 +468,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
471
468
 
472
469
  self.set_attn_processor(processor)
473
470
 
474
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
475
- if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
476
- module.gradient_checkpointing = value
477
-
478
471
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
479
472
  def enable_freeu(self, s1, s2, b1, b2):
480
473
  r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -624,10 +617,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
624
617
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
625
618
  # This would be a good case for the `match` statement (Python 3.10+)
626
619
  is_mps = sample.device.type == "mps"
620
+ is_npu = sample.device.type == "npu"
627
621
  if isinstance(timestep, float):
628
- dtype = torch.float32 if is_mps else torch.float64
622
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
629
623
  else:
630
- dtype = torch.int32 if is_mps else torch.int64
624
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
631
625
  timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
632
626
  elif len(timesteps.shape) == 0:
633
627
  timesteps = timesteps[None].to(sample.device)
@@ -644,8 +638,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
644
638
  t_emb = t_emb.to(dtype=self.dtype)
645
639
 
646
640
  emb = self.time_embedding(t_emb, timestep_cond)
647
- emb = emb.repeat_interleave(repeats=num_frames, dim=0)
648
- encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
641
+ emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
642
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(
643
+ num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
644
+ )
649
645
 
650
646
  # 2. pre-process
651
647
  sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
@@ -35,11 +35,7 @@ from ..embeddings import TimestepEmbedding, Timesteps
35
35
  from ..modeling_utils import ModelMixin
36
36
  from ..transformers.transformer_temporal import TransformerTemporalModel
37
37
  from .unet_3d_blocks import (
38
- CrossAttnDownBlock3D,
39
- CrossAttnUpBlock3D,
40
- DownBlock3D,
41
38
  UNetMidBlock3DCrossAttn,
42
- UpBlock3D,
43
39
  get_down_block,
44
40
  get_up_block,
45
41
  )
@@ -436,11 +432,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
436
432
 
437
433
  self.set_attn_processor(processor)
438
434
 
439
- # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel._set_gradient_checkpointing
440
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
441
- if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
442
- module.gradient_checkpointing = value
443
-
444
435
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
445
436
  def enable_freeu(self, s1, s2, b1, b2):
446
437
  r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -575,10 +566,11 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
575
566
  # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can
576
567
  # This would be a good case for the `match` statement (Python 3.10+)
577
568
  is_mps = sample.device.type == "mps"
569
+ is_npu = sample.device.type == "npu"
578
570
  if isinstance(timesteps, float):
579
- dtype = torch.float32 if is_mps else torch.float64
571
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
580
572
  else:
581
- dtype = torch.int32 if is_mps else torch.int64
573
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
582
574
  timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
583
575
  elif len(timesteps.shape) == 0:
584
576
  timesteps = timesteps[None].to(sample.device)
@@ -600,7 +592,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
600
592
 
601
593
  # 3. time + FPS embeddings.
602
594
  emb = t_emb + fps_emb
603
- emb = emb.repeat_interleave(repeats=num_frames, dim=0)
595
+ emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
604
596
 
605
597
  # 4. context embeddings.
606
598
  # The context embeddings consist of both text embeddings from the input prompt
@@ -628,7 +620,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
628
620
  image_emb = self.context_embedding(image_embeddings)
629
621
  image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim)
630
622
  context_emb = torch.cat([context_emb, image_emb], dim=1)
631
- context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)
623
+ context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames)
632
624
 
633
625
  image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
634
626
  image_latents.shape[0] * image_latents.shape[2],
@@ -205,10 +205,6 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
205
205
  """
206
206
  self.set_attn_processor(AttnProcessor())
207
207
 
208
- def _set_gradient_checkpointing(self, module, value=False):
209
- if hasattr(module, "gradient_checkpointing"):
210
- module.gradient_checkpointing = value
211
-
212
208
  def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
213
209
  if encoder_attention_mask is not None:
214
210
  encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
@@ -22,7 +22,7 @@ import torch.utils.checkpoint
22
22
 
23
23
  from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
24
24
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
25
- from ...utils import BaseOutput, deprecate, is_torch_version, logging
25
+ from ...utils import BaseOutput, deprecate, logging
26
26
  from ...utils.torch_utils import apply_freeu
27
27
  from ..attention import BasicTransformerBlock
28
28
  from ..attention_processor import (
@@ -324,25 +324,7 @@ class DownBlockMotion(nn.Module):
324
324
  blocks = zip(self.resnets, self.motion_modules)
325
325
  for resnet, motion_module in blocks:
326
326
  if torch.is_grad_enabled() and self.gradient_checkpointing:
327
-
328
- def create_custom_forward(module):
329
- def custom_forward(*inputs):
330
- return module(*inputs)
331
-
332
- return custom_forward
333
-
334
- if is_torch_version(">=", "1.11.0"):
335
- hidden_states = torch.utils.checkpoint.checkpoint(
336
- create_custom_forward(resnet),
337
- hidden_states,
338
- temb,
339
- use_reentrant=False,
340
- )
341
- else:
342
- hidden_states = torch.utils.checkpoint.checkpoint(
343
- create_custom_forward(resnet), hidden_states, temb
344
- )
345
-
327
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
346
328
  else:
347
329
  hidden_states = resnet(input_tensor=hidden_states, temb=temb)
348
330
 
@@ -514,23 +496,7 @@ class CrossAttnDownBlockMotion(nn.Module):
514
496
  blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
515
497
  for i, (resnet, attn, motion_module) in enumerate(blocks):
516
498
  if torch.is_grad_enabled() and self.gradient_checkpointing:
517
-
518
- def create_custom_forward(module, return_dict=None):
519
- def custom_forward(*inputs):
520
- if return_dict is not None:
521
- return module(*inputs, return_dict=return_dict)
522
- else:
523
- return module(*inputs)
524
-
525
- return custom_forward
526
-
527
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
528
- hidden_states = torch.utils.checkpoint.checkpoint(
529
- create_custom_forward(resnet),
530
- hidden_states,
531
- temb,
532
- **ckpt_kwargs,
533
- )
499
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
534
500
  else:
535
501
  hidden_states = resnet(input_tensor=hidden_states, temb=temb)
536
502
 
@@ -543,10 +509,7 @@ class CrossAttnDownBlockMotion(nn.Module):
543
509
  return_dict=False,
544
510
  )[0]
545
511
 
546
- hidden_states = motion_module(
547
- hidden_states,
548
- num_frames=num_frames,
549
- )
512
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)
550
513
 
551
514
  # apply additional residuals to the output of the last pair of resnet and attention blocks
552
515
  if i == len(blocks) - 1 and additional_residuals is not None:
@@ -733,23 +696,7 @@ class CrossAttnUpBlockMotion(nn.Module):
733
696
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
734
697
 
735
698
  if torch.is_grad_enabled() and self.gradient_checkpointing:
736
-
737
- def create_custom_forward(module, return_dict=None):
738
- def custom_forward(*inputs):
739
- if return_dict is not None:
740
- return module(*inputs, return_dict=return_dict)
741
- else:
742
- return module(*inputs)
743
-
744
- return custom_forward
745
-
746
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
747
- hidden_states = torch.utils.checkpoint.checkpoint(
748
- create_custom_forward(resnet),
749
- hidden_states,
750
- temb,
751
- **ckpt_kwargs,
752
- )
699
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
753
700
  else:
754
701
  hidden_states = resnet(input_tensor=hidden_states, temb=temb)
755
702
 
@@ -762,10 +709,7 @@ class CrossAttnUpBlockMotion(nn.Module):
762
709
  return_dict=False,
763
710
  )[0]
764
711
 
765
- hidden_states = motion_module(
766
- hidden_states,
767
- num_frames=num_frames,
768
- )
712
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)
769
713
 
770
714
  if self.upsamplers is not None:
771
715
  for upsampler in self.upsamplers:
@@ -896,24 +840,7 @@ class UpBlockMotion(nn.Module):
896
840
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
897
841
 
898
842
  if torch.is_grad_enabled() and self.gradient_checkpointing:
899
-
900
- def create_custom_forward(module):
901
- def custom_forward(*inputs):
902
- return module(*inputs)
903
-
904
- return custom_forward
905
-
906
- if is_torch_version(">=", "1.11.0"):
907
- hidden_states = torch.utils.checkpoint.checkpoint(
908
- create_custom_forward(resnet),
909
- hidden_states,
910
- temb,
911
- use_reentrant=False,
912
- )
913
- else:
914
- hidden_states = torch.utils.checkpoint.checkpoint(
915
- create_custom_forward(resnet), hidden_states, temb
916
- )
843
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
917
844
  else:
918
845
  hidden_states = resnet(input_tensor=hidden_states, temb=temb)
919
846
 
@@ -1080,34 +1007,12 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
1080
1007
  )[0]
1081
1008
 
1082
1009
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1083
-
1084
- def create_custom_forward(module, return_dict=None):
1085
- def custom_forward(*inputs):
1086
- if return_dict is not None:
1087
- return module(*inputs, return_dict=return_dict)
1088
- else:
1089
- return module(*inputs)
1090
-
1091
- return custom_forward
1092
-
1093
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1094
- hidden_states = torch.utils.checkpoint.checkpoint(
1095
- create_custom_forward(motion_module),
1096
- hidden_states,
1097
- temb,
1098
- **ckpt_kwargs,
1099
- )
1100
- hidden_states = torch.utils.checkpoint.checkpoint(
1101
- create_custom_forward(resnet),
1102
- hidden_states,
1103
- temb,
1104
- **ckpt_kwargs,
1010
+ hidden_states = self._gradient_checkpointing_func(
1011
+ motion_module, hidden_states, None, None, None, num_frames, None
1105
1012
  )
1013
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
1106
1014
  else:
1107
- hidden_states = motion_module(
1108
- hidden_states,
1109
- num_frames=num_frames,
1110
- )
1015
+ hidden_states = motion_module(hidden_states, None, None, None, num_frames, None)
1111
1016
  hidden_states = resnet(input_tensor=hidden_states, temb=temb)
1112
1017
 
1113
1018
  return hidden_states
@@ -1301,6 +1206,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
1301
1206
  """
1302
1207
 
1303
1208
  _supports_gradient_checkpointing = True
1209
+ _skip_layerwise_casting_patterns = ["norm"]
1304
1210
 
1305
1211
  @register_to_config
1306
1212
  def __init__(
@@ -1965,10 +1871,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
1965
1871
 
1966
1872
  self.set_attn_processor(processor)
1967
1873
 
1968
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
1969
- if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
1970
- module.gradient_checkpointing = value
1971
-
1972
1874
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
1973
1875
  def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
1974
1876
  r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -2114,10 +2016,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
2114
2016
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
2115
2017
  # This would be a good case for the `match` statement (Python 3.10+)
2116
2018
  is_mps = sample.device.type == "mps"
2019
+ is_npu = sample.device.type == "npu"
2117
2020
  if isinstance(timestep, float):
2118
- dtype = torch.float32 if is_mps else torch.float64
2021
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
2119
2022
  else:
2120
- dtype = torch.int32 if is_mps else torch.int64
2023
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
2121
2024
  timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
2122
2025
  elif len(timesteps.shape) == 0:
2123
2026
  timesteps = timesteps[None].to(sample.device)
@@ -2156,7 +2059,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
2156
2059
  aug_emb = self.add_embedding(add_embeds)
2157
2060
 
2158
2061
  emb = emb if aug_emb is None else emb + aug_emb
2159
- emb = emb.repeat_interleave(repeats=num_frames, dim=0)
2062
+ emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
2160
2063
 
2161
2064
  if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
2162
2065
  if "image_embeds" not in added_cond_kwargs:
@@ -2165,7 +2068,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
2165
2068
  )
2166
2069
  image_embeds = added_cond_kwargs.get("image_embeds")
2167
2070
  image_embeds = self.encoder_hid_proj(image_embeds)
2168
- image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
2071
+ image_embeds = [
2072
+ image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames)
2073
+ for image_embed in image_embeds
2074
+ ]
2169
2075
  encoder_hidden_states = (encoder_hidden_states, image_embeds)
2170
2076
 
2171
2077
  # 2. pre-process
@@ -320,10 +320,6 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
320
320
 
321
321
  self.set_attn_processor(processor)
322
322
 
323
- def _set_gradient_checkpointing(self, module, value=False):
324
- if hasattr(module, "gradient_checkpointing"):
325
- module.gradient_checkpointing = value
326
-
327
323
  # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
328
324
  def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
329
325
  """
@@ -402,10 +398,11 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
402
398
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
403
399
  # This would be a good case for the `match` statement (Python 3.10+)
404
400
  is_mps = sample.device.type == "mps"
401
+ is_npu = sample.device.type == "npu"
405
402
  if isinstance(timestep, float):
406
- dtype = torch.float32 if is_mps else torch.float64
403
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
407
404
  else:
408
- dtype = torch.int32 if is_mps else torch.int64
405
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
409
406
  timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
410
407
  elif len(timesteps.shape) == 0:
411
408
  timesteps = timesteps[None].to(sample.device)
@@ -434,9 +431,11 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
434
431
  sample = sample.flatten(0, 1)
435
432
  # Repeat the embeddings num_video_frames times
436
433
  # emb: [batch, channels] -> [batch * frames, channels]
437
- emb = emb.repeat_interleave(num_frames, dim=0)
434
+ emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
438
435
  # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
439
- encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
436
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(
437
+ num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
438
+ )
440
439
 
441
440
  # 2. pre-process
442
441
  sample = self.conv_in(sample)
@@ -387,9 +387,6 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
387
387
 
388
388
  self.gradient_checkpointing = False
389
389
 
390
- def _set_gradient_checkpointing(self, value=False):
391
- self.gradient_checkpointing = value
392
-
393
390
  def _init_weights(self, m):
394
391
  if isinstance(m, (nn.Conv2d, nn.Linear)):
395
392
  torch.nn.init.xavier_uniform_(m.weight)
@@ -456,29 +453,18 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
456
453
  block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
457
454
 
458
455
  if torch.is_grad_enabled() and self.gradient_checkpointing:
459
-
460
- def create_custom_forward(module):
461
- def custom_forward(*inputs):
462
- return module(*inputs)
463
-
464
- return custom_forward
465
-
466
456
  for down_block, downscaler, repmap in block_group:
467
457
  x = downscaler(x)
468
458
  for i in range(len(repmap) + 1):
469
459
  for block in down_block:
470
460
  if isinstance(block, SDCascadeResBlock):
471
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
461
+ x = self._gradient_checkpointing_func(block, x)
472
462
  elif isinstance(block, SDCascadeAttnBlock):
473
- x = torch.utils.checkpoint.checkpoint(
474
- create_custom_forward(block), x, clip, use_reentrant=False
475
- )
463
+ x = self._gradient_checkpointing_func(block, x, clip)
476
464
  elif isinstance(block, SDCascadeTimestepBlock):
477
- x = torch.utils.checkpoint.checkpoint(
478
- create_custom_forward(block), x, r_embed, use_reentrant=False
479
- )
465
+ x = self._gradient_checkpointing_func(block, x, r_embed)
480
466
  else:
481
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False)
467
+ x = self._gradient_checkpointing_func(block)
482
468
  if i < len(repmap):
483
469
  x = repmap[i](x)
484
470
  level_outputs.insert(0, x)
@@ -505,13 +491,6 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
505
491
  block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
506
492
 
507
493
  if torch.is_grad_enabled() and self.gradient_checkpointing:
508
-
509
- def create_custom_forward(module):
510
- def custom_forward(*inputs):
511
- return module(*inputs)
512
-
513
- return custom_forward
514
-
515
494
  for i, (up_block, upscaler, repmap) in enumerate(block_group):
516
495
  for j in range(len(repmap) + 1):
517
496
  for k, block in enumerate(up_block):
@@ -523,19 +502,13 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
523
502
  x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
524
503
  )
525
504
  x = x.to(orig_type)
526
- x = torch.utils.checkpoint.checkpoint(
527
- create_custom_forward(block), x, skip, use_reentrant=False
528
- )
505
+ x = self._gradient_checkpointing_func(block, x, skip)
529
506
  elif isinstance(block, SDCascadeAttnBlock):
530
- x = torch.utils.checkpoint.checkpoint(
531
- create_custom_forward(block), x, clip, use_reentrant=False
532
- )
507
+ x = self._gradient_checkpointing_func(block, x, clip)
533
508
  elif isinstance(block, SDCascadeTimestepBlock):
534
- x = torch.utils.checkpoint.checkpoint(
535
- create_custom_forward(block), x, r_embed, use_reentrant=False
536
- )
509
+ x = self._gradient_checkpointing_func(block, x, r_embed)
537
510
  else:
538
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
511
+ x = self._gradient_checkpointing_func(block, x)
539
512
  if j < len(repmap):
540
513
  x = repmap[j](x)
541
514
  x = upscaler(x)
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team.
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -148,9 +148,6 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
148
148
 
149
149
  self.gradient_checkpointing = False
150
150
 
151
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
152
- pass
153
-
154
151
  def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
155
152
  encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
156
153
  encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
diffusers/optimization.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team.
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -258,7 +258,7 @@ def get_polynomial_decay_schedule_with_warmup(
258
258
 
259
259
  lr_init = optimizer.defaults["lr"]
260
260
  if not (lr_init > lr_end):
261
- raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
261
+ raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
262
262
 
263
263
  def lr_lambda(current_step: int):
264
264
  if current_step < num_warmup_steps: