diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ import torch
19
19
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
20
20
 
21
21
  from ...image_processor import PipelineImageInput
22
- from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
22
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
23
23
  from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
24
24
  from ...models.lora import adjust_lora_scale_text_encoder
25
25
  from ...models.unets.unet_motion_model import MotionAdapter
@@ -31,7 +31,7 @@ from ...schedulers import (
31
31
  LMSDiscreteScheduler,
32
32
  PNDMScheduler,
33
33
  )
34
- from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
34
+ from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
35
35
  from ...utils.torch_utils import randn_tensor
36
36
  from ...video_processor import VideoProcessor
37
37
  from ..free_init_utils import FreeInitMixin
@@ -40,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
40
40
  from .pipeline_output import AnimateDiffPipelineOutput
41
41
 
42
42
 
43
+ if is_torch_xla_available():
44
+ import torch_xla.core.xla_model as xm
45
+
46
+ XLA_AVAILABLE = True
47
+ else:
48
+ XLA_AVAILABLE = False
49
+
43
50
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
51
 
52
+
45
53
  EXAMPLE_DOC_STRING = """
46
54
  Examples:
47
55
  ```py
@@ -178,6 +186,7 @@ class AnimateDiffVideoToVideoPipeline(
178
186
  StableDiffusionLoraLoaderMixin,
179
187
  FreeInitMixin,
180
188
  AnimateDiffFreeNoiseMixin,
189
+ FromSingleFileMixin,
181
190
  ):
182
191
  r"""
183
192
  Pipeline for video-to-video generation.
@@ -216,7 +225,7 @@ class AnimateDiffVideoToVideoPipeline(
216
225
  vae: AutoencoderKL,
217
226
  text_encoder: CLIPTextModel,
218
227
  tokenizer: CLIPTokenizer,
219
- unet: UNet2DConditionModel,
228
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
220
229
  motion_adapter: MotionAdapter,
221
230
  scheduler: Union[
222
231
  DDIMScheduler,
@@ -243,7 +252,7 @@ class AnimateDiffVideoToVideoPipeline(
243
252
  feature_extractor=feature_extractor,
244
253
  image_encoder=image_encoder,
245
254
  )
246
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
255
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
247
256
  self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
248
257
 
249
258
  def encode_prompt(
@@ -1037,6 +1046,9 @@ class AnimateDiffVideoToVideoPipeline(
1037
1046
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1038
1047
  progress_bar.update()
1039
1048
 
1049
+ if XLA_AVAILABLE:
1050
+ xm.mark_step()
1051
+
1040
1052
  # 10. Post-processing
1041
1053
  if output_type == "latent":
1042
1054
  video = latents
@@ -20,7 +20,7 @@ import torch.nn.functional as F
20
20
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
21
21
 
22
22
  from ...image_processor import PipelineImageInput
23
- from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
23
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
24
24
  from ...models import (
25
25
  AutoencoderKL,
26
26
  ControlNetModel,
@@ -39,7 +39,7 @@ from ...schedulers import (
39
39
  LMSDiscreteScheduler,
40
40
  PNDMScheduler,
41
41
  )
42
- from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
42
+ from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
43
43
  from ...utils.torch_utils import is_compiled_module, randn_tensor
44
44
  from ...video_processor import VideoProcessor
45
45
  from ..free_init_utils import FreeInitMixin
@@ -48,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
48
48
  from .pipeline_output import AnimateDiffPipelineOutput
49
49
 
50
50
 
51
+ if is_torch_xla_available():
52
+ import torch_xla.core.xla_model as xm
53
+
54
+ XLA_AVAILABLE = True
55
+ else:
56
+ XLA_AVAILABLE = False
57
+
51
58
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
59
 
60
+
53
61
  EXAMPLE_DOC_STRING = """
54
62
  Examples:
55
63
  ```py
@@ -196,6 +204,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
196
204
  StableDiffusionLoraLoaderMixin,
197
205
  FreeInitMixin,
198
206
  AnimateDiffFreeNoiseMixin,
207
+ FromSingleFileMixin,
199
208
  ):
200
209
  r"""
201
210
  Pipeline for video-to-video generation with ControlNet guidance.
@@ -238,7 +247,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
238
247
  vae: AutoencoderKL,
239
248
  text_encoder: CLIPTextModel,
240
249
  tokenizer: CLIPTokenizer,
241
- unet: UNet2DConditionModel,
250
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
242
251
  motion_adapter: MotionAdapter,
243
252
  controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
244
253
  scheduler: Union[
@@ -270,7 +279,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
270
279
  feature_extractor=feature_extractor,
271
280
  image_encoder=image_encoder,
272
281
  )
273
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
282
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
274
283
  self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
275
284
  self.control_video_processor = VideoProcessor(
276
285
  vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -1325,6 +1334,9 @@ class AnimateDiffVideoToVideoControlNetPipeline(
1325
1334
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1326
1335
  progress_bar.update()
1327
1336
 
1337
+ if XLA_AVAILABLE:
1338
+ xm.mark_step()
1339
+
1328
1340
  # 11. Post-processing
1329
1341
  if output_type == "latent":
1330
1342
  video = latents
@@ -22,13 +22,21 @@ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaT
22
22
 
23
23
  from ...models import AutoencoderKL, UNet2DConditionModel
24
24
  from ...schedulers import KarrasDiffusionSchedulers
25
- from ...utils import logging, replace_example_docstring
25
+ from ...utils import is_torch_xla_available, logging, replace_example_docstring
26
26
  from ...utils.torch_utils import randn_tensor
27
27
  from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin
28
28
 
29
29
 
30
+ if is_torch_xla_available():
31
+ import torch_xla.core.xla_model as xm
32
+
33
+ XLA_AVAILABLE = True
34
+ else:
35
+ XLA_AVAILABLE = False
36
+
30
37
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
38
 
39
+
32
40
  EXAMPLE_DOC_STRING = """
33
41
  Examples:
34
42
  ```py
@@ -94,7 +102,7 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
94
102
  scheduler=scheduler,
95
103
  vocoder=vocoder,
96
104
  )
97
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
105
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
98
106
 
99
107
  def _encode_prompt(
100
108
  self,
@@ -530,6 +538,9 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
530
538
  step_idx = i // getattr(self.scheduler, "order", 1)
531
539
  callback(step_idx, t, latents)
532
540
 
541
+ if XLA_AVAILABLE:
542
+ xm.mark_step()
543
+
533
544
  # 8. Post-processing
534
545
  mel_spectrogram = self.decode_latents(latents)
535
546
 
@@ -38,7 +38,7 @@ from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
38
38
  from ...models.transformers.transformer_2d import Transformer2DModel
39
39
  from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
40
40
  from ...models.unets.unet_2d_condition import UNet2DConditionOutput
41
- from ...utils import BaseOutput, is_torch_version, logging
41
+ from ...utils import BaseOutput, logging
42
42
 
43
43
 
44
44
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -673,11 +673,6 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
673
673
  for module in self.children():
674
674
  fn_recursive_set_attention_slice(module, reversed_slice_size)
675
675
 
676
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
677
- def _set_gradient_checkpointing(self, module, value=False):
678
- if hasattr(module, "gradient_checkpointing"):
679
- module.gradient_checkpointing = value
680
-
681
676
  def forward(
682
677
  self,
683
678
  sample: torch.Tensor,
@@ -768,10 +763,11 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
768
763
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
769
764
  # This would be a good case for the `match` statement (Python 3.10+)
770
765
  is_mps = sample.device.type == "mps"
766
+ is_npu = sample.device.type == "npu"
771
767
  if isinstance(timestep, float):
772
- dtype = torch.float32 if is_mps else torch.float64
768
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
773
769
  else:
774
- dtype = torch.int32 if is_mps else torch.int64
770
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
775
771
  timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
776
772
  elif len(timesteps.shape) == 0:
777
773
  timesteps = timesteps[None].to(sample.device)
@@ -1113,23 +1109,7 @@ class CrossAttnDownBlock2D(nn.Module):
1113
1109
 
1114
1110
  for i in range(num_layers):
1115
1111
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1116
-
1117
- def create_custom_forward(module, return_dict=None):
1118
- def custom_forward(*inputs):
1119
- if return_dict is not None:
1120
- return module(*inputs, return_dict=return_dict)
1121
- else:
1122
- return module(*inputs)
1123
-
1124
- return custom_forward
1125
-
1126
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1127
- hidden_states = torch.utils.checkpoint.checkpoint(
1128
- create_custom_forward(self.resnets[i]),
1129
- hidden_states,
1130
- temb,
1131
- **ckpt_kwargs,
1132
- )
1112
+ hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
1133
1113
  for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1134
1114
  if cross_attention_dim is not None and idx <= 1:
1135
1115
  forward_encoder_hidden_states = encoder_hidden_states
@@ -1140,8 +1120,8 @@ class CrossAttnDownBlock2D(nn.Module):
1140
1120
  else:
1141
1121
  forward_encoder_hidden_states = None
1142
1122
  forward_encoder_attention_mask = None
1143
- hidden_states = torch.utils.checkpoint.checkpoint(
1144
- create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
1123
+ hidden_states = self._gradient_checkpointing_func(
1124
+ self.attentions[i * num_attention_per_layer + idx],
1145
1125
  hidden_states,
1146
1126
  forward_encoder_hidden_states,
1147
1127
  None, # timestep
@@ -1149,7 +1129,6 @@ class CrossAttnDownBlock2D(nn.Module):
1149
1129
  cross_attention_kwargs,
1150
1130
  attention_mask,
1151
1131
  forward_encoder_attention_mask,
1152
- **ckpt_kwargs,
1153
1132
  )[0]
1154
1133
  else:
1155
1134
  hidden_states = self.resnets[i](hidden_states, temb)
@@ -1291,17 +1270,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
1291
1270
 
1292
1271
  for i in range(len(self.resnets[1:])):
1293
1272
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1294
-
1295
- def create_custom_forward(module, return_dict=None):
1296
- def custom_forward(*inputs):
1297
- if return_dict is not None:
1298
- return module(*inputs, return_dict=return_dict)
1299
- else:
1300
- return module(*inputs)
1301
-
1302
- return custom_forward
1303
-
1304
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1305
1273
  for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1306
1274
  if cross_attention_dim is not None and idx <= 1:
1307
1275
  forward_encoder_hidden_states = encoder_hidden_states
@@ -1312,8 +1280,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
1312
1280
  else:
1313
1281
  forward_encoder_hidden_states = None
1314
1282
  forward_encoder_attention_mask = None
1315
- hidden_states = torch.utils.checkpoint.checkpoint(
1316
- create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
1283
+ hidden_states = self._gradient_checkpointing_func(
1284
+ self.attentions[i * num_attention_per_layer + idx],
1317
1285
  hidden_states,
1318
1286
  forward_encoder_hidden_states,
1319
1287
  None, # timestep
@@ -1321,14 +1289,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
1321
1289
  cross_attention_kwargs,
1322
1290
  attention_mask,
1323
1291
  forward_encoder_attention_mask,
1324
- **ckpt_kwargs,
1325
1292
  )[0]
1326
- hidden_states = torch.utils.checkpoint.checkpoint(
1327
- create_custom_forward(self.resnets[i + 1]),
1328
- hidden_states,
1329
- temb,
1330
- **ckpt_kwargs,
1331
- )
1293
+ hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb)
1332
1294
  else:
1333
1295
  for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1334
1296
  if cross_attention_dim is not None and idx <= 1:
@@ -1465,23 +1427,7 @@ class CrossAttnUpBlock2D(nn.Module):
1465
1427
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1466
1428
 
1467
1429
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1468
-
1469
- def create_custom_forward(module, return_dict=None):
1470
- def custom_forward(*inputs):
1471
- if return_dict is not None:
1472
- return module(*inputs, return_dict=return_dict)
1473
- else:
1474
- return module(*inputs)
1475
-
1476
- return custom_forward
1477
-
1478
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1479
- hidden_states = torch.utils.checkpoint.checkpoint(
1480
- create_custom_forward(self.resnets[i]),
1481
- hidden_states,
1482
- temb,
1483
- **ckpt_kwargs,
1484
- )
1430
+ hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
1485
1431
  for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1486
1432
  if cross_attention_dim is not None and idx <= 1:
1487
1433
  forward_encoder_hidden_states = encoder_hidden_states
@@ -1492,8 +1438,8 @@ class CrossAttnUpBlock2D(nn.Module):
1492
1438
  else:
1493
1439
  forward_encoder_hidden_states = None
1494
1440
  forward_encoder_attention_mask = None
1495
- hidden_states = torch.utils.checkpoint.checkpoint(
1496
- create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
1441
+ hidden_states = self._gradient_checkpointing_func(
1442
+ self.attentions[i * num_attention_per_layer + idx],
1497
1443
  hidden_states,
1498
1444
  forward_encoder_hidden_states,
1499
1445
  None, # timestep
@@ -1501,7 +1447,6 @@ class CrossAttnUpBlock2D(nn.Module):
1501
1447
  cross_attention_kwargs,
1502
1448
  attention_mask,
1503
1449
  forward_encoder_attention_mask,
1504
- **ckpt_kwargs,
1505
1450
  )[0]
1506
1451
  else:
1507
1452
  hidden_states = self.resnets[i](hidden_states, temb)
@@ -20,7 +20,7 @@ import torch
20
20
  from transformers import (
21
21
  ClapFeatureExtractor,
22
22
  ClapModel,
23
- GPT2Model,
23
+ GPT2LMHeadModel,
24
24
  RobertaTokenizer,
25
25
  RobertaTokenizerFast,
26
26
  SpeechT5HifiGan,
@@ -48,8 +48,20 @@ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditi
48
48
  if is_librosa_available():
49
49
  import librosa
50
50
 
51
+
52
+ from ...utils import is_torch_xla_available
53
+
54
+
55
+ if is_torch_xla_available():
56
+ import torch_xla.core.xla_model as xm
57
+
58
+ XLA_AVAILABLE = True
59
+ else:
60
+ XLA_AVAILABLE = False
61
+
51
62
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
63
 
64
+
53
65
  EXAMPLE_DOC_STRING = """
54
66
  Examples:
55
67
  ```py
@@ -184,7 +196,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
184
196
  text_encoder: ClapModel,
185
197
  text_encoder_2: Union[T5EncoderModel, VitsModel],
186
198
  projection_model: AudioLDM2ProjectionModel,
187
- language_model: GPT2Model,
199
+ language_model: GPT2LMHeadModel,
188
200
  tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
189
201
  tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
190
202
  feature_extractor: ClapFeatureExtractor,
@@ -207,7 +219,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
207
219
  scheduler=scheduler,
208
220
  vocoder=vocoder,
209
221
  )
210
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
222
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
211
223
 
212
224
  # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
213
225
  def enable_vae_slicing(self):
@@ -225,7 +237,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
225
237
  """
226
238
  self.vae.disable_slicing()
227
239
 
228
- def enable_model_cpu_offload(self, gpu_id=0):
240
+ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
229
241
  r"""
230
242
  Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
231
243
  to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
@@ -237,11 +249,26 @@ class AudioLDM2Pipeline(DiffusionPipeline):
237
249
  else:
238
250
  raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
239
251
 
240
- device = torch.device(f"cuda:{gpu_id}")
252
+ torch_device = torch.device(device)
253
+ device_index = torch_device.index
254
+
255
+ if gpu_id is not None and device_index is not None:
256
+ raise ValueError(
257
+ f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
258
+ f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
259
+ )
260
+
261
+ device_type = torch_device.type
262
+ device_str = device_type
263
+ if gpu_id or torch_device.index:
264
+ device_str = f"{device_str}:{gpu_id or torch_device.index}"
265
+ device = torch.device(device_str)
241
266
 
242
267
  if self.device.type != "cpu":
243
268
  self.to("cpu", silence_dtype_warnings=True)
244
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
269
+ device_mod = getattr(torch, device.type, None)
270
+ if hasattr(device_mod, "empty_cache") and device_mod.is_available():
271
+ device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
245
272
 
246
273
  model_sequence = [
247
274
  self.text_encoder.text_model,
@@ -292,9 +319,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
292
319
  model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
293
320
 
294
321
  # forward pass to get next hidden states
295
- output = self.language_model(**model_inputs, return_dict=True)
322
+ output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True)
296
323
 
297
- next_hidden_states = output.last_hidden_state
324
+ next_hidden_states = output.hidden_states[-1]
298
325
 
299
326
  # Update the model input
300
327
  inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
@@ -764,7 +791,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
764
791
 
765
792
  if transcription is None:
766
793
  if self.text_encoder_2.config.model_type == "vits":
767
- raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription")
794
+ raise ValueError("Cannot forward without transcription. Please make sure to have transcription")
768
795
  elif transcription is not None and (
769
796
  not isinstance(transcription, str) and not isinstance(transcription, list)
770
797
  ):
@@ -1033,6 +1060,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
1033
1060
  step_idx = i // getattr(self.scheduler, "order", 1)
1034
1061
  callback(step_idx, t, latents)
1035
1062
 
1063
+ if XLA_AVAILABLE:
1064
+ xm.mark_step()
1065
+
1036
1066
  self.maybe_free_model_hooks()
1037
1067
 
1038
1068
  # 8. Post-processing
@@ -12,20 +12,28 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import inspect
15
- from typing import List, Optional, Tuple, Union
15
+ from typing import Callable, Dict, List, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import T5Tokenizer, UMT5EncoderModel
19
19
 
20
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
20
21
  from ...image_processor import VaeImageProcessor
21
22
  from ...models import AuraFlowTransformer2DModel, AutoencoderKL
22
23
  from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
23
24
  from ...schedulers import FlowMatchEulerDiscreteScheduler
24
- from ...utils import logging, replace_example_docstring
25
+ from ...utils import is_torch_xla_available, logging, replace_example_docstring
25
26
  from ...utils.torch_utils import randn_tensor
26
27
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
27
28
 
28
29
 
30
+ if is_torch_xla_available():
31
+ import torch_xla.core.xla_model as xm
32
+
33
+ XLA_AVAILABLE = True
34
+ else:
35
+ XLA_AVAILABLE = False
36
+
29
37
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
38
 
31
39
 
@@ -124,6 +132,10 @@ class AuraFlowPipeline(DiffusionPipeline):
124
132
 
125
133
  _optional_components = []
126
134
  model_cpu_offload_seq = "text_encoder->transformer->vae"
135
+ _callback_tensor_inputs = [
136
+ "latents",
137
+ "prompt_embeds",
138
+ ]
127
139
 
128
140
  def __init__(
129
141
  self,
@@ -139,9 +151,7 @@ class AuraFlowPipeline(DiffusionPipeline):
139
151
  tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
140
152
  )
141
153
 
142
- self.vae_scale_factor = (
143
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
144
- )
154
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
145
155
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
146
156
 
147
157
  def check_inputs(
@@ -154,10 +164,19 @@ class AuraFlowPipeline(DiffusionPipeline):
154
164
  negative_prompt_embeds=None,
155
165
  prompt_attention_mask=None,
156
166
  negative_prompt_attention_mask=None,
167
+ callback_on_step_end_tensor_inputs=None,
157
168
  ):
158
- if height % 8 != 0 or width % 8 != 0:
159
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
169
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
170
+ raise ValueError(
171
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
172
+ )
160
173
 
174
+ if callback_on_step_end_tensor_inputs is not None and not all(
175
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
176
+ ):
177
+ raise ValueError(
178
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
179
+ )
161
180
  if prompt is not None and prompt_embeds is not None:
162
181
  raise ValueError(
163
182
  f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -380,6 +399,14 @@ class AuraFlowPipeline(DiffusionPipeline):
380
399
  self.vae.decoder.conv_in.to(dtype)
381
400
  self.vae.decoder.mid_block.to(dtype)
382
401
 
402
+ @property
403
+ def guidance_scale(self):
404
+ return self._guidance_scale
405
+
406
+ @property
407
+ def num_timesteps(self):
408
+ return self._num_timesteps
409
+
383
410
  @torch.no_grad()
384
411
  @replace_example_docstring(EXAMPLE_DOC_STRING)
385
412
  def __call__(
@@ -401,6 +428,10 @@ class AuraFlowPipeline(DiffusionPipeline):
401
428
  max_sequence_length: int = 256,
402
429
  output_type: Optional[str] = "pil",
403
430
  return_dict: bool = True,
431
+ callback_on_step_end: Optional[
432
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
433
+ ] = None,
434
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
404
435
  ) -> Union[ImagePipelineOutput, Tuple]:
405
436
  r"""
406
437
  Function invoked when calling the pipeline for generation.
@@ -455,6 +486,15 @@ class AuraFlowPipeline(DiffusionPipeline):
455
486
  return_dict (`bool`, *optional*, defaults to `True`):
456
487
  Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
457
488
  of a plain tuple.
489
+ callback_on_step_end (`Callable`, *optional*):
490
+ A function that calls at the end of each denoising steps during the inference. The function is called
491
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
492
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
493
+ `callback_on_step_end_tensor_inputs`.
494
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
495
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
496
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
497
+ `._callback_tensor_inputs` attribute of your pipeline class.
458
498
  max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
459
499
 
460
500
  Examples:
@@ -476,8 +516,11 @@ class AuraFlowPipeline(DiffusionPipeline):
476
516
  negative_prompt_embeds,
477
517
  prompt_attention_mask,
478
518
  negative_prompt_attention_mask,
519
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
479
520
  )
480
521
 
522
+ self._guidance_scale = guidance_scale
523
+
481
524
  # 2. Determine batch size.
482
525
  if prompt is not None and isinstance(prompt, str):
483
526
  batch_size = 1
@@ -534,6 +577,7 @@ class AuraFlowPipeline(DiffusionPipeline):
534
577
 
535
578
  # 6. Denoising loop
536
579
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
580
+ self._num_timesteps = len(timesteps)
537
581
  with self.progress_bar(total=num_inference_steps) as progress_bar:
538
582
  for i, t in enumerate(timesteps):
539
583
  # expand the latents if we are doing classifier free guidance
@@ -560,10 +604,22 @@ class AuraFlowPipeline(DiffusionPipeline):
560
604
  # compute the previous noisy sample x_t -> x_t-1
561
605
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
562
606
 
607
+ if callback_on_step_end is not None:
608
+ callback_kwargs = {}
609
+ for k in callback_on_step_end_tensor_inputs:
610
+ callback_kwargs[k] = locals()[k]
611
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
612
+
613
+ latents = callback_outputs.pop("latents", latents)
614
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
615
+
563
616
  # call the callback, if provided
564
617
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
565
618
  progress_bar.update()
566
619
 
620
+ if XLA_AVAILABLE:
621
+ xm.mark_step()
622
+
567
623
  if output_type == "latent":
568
624
  image = latents
569
625
  else: