diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import torch
18
18
  import torch.nn.functional as F
19
19
  from torch import nn
20
20
 
21
- from ...utils import deprecate, is_torch_version, logging
21
+ from ...utils import deprecate, logging
22
22
  from ...utils.torch_utils import apply_freeu
23
23
  from ..activations import get_activation
24
24
  from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
@@ -737,25 +737,9 @@ class UNetMidBlock2D(nn.Module):
737
737
  hidden_states = self.resnets[0](hidden_states, temb)
738
738
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
739
739
  if torch.is_grad_enabled() and self.gradient_checkpointing:
740
-
741
- def create_custom_forward(module, return_dict=None):
742
- def custom_forward(*inputs):
743
- if return_dict is not None:
744
- return module(*inputs, return_dict=return_dict)
745
- else:
746
- return module(*inputs)
747
-
748
- return custom_forward
749
-
750
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
751
740
  if attn is not None:
752
741
  hidden_states = attn(hidden_states, temb=temb)
753
- hidden_states = torch.utils.checkpoint.checkpoint(
754
- create_custom_forward(resnet),
755
- hidden_states,
756
- temb,
757
- **ckpt_kwargs,
758
- )
742
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
759
743
  else:
760
744
  if attn is not None:
761
745
  hidden_states = attn(hidden_states, temb=temb)
@@ -883,17 +867,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
883
867
  hidden_states = self.resnets[0](hidden_states, temb)
884
868
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
885
869
  if torch.is_grad_enabled() and self.gradient_checkpointing:
886
-
887
- def create_custom_forward(module, return_dict=None):
888
- def custom_forward(*inputs):
889
- if return_dict is not None:
890
- return module(*inputs, return_dict=return_dict)
891
- else:
892
- return module(*inputs)
893
-
894
- return custom_forward
895
-
896
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
897
870
  hidden_states = attn(
898
871
  hidden_states,
899
872
  encoder_hidden_states=encoder_hidden_states,
@@ -902,12 +875,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
902
875
  encoder_attention_mask=encoder_attention_mask,
903
876
  return_dict=False,
904
877
  )[0]
905
- hidden_states = torch.utils.checkpoint.checkpoint(
906
- create_custom_forward(resnet),
907
- hidden_states,
908
- temb,
909
- **ckpt_kwargs,
910
- )
878
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
911
879
  else:
912
880
  hidden_states = attn(
913
881
  hidden_states,
@@ -1156,23 +1124,7 @@ class AttnDownBlock2D(nn.Module):
1156
1124
 
1157
1125
  for resnet, attn in zip(self.resnets, self.attentions):
1158
1126
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1159
-
1160
- def create_custom_forward(module, return_dict=None):
1161
- def custom_forward(*inputs):
1162
- if return_dict is not None:
1163
- return module(*inputs, return_dict=return_dict)
1164
- else:
1165
- return module(*inputs)
1166
-
1167
- return custom_forward
1168
-
1169
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1170
- hidden_states = torch.utils.checkpoint.checkpoint(
1171
- create_custom_forward(resnet),
1172
- hidden_states,
1173
- temb,
1174
- **ckpt_kwargs,
1175
- )
1127
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
1176
1128
  hidden_states = attn(hidden_states, **cross_attention_kwargs)
1177
1129
  output_states = output_states + (hidden_states,)
1178
1130
  else:
@@ -1304,23 +1256,7 @@ class CrossAttnDownBlock2D(nn.Module):
1304
1256
 
1305
1257
  for i, (resnet, attn) in enumerate(blocks):
1306
1258
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1307
-
1308
- def create_custom_forward(module, return_dict=None):
1309
- def custom_forward(*inputs):
1310
- if return_dict is not None:
1311
- return module(*inputs, return_dict=return_dict)
1312
- else:
1313
- return module(*inputs)
1314
-
1315
- return custom_forward
1316
-
1317
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1318
- hidden_states = torch.utils.checkpoint.checkpoint(
1319
- create_custom_forward(resnet),
1320
- hidden_states,
1321
- temb,
1322
- **ckpt_kwargs,
1323
- )
1259
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
1324
1260
  hidden_states = attn(
1325
1261
  hidden_states,
1326
1262
  encoder_hidden_states=encoder_hidden_states,
@@ -1418,21 +1354,7 @@ class DownBlock2D(nn.Module):
1418
1354
 
1419
1355
  for resnet in self.resnets:
1420
1356
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1421
-
1422
- def create_custom_forward(module):
1423
- def custom_forward(*inputs):
1424
- return module(*inputs)
1425
-
1426
- return custom_forward
1427
-
1428
- if is_torch_version(">=", "1.11.0"):
1429
- hidden_states = torch.utils.checkpoint.checkpoint(
1430
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1431
- )
1432
- else:
1433
- hidden_states = torch.utils.checkpoint.checkpoint(
1434
- create_custom_forward(resnet), hidden_states, temb
1435
- )
1357
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
1436
1358
  else:
1437
1359
  hidden_states = resnet(hidden_states, temb)
1438
1360
 
@@ -1906,21 +1828,7 @@ class ResnetDownsampleBlock2D(nn.Module):
1906
1828
 
1907
1829
  for resnet in self.resnets:
1908
1830
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1909
-
1910
- def create_custom_forward(module):
1911
- def custom_forward(*inputs):
1912
- return module(*inputs)
1913
-
1914
- return custom_forward
1915
-
1916
- if is_torch_version(">=", "1.11.0"):
1917
- hidden_states = torch.utils.checkpoint.checkpoint(
1918
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1919
- )
1920
- else:
1921
- hidden_states = torch.utils.checkpoint.checkpoint(
1922
- create_custom_forward(resnet), hidden_states, temb
1923
- )
1831
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
1924
1832
  else:
1925
1833
  hidden_states = resnet(hidden_states, temb)
1926
1834
 
@@ -2058,17 +1966,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
2058
1966
 
2059
1967
  for resnet, attn in zip(self.resnets, self.attentions):
2060
1968
  if torch.is_grad_enabled() and self.gradient_checkpointing:
2061
-
2062
- def create_custom_forward(module, return_dict=None):
2063
- def custom_forward(*inputs):
2064
- if return_dict is not None:
2065
- return module(*inputs, return_dict=return_dict)
2066
- else:
2067
- return module(*inputs)
2068
-
2069
- return custom_forward
2070
-
2071
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1969
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
2072
1970
  hidden_states = attn(
2073
1971
  hidden_states,
2074
1972
  encoder_hidden_states=encoder_hidden_states,
@@ -2153,21 +2051,7 @@ class KDownBlock2D(nn.Module):
2153
2051
 
2154
2052
  for resnet in self.resnets:
2155
2053
  if torch.is_grad_enabled() and self.gradient_checkpointing:
2156
-
2157
- def create_custom_forward(module):
2158
- def custom_forward(*inputs):
2159
- return module(*inputs)
2160
-
2161
- return custom_forward
2162
-
2163
- if is_torch_version(">=", "1.11.0"):
2164
- hidden_states = torch.utils.checkpoint.checkpoint(
2165
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
2166
- )
2167
- else:
2168
- hidden_states = torch.utils.checkpoint.checkpoint(
2169
- create_custom_forward(resnet), hidden_states, temb
2170
- )
2054
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
2171
2055
  else:
2172
2056
  hidden_states = resnet(hidden_states, temb)
2173
2057
 
@@ -2262,22 +2146,10 @@ class KCrossAttnDownBlock2D(nn.Module):
2262
2146
 
2263
2147
  for resnet, attn in zip(self.resnets, self.attentions):
2264
2148
  if torch.is_grad_enabled() and self.gradient_checkpointing:
2265
-
2266
- def create_custom_forward(module, return_dict=None):
2267
- def custom_forward(*inputs):
2268
- if return_dict is not None:
2269
- return module(*inputs, return_dict=return_dict)
2270
- else:
2271
- return module(*inputs)
2272
-
2273
- return custom_forward
2274
-
2275
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2276
- hidden_states = torch.utils.checkpoint.checkpoint(
2277
- create_custom_forward(resnet),
2149
+ hidden_states = self._gradient_checkpointing_func(
2150
+ resnet,
2278
2151
  hidden_states,
2279
2152
  temb,
2280
- **ckpt_kwargs,
2281
2153
  )
2282
2154
  hidden_states = attn(
2283
2155
  hidden_states,
@@ -2423,23 +2295,7 @@ class AttnUpBlock2D(nn.Module):
2423
2295
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2424
2296
 
2425
2297
  if torch.is_grad_enabled() and self.gradient_checkpointing:
2426
-
2427
- def create_custom_forward(module, return_dict=None):
2428
- def custom_forward(*inputs):
2429
- if return_dict is not None:
2430
- return module(*inputs, return_dict=return_dict)
2431
- else:
2432
- return module(*inputs)
2433
-
2434
- return custom_forward
2435
-
2436
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2437
- hidden_states = torch.utils.checkpoint.checkpoint(
2438
- create_custom_forward(resnet),
2439
- hidden_states,
2440
- temb,
2441
- **ckpt_kwargs,
2442
- )
2298
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
2443
2299
  hidden_states = attn(hidden_states)
2444
2300
  else:
2445
2301
  hidden_states = resnet(hidden_states, temb)
@@ -2588,23 +2444,7 @@ class CrossAttnUpBlock2D(nn.Module):
2588
2444
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2589
2445
 
2590
2446
  if torch.is_grad_enabled() and self.gradient_checkpointing:
2591
-
2592
- def create_custom_forward(module, return_dict=None):
2593
- def custom_forward(*inputs):
2594
- if return_dict is not None:
2595
- return module(*inputs, return_dict=return_dict)
2596
- else:
2597
- return module(*inputs)
2598
-
2599
- return custom_forward
2600
-
2601
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2602
- hidden_states = torch.utils.checkpoint.checkpoint(
2603
- create_custom_forward(resnet),
2604
- hidden_states,
2605
- temb,
2606
- **ckpt_kwargs,
2607
- )
2447
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
2608
2448
  hidden_states = attn(
2609
2449
  hidden_states,
2610
2450
  encoder_hidden_states=encoder_hidden_states,
@@ -2721,21 +2561,7 @@ class UpBlock2D(nn.Module):
2721
2561
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2722
2562
 
2723
2563
  if torch.is_grad_enabled() and self.gradient_checkpointing:
2724
-
2725
- def create_custom_forward(module):
2726
- def custom_forward(*inputs):
2727
- return module(*inputs)
2728
-
2729
- return custom_forward
2730
-
2731
- if is_torch_version(">=", "1.11.0"):
2732
- hidden_states = torch.utils.checkpoint.checkpoint(
2733
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
2734
- )
2735
- else:
2736
- hidden_states = torch.utils.checkpoint.checkpoint(
2737
- create_custom_forward(resnet), hidden_states, temb
2738
- )
2564
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
2739
2565
  else:
2740
2566
  hidden_states = resnet(hidden_states, temb)
2741
2567
 
@@ -3251,21 +3077,7 @@ class ResnetUpsampleBlock2D(nn.Module):
3251
3077
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
3252
3078
 
3253
3079
  if torch.is_grad_enabled() and self.gradient_checkpointing:
3254
-
3255
- def create_custom_forward(module):
3256
- def custom_forward(*inputs):
3257
- return module(*inputs)
3258
-
3259
- return custom_forward
3260
-
3261
- if is_torch_version(">=", "1.11.0"):
3262
- hidden_states = torch.utils.checkpoint.checkpoint(
3263
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
3264
- )
3265
- else:
3266
- hidden_states = torch.utils.checkpoint.checkpoint(
3267
- create_custom_forward(resnet), hidden_states, temb
3268
- )
3080
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
3269
3081
  else:
3270
3082
  hidden_states = resnet(hidden_states, temb)
3271
3083
 
@@ -3409,17 +3221,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
3409
3221
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
3410
3222
 
3411
3223
  if torch.is_grad_enabled() and self.gradient_checkpointing:
3412
-
3413
- def create_custom_forward(module, return_dict=None):
3414
- def custom_forward(*inputs):
3415
- if return_dict is not None:
3416
- return module(*inputs, return_dict=return_dict)
3417
- else:
3418
- return module(*inputs)
3419
-
3420
- return custom_forward
3421
-
3422
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
3224
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
3423
3225
  hidden_states = attn(
3424
3226
  hidden_states,
3425
3227
  encoder_hidden_states=encoder_hidden_states,
@@ -3512,21 +3314,7 @@ class KUpBlock2D(nn.Module):
3512
3314
 
3513
3315
  for resnet in self.resnets:
3514
3316
  if torch.is_grad_enabled() and self.gradient_checkpointing:
3515
-
3516
- def create_custom_forward(module):
3517
- def custom_forward(*inputs):
3518
- return module(*inputs)
3519
-
3520
- return custom_forward
3521
-
3522
- if is_torch_version(">=", "1.11.0"):
3523
- hidden_states = torch.utils.checkpoint.checkpoint(
3524
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
3525
- )
3526
- else:
3527
- hidden_states = torch.utils.checkpoint.checkpoint(
3528
- create_custom_forward(resnet), hidden_states, temb
3529
- )
3317
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
3530
3318
  else:
3531
3319
  hidden_states = resnet(hidden_states, temb)
3532
3320
 
@@ -3640,22 +3428,10 @@ class KCrossAttnUpBlock2D(nn.Module):
3640
3428
 
3641
3429
  for resnet, attn in zip(self.resnets, self.attentions):
3642
3430
  if torch.is_grad_enabled() and self.gradient_checkpointing:
3643
-
3644
- def create_custom_forward(module, return_dict=None):
3645
- def custom_forward(*inputs):
3646
- if return_dict is not None:
3647
- return module(*inputs, return_dict=return_dict)
3648
- else:
3649
- return module(*inputs)
3650
-
3651
- return custom_forward
3652
-
3653
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
3654
- hidden_states = torch.utils.checkpoint.checkpoint(
3655
- create_custom_forward(resnet),
3431
+ hidden_states = self._gradient_checkpointing_func(
3432
+ resnet,
3656
3433
  hidden_states,
3657
3434
  temb,
3658
- **ckpt_kwargs,
3659
3435
  )
3660
3436
  hidden_states = attn(
3661
3437
  hidden_states,
@@ -166,6 +166,7 @@ class UNet2DConditionModel(
166
166
 
167
167
  _supports_gradient_checkpointing = True
168
168
  _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
169
+ _skip_layerwise_casting_patterns = ["norm"]
169
170
 
170
171
  @register_to_config
171
172
  def __init__(
@@ -833,10 +834,6 @@ class UNet2DConditionModel(
833
834
  for module in self.children():
834
835
  fn_recursive_set_attention_slice(module, reversed_slice_size)
835
836
 
836
- def _set_gradient_checkpointing(self, module, value=False):
837
- if hasattr(module, "gradient_checkpointing"):
838
- module.gradient_checkpointing = value
839
-
840
837
  def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
841
838
  r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
842
839
 
@@ -915,10 +912,11 @@ class UNet2DConditionModel(
915
912
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
916
913
  # This would be a good case for the `match` statement (Python 3.10+)
917
914
  is_mps = sample.device.type == "mps"
915
+ is_npu = sample.device.type == "npu"
918
916
  if isinstance(timestep, float):
919
- dtype = torch.float32 if is_mps else torch.float64
917
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
920
918
  else:
921
- dtype = torch.int32 if is_mps else torch.int64
919
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
922
920
  timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
923
921
  elif len(timesteps.shape) == 0:
924
922
  timesteps = timesteps[None].to(sample.device)
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
17
17
  import torch
18
18
  from torch import nn
19
19
 
20
- from ...utils import deprecate, is_torch_version, logging
20
+ from ...utils import deprecate, logging
21
21
  from ...utils.torch_utils import apply_freeu
22
22
  from ..attention import Attention
23
23
  from ..resnet import (
@@ -1078,31 +1078,14 @@ class UNetMidBlockSpatioTemporal(nn.Module):
1078
1078
  )
1079
1079
 
1080
1080
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
1081
- if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
1082
-
1083
- def create_custom_forward(module, return_dict=None):
1084
- def custom_forward(*inputs):
1085
- if return_dict is not None:
1086
- return module(*inputs, return_dict=return_dict)
1087
- else:
1088
- return module(*inputs)
1089
-
1090
- return custom_forward
1091
-
1092
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1081
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1093
1082
  hidden_states = attn(
1094
1083
  hidden_states,
1095
1084
  encoder_hidden_states=encoder_hidden_states,
1096
1085
  image_only_indicator=image_only_indicator,
1097
1086
  return_dict=False,
1098
1087
  )[0]
1099
- hidden_states = torch.utils.checkpoint.checkpoint(
1100
- create_custom_forward(resnet),
1101
- hidden_states,
1102
- temb,
1103
- image_only_indicator,
1104
- **ckpt_kwargs,
1105
- )
1088
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
1106
1089
  else:
1107
1090
  hidden_states = attn(
1108
1091
  hidden_states,
@@ -1110,11 +1093,7 @@ class UNetMidBlockSpatioTemporal(nn.Module):
1110
1093
  image_only_indicator=image_only_indicator,
1111
1094
  return_dict=False,
1112
1095
  )[0]
1113
- hidden_states = resnet(
1114
- hidden_states,
1115
- temb,
1116
- image_only_indicator=image_only_indicator,
1117
- )
1096
+ hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
1118
1097
 
1119
1098
  return hidden_states
1120
1099
 
@@ -1169,34 +1148,9 @@ class DownBlockSpatioTemporal(nn.Module):
1169
1148
  output_states = ()
1170
1149
  for resnet in self.resnets:
1171
1150
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1172
-
1173
- def create_custom_forward(module):
1174
- def custom_forward(*inputs):
1175
- return module(*inputs)
1176
-
1177
- return custom_forward
1178
-
1179
- if is_torch_version(">=", "1.11.0"):
1180
- hidden_states = torch.utils.checkpoint.checkpoint(
1181
- create_custom_forward(resnet),
1182
- hidden_states,
1183
- temb,
1184
- image_only_indicator,
1185
- use_reentrant=False,
1186
- )
1187
- else:
1188
- hidden_states = torch.utils.checkpoint.checkpoint(
1189
- create_custom_forward(resnet),
1190
- hidden_states,
1191
- temb,
1192
- image_only_indicator,
1193
- )
1151
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
1194
1152
  else:
1195
- hidden_states = resnet(
1196
- hidden_states,
1197
- temb,
1198
- image_only_indicator=image_only_indicator,
1199
- )
1153
+ hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
1200
1154
 
1201
1155
  output_states = output_states + (hidden_states,)
1202
1156
 
@@ -1281,25 +1235,8 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
1281
1235
 
1282
1236
  blocks = list(zip(self.resnets, self.attentions))
1283
1237
  for resnet, attn in blocks:
1284
- if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
1285
-
1286
- def create_custom_forward(module, return_dict=None):
1287
- def custom_forward(*inputs):
1288
- if return_dict is not None:
1289
- return module(*inputs, return_dict=return_dict)
1290
- else:
1291
- return module(*inputs)
1292
-
1293
- return custom_forward
1294
-
1295
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1296
- hidden_states = torch.utils.checkpoint.checkpoint(
1297
- create_custom_forward(resnet),
1298
- hidden_states,
1299
- temb,
1300
- image_only_indicator,
1301
- **ckpt_kwargs,
1302
- )
1238
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1239
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
1303
1240
 
1304
1241
  hidden_states = attn(
1305
1242
  hidden_states,
@@ -1308,11 +1245,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
1308
1245
  return_dict=False,
1309
1246
  )[0]
1310
1247
  else:
1311
- hidden_states = resnet(
1312
- hidden_states,
1313
- temb,
1314
- image_only_indicator=image_only_indicator,
1315
- )
1248
+ hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
1316
1249
  hidden_states = attn(
1317
1250
  hidden_states,
1318
1251
  encoder_hidden_states=encoder_hidden_states,
@@ -1385,34 +1318,9 @@ class UpBlockSpatioTemporal(nn.Module):
1385
1318
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1386
1319
 
1387
1320
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1388
-
1389
- def create_custom_forward(module):
1390
- def custom_forward(*inputs):
1391
- return module(*inputs)
1392
-
1393
- return custom_forward
1394
-
1395
- if is_torch_version(">=", "1.11.0"):
1396
- hidden_states = torch.utils.checkpoint.checkpoint(
1397
- create_custom_forward(resnet),
1398
- hidden_states,
1399
- temb,
1400
- image_only_indicator,
1401
- use_reentrant=False,
1402
- )
1403
- else:
1404
- hidden_states = torch.utils.checkpoint.checkpoint(
1405
- create_custom_forward(resnet),
1406
- hidden_states,
1407
- temb,
1408
- image_only_indicator,
1409
- )
1321
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
1410
1322
  else:
1411
- hidden_states = resnet(
1412
- hidden_states,
1413
- temb,
1414
- image_only_indicator=image_only_indicator,
1415
- )
1323
+ hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
1416
1324
 
1417
1325
  if self.upsamplers is not None:
1418
1326
  for upsampler in self.upsamplers:
@@ -1495,25 +1403,8 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
1495
1403
 
1496
1404
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1497
1405
 
1498
- if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
1499
-
1500
- def create_custom_forward(module, return_dict=None):
1501
- def custom_forward(*inputs):
1502
- if return_dict is not None:
1503
- return module(*inputs, return_dict=return_dict)
1504
- else:
1505
- return module(*inputs)
1506
-
1507
- return custom_forward
1508
-
1509
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1510
- hidden_states = torch.utils.checkpoint.checkpoint(
1511
- create_custom_forward(resnet),
1512
- hidden_states,
1513
- temb,
1514
- image_only_indicator,
1515
- **ckpt_kwargs,
1516
- )
1406
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1407
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
1517
1408
  hidden_states = attn(
1518
1409
  hidden_states,
1519
1410
  encoder_hidden_states=encoder_hidden_states,
@@ -1521,11 +1412,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
1521
1412
  return_dict=False,
1522
1413
  )[0]
1523
1414
  else:
1524
- hidden_states = resnet(
1525
- hidden_states,
1526
- temb,
1527
- image_only_indicator=image_only_indicator,
1528
- )
1415
+ hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
1529
1416
  hidden_states = attn(
1530
1417
  hidden_states,
1531
1418
  encoder_hidden_states=encoder_hidden_states,